diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 253964562c8d34e0fda3b4760761206895f749aa..baf04c30b17cb333fc8a6544afd6c479442f835b 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -24,19 +26,57 @@ struct AddFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } }; +template +void default_elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + AddFunctor(), z); +} + +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + + auto blas = math::GetBlas(ctx); + blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data()); +} + +template +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + default_elementwise_add(ctx, x, y, z); +} + template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using Tensor = framework::Tensor; - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); + const auto x = ctx.Input("X"); + const auto y = ctx.Input("Y"); + auto z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - AddFunctor(), z); + + auto dims_equal = x->dims() == y->dims(); + if (dims_equal) { + elementwise_add(ctx, x, y, z); + } else { + default_elementwise_add(ctx, x, y, z); + } } }; @@ -45,6 +85,55 @@ struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; +template +void default_elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, + framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + + ElemwiseGradCompute, IdentityGrad>( + ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), + IdentityGrad()); +} + +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + auto blas = math::GetBlas(ctx); + + if (dx) { + blas.VCOPY(dout->numel(), dout->data(), + dx->mutable_data(ctx.GetPlace())); + } + + if (dy) { + blas.VCOPY(dout->numel(), dout->data(), + dy->mutable_data(ctx.GetPlace())); + } +} + +template +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); +} + template class ElementwiseAddGradKernel : public framework::OpKernel { public: @@ -57,10 +146,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, IdentityGrad>( - ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), - IdentityGrad()); + + if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { + elementwise_add_grad(ctx, x, y, out, dout, dx, dy); + } else { + default_elementwise_add_grad(ctx, x, y, out, dout, dx, + dy); + } } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index dabde43850db770d286b13cacd32bee181328d5c..1a37cb39d56066b8380338b9710a441e41518c39 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -125,6 +125,12 @@ class Blas { template void AXPY(int n, T alpha, const T* x, T* y) const; + template + void VADD(int n, const T* x, const T* y, T* z) const; + + template + void VCOPY(int n, const T* x, T* y) const; + template void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, T* C) const; @@ -163,6 +169,16 @@ class BlasT : private Blas { Base()->template AXPY(args...); } + template + void VADD(ARGS... args) const { + Base()->template VADD(args...); + } + + template + void VCOPY(ARGS... args) const { + Base()->template VCOPY(args...); + } + template void GEMV(ARGS... args) const { Base()->template GEMV(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 14b3624b420cb883b36268c0a5a9e8692dbb5b43..ae20406bc21d5e08359be8295cd98495dda7813b 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -34,6 +34,18 @@ struct CBlas { cblas_saxpy(args...); } +#ifdef PADDLE_WITH_MKLML + template + static void VADD(ARGS... args) { + vsAdd(args...); + } +#endif + + template + static void VCOPY(ARGS... args) { + cblas_scopy(args...); + } + template static void GEMV(ARGS... args) { cblas_sgemv(args...); @@ -59,6 +71,18 @@ struct CBlas { cblas_daxpy(args...); } +#ifdef PADDLE_WITH_MKLML + template + static void VADD(ARGS... args) { + vdAdd(args...); + } +#endif + + template + static void VCOPY(ARGS... args) { + cblas_dcopy(args...); + } + template static void GEMV(ARGS... args) { cblas_dgemv(args...); @@ -139,6 +163,24 @@ void Blas::AXPY(int n, T alpha, const T *x, CBlas::AXPY(n, alpha, x, 1, y, 1); } +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + CBlas::VCOPY(n, x, 1, y, 1); +} + +template <> +template +void Blas::VADD(int n, const T *x, const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VADD(n, x, y, z); +#else + this->template VCOPY(n, y, z); + this->template AXPY(n, 1., x, z); +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha,