diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc index c86cd57316078778e5930c9b524b931d523028d7..9ad82aec8182d6ba06b67391d71317a3d0df1833 100644 --- a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc @@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { }; template -class EltwiseAddMKLDNNGradKernel : public framework::OpKernel { +class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 5356105e2e551c0528694091608fc7585dce66d2..c60cb1f92e99329d52f6ed39dccde406a5f83563 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" @@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } template -class ElementwiseAddGradKernel : public framework::OpKernel { +class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 95649ac46e6bd41b9e1a865794cdec3ae1e6e247..41a7950bf0c598507c0fda48c6a43f2fd38c41d2 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -53,9 +53,10 @@ struct DivGradDY { }; template -class ElementwiseDivGradKernel : public framework::OpKernel { +class ElementwiseDivGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index 527a18ee3ba88a158a13266a7fbcdafe59ec69d9..bfb5c931958b4ca890ea720af42dad91d5625abb 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -55,9 +56,10 @@ struct MaxGradDy { }; template -class ElementwiseMaxGradKernel : public framework::OpKernel { +class ElementwiseMaxGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index d4e5831463f3e54c72789b6876ea696cf1b4ef4b..db035ffb52e619b337c8190af4ed0e155aaac48d 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -55,9 +55,10 @@ struct MinGradDy { }; template -class ElementwiseMinGradKernel : public framework::OpKernel { +class ElementwiseMinGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index dc73cb6f23614504640283af01981d3f69e89126..82c5fa0472bcc3b4d2d12b7a80c3418da5d6dd7b 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -50,9 +50,10 @@ struct MulGradDY { }; template -class ElementwiseMulGradKernel : public framework::OpKernel { +class ElementwiseMulGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index d8a12e800ad733800c1ec333f15d31d4dcd1a3a5..a79b900b9801e6b80e4433a9acdd4dab6c34859d 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { } }; +template +class ElemwiseGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = + context.Output(framework::GradVarName("X")); + if (dx != nullptr) { + auto& dout = + *context.Input(framework::GradVarName("Out")); + dx->set_lod(dout.lod()); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index 11c7e3fe628001f095836a788f2bcc7c4ee7ad4b..3385df0897700d37d60d8804a01db777ebc02a7e 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -50,9 +51,10 @@ struct SubGradDY { }; template -class ElementwiseSubGradKernel : public framework::OpKernel { +class ElementwiseSubGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index 15dd975e3bbf80b2e616e6628555e812d025f70a..f72824806ed6ee3a4490938403d441326f8a3d4a 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { int x_num_col_dims = ctx.template Attr("x_num_col_dims"); int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - const Tensor* x = ctx.Input("X"); - const Tensor* y = ctx.Input("Y"); - const Tensor x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : *y; - const Tensor* dout = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto x_matrix = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, x_num_col_dims) + : static_cast(*x); + auto y_matrix = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, y_num_col_dims) + : static_cast(*y); + auto* dout = ctx.Input(framework::GradVarName("Out")); Tensor dout_mat; dout_mat.ShareDataWith(*dout); dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); - Tensor* dx = ctx.Output(framework::GradVarName("X")); - Tensor* dy = ctx.Output(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); if (dx) {