From e85be1b1b22aa4c75ed1b7cbf78acce0edaa933e Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Thu, 14 Jan 2021 16:21:43 +0800 Subject: [PATCH] fix flatten api grad (#30426) --- paddle/fluid/operators/flatten_op.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 08efaedccd..1b2f1db1b0 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -68,7 +68,9 @@ class FlattenGradKernel : public framework::OpKernel { auto in_dims = ctx.Input("X")->dims(); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(in_dims); } }; @@ -107,7 +109,9 @@ class Flatten2GradKernel : public framework::OpKernel { auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(x_dims); } }; @@ -175,7 +179,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel { auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); d_x->Resize(x_dims); } }; -- GitLab