diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 08efaedccd4f40033dfa02a801911f6666e14ec8..1b2f1db1b07cdd883417fb5f98e4c685fe32c515 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -68,7 +68,9 @@ class FlattenGradKernel : public framework::OpKernel { auto in_dims = ctx.Input("X")->dims(); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(in_dims); } }; @@ -107,7 +109,9 @@ class Flatten2GradKernel : public framework::OpKernel { auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(x_dims); } }; @@ -175,7 +179,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel { auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(x_dims); } };