From 7b10bf0e60e9ac0f56ff532fe58cbf5c538a81b6 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 13 Dec 2018 16:40:51 +0800 Subject: [PATCH] Use mkl --- .../fluid/operators/hierarchical_sigmoid_op.h | 28 ++++++++++++------- paddle/fluid/operators/math/blas.h | 8 ++++++ paddle/fluid/operators/math/blas_impl.h | 21 ++++++++++++++ paddle/fluid/platform/dynload/mklml.h | 2 ++ 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index b73a32af89e..d212e6f8437 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -150,19 +150,27 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { label.data())); } - auto& place = *ctx.template device_context().eigen_device(); - auto pre_out_mat = EigenMatrix::From(pre_out); - auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - auto out_grad_mat = EigenMatrix::From(out_grad); + // softrelu derivative - Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; + auto blas = math::GetBlas(ctx); - // softrelu derivative - pre_out_grad_mat.device(place) = - static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); + auto* pre_out_grad_data = pre_out_grad.data(); + auto* pre_out_data = pre_out.data(); + auto n = pre_out.numel(); + blas.VEXP(n, pre_out_data, pre_out_grad_data); + blas.VINV(n, pre_out_grad_data, pre_out_grad_data); + for (int64_t i = 0; i < n; ++i) { + pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; + } bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) - pre_out_grad_mat.device(place) = - pre_out_grad_mat * out_grad_mat.broadcast(bcast); + auto* out_grad_data = out_grad.data(); + + int64_t dim0 = pre_out_grad.dims()[0]; + int64_t dim1 = pre_out_grad.dims()[1]; + for (int64_t i = 0; i < dim0; ++i) { + T tmp = out_grad_data[i]; + blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1); + } // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 9f3a81f22cc..f67f57827bc 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -181,6 +181,9 @@ class Blas { const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; + template + void VINV(int n, const T* a, T* y) const; + private: const DeviceContext& context_; }; @@ -282,6 +285,11 @@ class BlasT : private Blas { Base()->template BatchedGEMM(args...); } + template + void VINV(ARGS... args) const { + Base()->template VINV(args...); + } + private: const Blas* Base() const { return static_cast*>(this); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index c84087bb1e4..972366bc093 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -118,6 +118,11 @@ struct CBlas { static void VPOW(ARGS... args) { platform::dynload::vsPowx(args...); } + + template + static void VINV(ARGS... args) { + platform::dynload::vsInv(args...); + } }; template <> @@ -213,6 +218,11 @@ struct CBlas { static void VPOW(ARGS... args) { platform::dynload::vdPowx(args...); } + + template + static void VINV(ARGS... args) { + platform::dynload::vdInv(args...); + } }; #else @@ -603,6 +613,17 @@ void Blas::MatMul(const framework::Tensor &mat_a, dim_a.stride_, dim_b.stride_); } } +template +template +void Blas::VINV(int n, const T *a, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VINV(n, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = 1.0 / a[i]; + } +#endif +} } // namespace math } // namespace operators diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index f0a97366236..c3f9433503a 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -82,6 +82,8 @@ extern void* mklml_dso_handle; __macro(vdSqr); \ __macro(vsPowx); \ __macro(vdPowx); \ + __macro(vsInv); \ + __macro(vdInv); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); -- GitLab