diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc index d51a845b41f3e7405704d84635d4a7a96de45199..d2c20537136fc3ac9d1bece24a2238f26215c922 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise_add_op.cc @@ -18,10 +18,10 @@ namespace ops = paddle::operators; REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y"); REGISTER_OP_CPU_KERNEL( elementwise_add, - ops::ElementwiseAddKernel); -// ops::ElementwiseAddKernel); -// ops::ElementwiseAddKernel, -// ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel); REGISTER_OP_CPU_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 316fd7568e6ba322d108c7cda973a95531486686..1f8735b7b17e9b212889532ca1677d678bfcff2f 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -26,6 +26,34 @@ struct AddFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } }; +template +void default_elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + AddFunctor(), z); +} + +template +typename std::enable_if::value>::type elementwise_add( + const framework::ExecutionContext& ctx, const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + + auto blas = math::GetBlas(ctx); + blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data()); +} + +template +typename std::enable_if::value>::type elementwise_add( + const framework::ExecutionContext& ctx, const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + default_elementwise_add(ctx, x, y, z); +} + template class ElementwiseAddKernel : public framework::OpKernel { public: @@ -36,19 +64,12 @@ class ElementwiseAddKernel : public framework::OpKernel { const auto y = ctx.Input("Y"); auto z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); auto dims_equal = x->dims() == y->dims(); if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - - auto blas = math::GetBlas(ctx); - blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data()); + elementwise_add(ctx, x, y, z); } else { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - AddFunctor(), z); + default_elementwise_add(ctx, x, y, z); } } };