From 6f932482f435f7f80c176afbd9f429c09bce381f Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Sat, 19 May 2018 16:48:29 +0200 Subject: [PATCH] MKL elementwise_add: BLAS version compiles with integral types --- paddle/fluid/operators/elementwise_add_op.cc | 8 ++-- paddle/fluid/operators/elementwise_add_op.h | 39 +++++++++++++++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc index d51a845b4..d2c205371 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise_add_op.cc @@ -18,10 +18,10 @@ namespace ops = paddle::operators; REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y"); REGISTER_OP_CPU_KERNEL( elementwise_add, - ops::ElementwiseAddKernel); -// ops::ElementwiseAddKernel); -// ops::ElementwiseAddKernel, -// ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel); REGISTER_OP_CPU_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 316fd7568..1f8735b7b 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -26,6 +26,34 @@ struct AddFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } }; +template +void default_elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + AddFunctor(), z); +} + +template +typename std::enable_if::value>::type elementwise_add( + const framework::ExecutionContext& ctx, const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + + auto blas = math::GetBlas(ctx); + blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data()); +} + +template +typename std::enable_if::value>::type elementwise_add( + const framework::ExecutionContext& ctx, const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + default_elementwise_add(ctx, x, y, z); +} + template class ElementwiseAddKernel : public framework::OpKernel { public: @@ -36,19 +64,12 @@ class ElementwiseAddKernel : public framework::OpKernel { const auto y = ctx.Input("Y"); auto z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); auto dims_equal = x->dims() == y->dims(); if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - - auto blas = math::GetBlas(ctx); - blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data()); + elementwise_add(ctx, x, y, z); } else { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - AddFunctor(), z); + default_elementwise_add(ctx, x, y, z); } } }; -- GitLab