提交 19ff254d 编写于 作者: S sneaxiy

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle...

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle into refine_elementwise_add
...@@ -145,13 +145,22 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -145,13 +145,22 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx != nullptr) dx->ShareDataWith(*dout); if (dx != nullptr) {
// In fact, we can just share memory, but it may cause a bug of memory
// optimizer
// dx->ShareDataWith(*dout);
framework::TensorCopy(*dout, ctx.GetPlace(),
ctx.template device_context<DeviceContext>(), dx);
}
if (dy == nullptr) return; if (dy == nullptr) return;
const framework::DDim& x_dim = dout->dims(); const framework::DDim& x_dim = dout->dims();
framework::DDim y_dim = dy->dims(); framework::DDim y_dim = dy->dims();
if (x_dim == y_dim) { if (x_dim == y_dim) {
dy->ShareDataWith(*dout); // dy->ShareDataWith(*dout);
framework::TensorCopy(*dout, ctx.GetPlace(),
ctx.template device_context<DeviceContext>(), dy);
} else { } else {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
// Perform reduction to dout to calculate dy // Perform reduction to dout to calculate dy
...@@ -160,15 +169,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -160,15 +169,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
y_dim = trim_trailing_singular_dims(y_dim); y_dim = trim_trailing_singular_dims(y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
auto* device = auto& device =
ctx.template device_context<DeviceContext>().eigen_device(); *(ctx.template device_context<DeviceContext>().eigen_device());
int pre, n, post; int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
auto eigen_dout = framework::EigenTensor<T, 3>::From( auto eigen_dout = framework::EigenTensor<T, 3>::From(
*dout, framework::make_ddim({pre, n, post})); *dout, framework::make_ddim({pre, n, post}));
auto eigen_dy = auto eigen_dy =
framework::EigenTensor<T, 1>::From(*dy, framework::make_ddim({n})); framework::EigenTensor<T, 1>::From(*dy, framework::make_ddim({n}));
eigen_dy.device(*device) = eigen_dout.sum( eigen_dy.device(device) = eigen_dout.sum(
framework::EigenDim<2>::From(framework::make_ddim({0, 2}))); framework::EigenDim<2>::From(framework::make_ddim({0, 2})));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册