未验证 提交 bab11969 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10913 from tpatejko/tpatejko/optimized-elementwise-add

Blas optimized elementwise_add forward and backward passes
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,19 +26,57 @@ struct AddFunctor { ...@@ -24,19 +26,57 @@ struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
}; };
template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VADD(x->numel(), eigen_x.data(), eigen_y.data(), eigen_z.data());
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); const auto x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); const auto y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, auto dims_equal = x->dims() == y->dims();
AddFunctor<T>(), z); if (dims_equal) {
elementwise_add<DeviceContext, T>(ctx, x, y, z);
} else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
}
} }
}; };
...@@ -45,6 +85,55 @@ struct IdentityGrad { ...@@ -45,6 +85,55 @@ struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
}; };
template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
IdentityGrad<T>());
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
}
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -57,10 +146,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -57,10 +146,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
IdentityGrad<T>()); } else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
} }
}; };
......
...@@ -125,6 +125,12 @@ class Blas { ...@@ -125,6 +125,12 @@ class Blas {
template <typename T> template <typename T>
void AXPY(int n, T alpha, const T* x, T* y) const; void AXPY(int n, T alpha, const T* x, T* y) const;
template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T> template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const; T* C) const;
...@@ -163,6 +169,16 @@ class BlasT : private Blas<DeviceContext> { ...@@ -163,6 +169,16 @@ class BlasT : private Blas<DeviceContext> {
Base()->template AXPY<T>(args...); Base()->template AXPY<T>(args...);
} }
template <typename... ARGS>
void VADD(ARGS... args) const {
Base()->template VADD<T>(args...);
}
template <typename... ARGS>
void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void GEMV(ARGS... args) const { void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...); Base()->template GEMV<T>(args...);
......
...@@ -34,6 +34,18 @@ struct CBlas<float> { ...@@ -34,6 +34,18 @@ struct CBlas<float> {
cblas_saxpy(args...); cblas_saxpy(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vsAdd(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_scopy(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_sgemv(args...); cblas_sgemv(args...);
...@@ -59,6 +71,18 @@ struct CBlas<double> { ...@@ -59,6 +71,18 @@ struct CBlas<double> {
cblas_daxpy(args...); cblas_daxpy(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void VADD(ARGS... args) {
vdAdd(args...);
}
#endif
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_dcopy(args...);
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_dgemv(args...); cblas_dgemv(args...);
...@@ -139,6 +163,24 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x, ...@@ -139,6 +163,24 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
CBlas<T>::AXPY(n, alpha, x, 1, y, 1); CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VADD(n, x, y, z);
#else
this->template VCOPY<T>(n, y, z);
this->template AXPY<T>(n, 1., x, z);
#endif
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册