未验证 提交 7915d180 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix bug in elementwise_mul/div_grad when inplace strategy (#38840)

* fix bug when inplace strategy

* fix

* fix

* fix

* fix

* fix
上级 3eaf8d2c
......@@ -31,20 +31,10 @@ ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, out, y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, DivGradXFunctor<T>());
......
......@@ -74,20 +74,10 @@ ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, y, x};
GetGradXAndYOut<ElementwiseType::kBinary, T>(
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, MulGradFunctor<T>());
......
......@@ -2575,6 +2575,7 @@ void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx,
framework::Tensor *dy, Functor func) {
framework::Tensor tmp_dx;
framework::Tensor tmp_dy;
dx->mutable_data<T>(place);
dy->mutable_data<T>(place);
std::vector<framework::Tensor *> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册