提交 e43c8f33 编写于 作者: T Tomasz Patejko

MKL elementwise add: elementwise_add uses vAdd VML function when MKL is used

上级 174d884d
...@@ -18,10 +18,10 @@ namespace ops = paddle::operators; ...@@ -18,10 +18,10 @@ namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y"); REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>, // ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>, // ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>); // ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,14 +32,25 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -30,14 +32,25 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); const auto x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); const auto y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto dims_equal = x->dims() == y->dims();
if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data());
} else {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z); AddFunctor<T>(), z);
} }
}
}; };
template <typename T> template <typename T>
......
...@@ -125,6 +125,12 @@ class Blas { ...@@ -125,6 +125,12 @@ class Blas {
template <typename T> template <typename T>
void AXPY(int n, T alpha, const T* x, T* y) const; void AXPY(int n, T alpha, const T* x, T* y) const;
template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T> template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const; T* C) const;
...@@ -163,6 +169,16 @@ class BlasT : private Blas<DeviceContext> { ...@@ -163,6 +169,16 @@ class BlasT : private Blas<DeviceContext> {
Base()->template AXPY<T>(args...); Base()->template AXPY<T>(args...);
} }
template <typename... ARGS>
void VADD(ARGS... args) const {
Base()->template VADD<T>(args...);
}
template <typename... ARGS>
void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void GEMV(ARGS... args) const { void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...); Base()->template GEMV<T>(args...);
......
...@@ -34,6 +34,18 @@ struct CBlas<float> { ...@@ -34,6 +34,18 @@ struct CBlas<float> {
cblas_saxpy(args...); cblas_saxpy(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vsAdd(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_scopy(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_sgemv(args...); cblas_sgemv(args...);
...@@ -59,6 +71,18 @@ struct CBlas<double> { ...@@ -59,6 +71,18 @@ struct CBlas<double> {
cblas_daxpy(args...); cblas_daxpy(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vdAdd(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_dcopy(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_dgemv(args...); cblas_dgemv(args...);
...@@ -139,6 +163,24 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x, ...@@ -139,6 +163,24 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
CBlas<T>::AXPY(n, alpha, x, 1, y, 1); CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VADD(n, x, y, z);
#else
this->template VCOPY<T>(n, y, z);
this->template AXPY<T>(n, 1., x, z);
#endif
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册