未验证 提交 52c1a950 编写于 作者: S ShenLiang 提交者: GitHub

[Hybrid Parallel]add op_device in seed op for recompute

上级 cb6510ff
......@@ -189,12 +189,20 @@ class ProgramStats(object):
persistable=False,
stop_gradient=False)
seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed"))
op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
)
op_device = ""
if op.desc.has_attr(op_device_attr_name):
op_device = op.desc.attr(op_device_attr_name)
added_op = self.block._insert_op(
index=op.idx,
type='seed',
inputs={},
outputs={'Out': [added_var]},
attrs={'seed': seed})
attrs={'seed': seed,
'op_device': op_device})
self.ops.insert(op_idx, added_op)
# modify dropout op desc so that it accept a seed var as input
op.desc.set_input("Seed", [var_unique_name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册