提交 19ff254d 编写于 作者: S sneaxiy

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle...

Merge branch 'refine_elementwise_add' of https://github.com/sneaxiy/Paddle into refine_elementwise_add
......@@ -145,13 +145,22 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx != nullptr) dx->ShareDataWith(*dout);
if (dx != nullptr) {
// In fact, we can just share memory, but it may cause a bug of memory
// optimizer
// dx->ShareDataWith(*dout);
framework::TensorCopy(*dout, ctx.GetPlace(),
ctx.template device_context<DeviceContext>(), dx);
}
if (dy == nullptr) return;
const framework::DDim& x_dim = dout->dims();
framework::DDim y_dim = dy->dims();
if (x_dim == y_dim) {
dy->ShareDataWith(*dout);
// dy->ShareDataWith(*dout);
framework::TensorCopy(*dout, ctx.GetPlace(),
ctx.template device_context<DeviceContext>(), dy);
} else {
dy->mutable_data<T>(ctx.GetPlace());
// Perform reduction to dout to calculate dy
......@@ -160,15 +169,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
y_dim = trim_trailing_singular_dims(y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
auto* device =
ctx.template device_context<DeviceContext>().eigen_device();
auto& device =
*(ctx.template device_context<DeviceContext>().eigen_device());
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
auto eigen_dout = framework::EigenTensor<T, 3>::From(
*dout, framework::make_ddim({pre, n, post}));
auto eigen_dy =
framework::EigenTensor<T, 1>::From(*dy, framework::make_ddim({n}));
eigen_dy.device(*device) = eigen_dout.sum(
eigen_dy.device(device) = eigen_dout.sum(
framework::EigenDim<2>::From(framework::make_ddim({0, 2})));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册