From 9241011b31bbfac0d99cd89f4545e0f905276914 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Tue, 22 May 2018 02:12:54 +0200 Subject: [PATCH] MKL elementwise add backward: backward works for integral types with fall back to default impl --- paddle/fluid/operators/elementwise_add_op.cc | 6 +- paddle/fluid/operators/elementwise_add_op.h | 69 +++++++++++++++----- 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc index 7824dea5d..d2c205371 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise_add_op.cc @@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); -// ops::ElementwiseAddGradKernel, -// ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 3286aa848..d85f78528 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -85,6 +85,57 @@ struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; +template +void default_elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, + framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + + ElemwiseGradCompute, IdentityGrad>( + ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), + IdentityGrad()); +} + +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + auto blas = math::GetBlas(ctx); + + if (dx) { + blas.VCOPY(dout->numel(), dout->data(), + dx->mutable_data(ctx.GetPlace())); + } + + if (dy) { + blas.VCOPY(dout->numel(), dout->data(), + dy->mutable_data(ctx.GetPlace())); + } +} + +template +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); +} + template class ElementwiseAddGradKernel : public framework::OpKernel { public: @@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { - auto blas = math::GetBlas(ctx); - - if (dx) { - blas.VCOPY(dout->numel(), dout->data(), - dx->mutable_data(ctx.GetPlace())); - } - - if (dy) { - blas.VCOPY(dout->numel(), dout->data(), - dy->mutable_data(ctx.GetPlace())); - } + elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { - ElemwiseGradCompute, IdentityGrad>( - ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), - IdentityGrad()); + default_elementwise_add_grad( + ctx, x, y, out, dout, dx, dy); } } }; -- GitLab