From 211d81863daefee0757f9cd5e8146382d99d58ec Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 22 Aug 2018 06:17:28 +0000 Subject: [PATCH] Process elemwise grad op's lod. mul_op's lod --- .../operators/elementwise_add_mkldnn_op.cc | 3 +- paddle/fluid/operators/elementwise_add_op.h | 5 +++- paddle/fluid/operators/elementwise_div_op.h | 5 ++-- paddle/fluid/operators/elementwise_max_op.h | 4 ++- paddle/fluid/operators/elementwise_min_op.h | 5 ++-- paddle/fluid/operators/elementwise_mul_op.h | 5 ++-- paddle/fluid/operators/elementwise_op.h | 14 +++++++++ paddle/fluid/operators/elementwise_sub_op.h | 4 ++- paddle/fluid/operators/mul_op.h | 30 ++++++++++++------- 9 files changed, 54 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc index c86cd57316..9ad82aec81 100644 --- a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc @@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { }; template -class EltwiseAddMKLDNNGradKernel : public framework::OpKernel { +class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 5356105e2e..c60cb1f92e 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" @@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } template -class ElementwiseAddGradKernel : public framework::OpKernel { +class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 95649ac46e..41a7950bf0 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -53,9 +53,10 @@ struct DivGradDY { }; template -class ElementwiseDivGradKernel : public framework::OpKernel { +class ElementwiseDivGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index 527a18ee3b..bfb5c93195 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -55,9 +56,10 @@ struct MaxGradDy { }; template -class ElementwiseMaxGradKernel : public framework::OpKernel { +class ElementwiseMaxGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index d4e5831463..db035ffb52 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -55,9 +55,10 @@ struct MinGradDy { }; template -class ElementwiseMinGradKernel : public framework::OpKernel { +class ElementwiseMinGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index dc73cb6f23..82c5fa0472 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" - namespace paddle { namespace operators { @@ -50,9 +50,10 @@ struct MulGradDY { }; template -class ElementwiseMulGradKernel : public framework::OpKernel { +class ElementwiseMulGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index d8a12e800a..a79b900b98 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { } }; +template +class ElemwiseGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = + context.Output(framework::GradVarName("X")); + if (dx != nullptr) { + auto& dout = + *context.Input(framework::GradVarName("Out")); + dx->set_lod(dout.lod()); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index 11c7e3fe62..3385df0897 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" namespace paddle { @@ -50,9 +51,10 @@ struct SubGradDY { }; template -class ElementwiseSubGradKernel : public framework::OpKernel { +class ElementwiseSubGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; auto* dout = ctx.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index 15dd975e3b..f72824806e 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { int x_num_col_dims = ctx.template Attr("x_num_col_dims"); int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - const Tensor* x = ctx.Input("X"); - const Tensor* y = ctx.Input("Y"); - const Tensor x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : *y; - const Tensor* dout = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto x_matrix = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, x_num_col_dims) + : static_cast(*x); + auto y_matrix = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, y_num_col_dims) + : static_cast(*y); + auto* dout = ctx.Input(framework::GradVarName("Out")); Tensor dout_mat; dout_mat.ShareDataWith(*dout); dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); - Tensor* dx = ctx.Output(framework::GradVarName("X")); - Tensor* dy = ctx.Output(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); if (dx) { -- GitLab