diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index bd31f8d20a8c69a30865fa2c0eda8c5aa2b8c58e..3630cfd0eafeb916df6ed0f722755fd925085a0e 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4287,6 +4287,18 @@ class PipelineOptimizer(object): prev_op = self._find_real_prev_op(block.ops, op, op.desc.input("X")[0]) op._set_attr('op_device', prev_op.attr('op_device')) + elif op.type == "memcpy" and not self._is_optimize_op(op): + assert len(op.input_arg_names) == 1 and len( + op.output_arg_names) == 1 + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + if '@Fetch' in output_name: + post_op = self._find_real_post_op(block.ops, op, output_name) + op._set_attr('op_device', post_op.attr('op_device')) + else: + prev_op = self._find_real_prev_op(block.ops, op, + op.desc.input("X")[0]) + op._set_attr('op_device', prev_op.attr('op_device')) elif self._is_loss_op(op): # For loss * loss_scaling op added by AMP offset = 1