From 7915d18056d4f4284f5f415d5f9111c157b782c7 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 11 Jan 2022 20:00:12 +0800 Subject: [PATCH] Fix bug in elementwise_mul/div_grad when inplace strategy (#38840) * fix bug when inplace strategy * fix * fix * fix * fix * fix --- .../operators/elementwise/elementwise_div_op.cu | 10 ---------- .../operators/elementwise/elementwise_mul_op.cu | 12 +----------- .../operators/elementwise/elementwise_op_function.h | 1 + 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 7a25f65366..06f9107db2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -31,20 +31,10 @@ ElementwiseDivGrad(const framework::ExecutionContext& ctx, const auto& dev_ctx = ctx.template device_context(); const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, out, y}; GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, DivGradXFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index a8b6c2abe3..5ece5cadc6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -74,20 +74,10 @@ ElementwiseMulGrad(const framework::ExecutionContext& ctx, const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y, x}; - GetGradXAndYOut( + GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, MulGradFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 3929699955..41cb2696f5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -2575,6 +2575,7 @@ void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, framework::Tensor *dy, Functor func) { framework::Tensor tmp_dx; framework::Tensor tmp_dy; + dx->mutable_data(place); dy->mutable_data(place); std::vector outs; if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { -- GitLab