CNTK:model.clone() 对于深度不为 0 的节点
CNTK: model.clone() for nodes not in depth 0
我正在尝试使用 model.clone() 替换图中的一些节点。我通过调用以下函数获得了我想要替换的这些节点:
times_nodes = find_all_with_name(型号, 姓名, -1)
有问题的节点在图中不在深度 0 中。我注意到当我按如下方式调用 clone 时:
模型 = model.clone('clone', substitutions=subst)
none 个节点实际上已被替换。有没有办法做到这一点或者这是预期的行为?
目前,作为块函数的节点不能用克隆替换它们的内脏。我们可能会在克隆中添加这种操作模式。现在您可以尝试展平图形。以下代码可能足以展平大多数网络。不过我还没有彻底测试它:
def break_a_block(root):
blocks = C.logging.graph.depth_first_search(root, lambda node: isinstance(node, C.cntk_py.Function) and node.is_block, depth=-1)
if len(blocks) == 0:
return False, root
block = blocks[0]
composite = C.as_composite(block.block_root)
output_dict = dict(zip(block.outputs, composite.outputs))
mapping = dict(block.block_arguments_mapping)
items = list(mapping.items())
owners = set(C.as_composite(arg.owner) for _, arg in items if arg.is_output)
for owner in owners:
clone = owner.clone('share', output_dict)
out = dict(zip(owner.outputs,clone.outputs))
for p,a in items:
if mapping[p] in out:
mapping[p] = out[mapping[p]]
composite.replace_placeholders(mapping)
return True, root.clone('share',output_dict)
def flatten(root):
changed, root = break_a_block(root)
while changed:
changed, root = break_a_block(root)
return root
我正在尝试使用 model.clone() 替换图中的一些节点。我通过调用以下函数获得了我想要替换的这些节点: times_nodes = find_all_with_name(型号, 姓名, -1) 有问题的节点在图中不在深度 0 中。我注意到当我按如下方式调用 clone 时: 模型 = model.clone('clone', substitutions=subst) none 个节点实际上已被替换。有没有办法做到这一点或者这是预期的行为?
目前,作为块函数的节点不能用克隆替换它们的内脏。我们可能会在克隆中添加这种操作模式。现在您可以尝试展平图形。以下代码可能足以展平大多数网络。不过我还没有彻底测试它:
def break_a_block(root):
blocks = C.logging.graph.depth_first_search(root, lambda node: isinstance(node, C.cntk_py.Function) and node.is_block, depth=-1)
if len(blocks) == 0:
return False, root
block = blocks[0]
composite = C.as_composite(block.block_root)
output_dict = dict(zip(block.outputs, composite.outputs))
mapping = dict(block.block_arguments_mapping)
items = list(mapping.items())
owners = set(C.as_composite(arg.owner) for _, arg in items if arg.is_output)
for owner in owners:
clone = owner.clone('share', output_dict)
out = dict(zip(owner.outputs,clone.outputs))
for p,a in items:
if mapping[p] in out:
mapping[p] = out[mapping[p]]
composite.replace_placeholders(mapping)
return True, root.clone('share',output_dict)
def flatten(root):
changed, root = break_a_block(root)
while changed:
changed, root = break_a_block(root)
return root