diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 7a25f65366901e23b3972b29b2f1eca76f19471c..06f9107db27b4f2cce54bbcabe3c53e81e4167d1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -31,20 +31,10 @@ ElementwiseDivGrad(const framework::ExecutionContext& ctx, const auto& dev_ctx = ctx.template device_context(); const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, out, y}; GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, DivGradXFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index a8b6c2abe3bf99188bd4dcf9201d0a4655121f74..5ece5cadc603fa9598fd7741067f680202c81eb4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -74,20 +74,10 @@ ElementwiseMulGrad(const framework::ExecutionContext& ctx, const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y, x}; - GetGradXAndYOut( + GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, MulGradFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 3929699955a17f63d5fa2deead9ee0a3659e267f..41cb2696f5492a94966f23ad47402e1b57f77367 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -2575,6 +2575,7 @@ void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, framework::Tensor *dy, Functor func) { framework::Tensor tmp_dx; framework::Tensor tmp_dy; + dx->mutable_data(place); dy->mutable_data(place); std::vector outs; if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {