diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index dc73cb6f23614504640283af01981d3f69e89126..329d2d129a9ea450cd211f0c6d2ea5e37ff8491d 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -23,6 +24,37 @@ struct MulFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } }; +template +void default_elementwise_mul(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + MulFunctor(), z); +} + +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_mul(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto blas = math::GetBlas(ctx); + blas.VMUL(x->numel(), x->data(), y->data(), + z->mutable_data(ctx.GetPlace())); +} + +template +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_mul(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + default_elementwise_mul(ctx, x, y, z); +} + template class ElementwiseMulKernel : public framework::OpKernel { public: @@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel { auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - MulFunctor(), z); + if (x->numel() == y->numel()) { + elementwise_mul(ctx, x, y, z); + } else { + default_elementwise_mul(ctx, x, y, z); + } } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 2558154e0b39a4281bfaa59ba75867589d73be5d..8dcf7c99f3860789dee834787eeb8b7ad4cc3530 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -134,6 +134,9 @@ class Blas { template void VADD(int n, const T* x, const T* y, T* z) const; + template + void VMUL(int n, const T* x, const T* y, T* z) const; + template void VCOPY(int n, const T* x, T* y) const; @@ -202,6 +205,11 @@ class BlasT : private Blas { Base()->template VADD(args...); } + template + void VMUL(ARGS... args) const { + Base()->template VMUL(args...); + } + template void VCOPY(ARGS... args) const { Base()->template VCOPY(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index bf3382107960dfd8b52f94b421b49022dcb6d291..dc77b6d793702458a22a2f59b68e9d9f2c23b4ff 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -82,6 +82,11 @@ struct CBlas { static void VADD(ARGS... args) { platform::dynload::vsAdd(args...); } + + template + static void VMUL(ARGS... args) { + platform::dynload::vsMul(args...); + } }; template <> @@ -142,6 +147,11 @@ struct CBlas { static void VADD(ARGS... args) { platform::dynload::vdAdd(args...); } + + template + static void VMUL(ARGS... args) { + platform::dynload::vdMul(args...); + } }; #else @@ -199,6 +209,7 @@ struct CBlas { static void SMM_GEMM(...) { PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); } + static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); @@ -374,6 +385,20 @@ void Blas::VADD(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VMUL(int n, const T *x, const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMUL(n, x, y, z); +#else + // try to find if openblas support vmul + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +#endif +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, T alpha, diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 9e7a616094e184695de521aa035257bde4170a91..f2e55ed52f4b98be34661cd8d9bc897cba484356 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -49,25 +49,27 @@ extern void* mklml_dso_handle; #define MKLML_ROUTINE_EACH(__macro) \ __macro(cblas_sgemm); \ - __macro(cblas_saxpy); \ - __macro(cblas_scopy); \ - __macro(cblas_sgemv); \ - __macro(cblas_sgemm_batch); \ __macro(cblas_dgemm); \ + __macro(cblas_saxpy); \ __macro(cblas_daxpy); \ + __macro(cblas_scopy); \ __macro(cblas_dcopy); \ + __macro(cblas_sgemv); \ __macro(cblas_dgemv); \ - __macro(cblas_dgemm_batch); \ - __macro(vsAdd); \ - __macro(vdAdd); \ __macro(cblas_sgemm_alloc); \ - __macro(cblas_sgemm_pack); \ - __macro(cblas_sgemm_compute); \ - __macro(cblas_sgemm_free); \ __macro(cblas_dgemm_alloc); \ + __macro(cblas_sgemm_pack); \ __macro(cblas_dgemm_pack); \ + __macro(cblas_sgemm_compute); \ __macro(cblas_dgemm_compute); \ + __macro(cblas_sgemm_free); \ __macro(cblas_dgemm_free); \ + __macro(cblas_sgemm_batch); \ + __macro(cblas_dgemm_batch); \ + __macro(vsAdd); \ + __macro(vdAdd); \ + __macro(vsMul); \ + __macro(vdMul); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);