提交 d53401f1 编写于 作者: S sandyhouse

add support for memcpy

上级 75644caf
......@@ -4287,6 +4287,18 @@ class PipelineOptimizer(object):
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif op.type == "memcpy" and not self._is_optimize_op(op):
assert len(op.input_arg_names) == 1 and len(
op.output_arg_names) == 1
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
if '@Fetch' in output_name:
post_op = self._find_real_post_op(block.ops, op, output_name)
op._set_attr('op_device', post_op.attr('op_device'))
else:
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr('op_device', prev_op.attr('op_device'))
elif self._is_loss_op(op):
# For loss * loss_scaling op added by AMP
offset = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册