提交 7b10bf0e 编写于 作者: Y Yu Yang

Use mkl

上级 b2b5241e
......@@ -150,19 +150,27 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
label.data<int64_t>()));
}
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(out_grad);
// softrelu derivative
Eigen::array<int, 2> bcast{1, static_cast<int>(pre_out_grad.dims()[1])};
auto blas = math::GetBlas<DeviceContext, T>(ctx);
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.data<T>();
auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
for (int64_t i = 0; i < n; ++i) {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
}
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
auto* out_grad_data = out_grad.data<T>();
int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
T tmp = out_grad_data[i];
blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
}
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
......
......@@ -181,6 +181,9 @@ class Blas {
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const;
template <typename T>
void VINV(int n, const T* a, T* y) const;
private:
const DeviceContext& context_;
};
......@@ -282,6 +285,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template BatchedGEMM<T>(args...);
}
template <typename... ARGS>
void VINV(ARGS... args) const {
Base()->template VINV<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);
......
......@@ -118,6 +118,11 @@ struct CBlas<float> {
static void VPOW(ARGS... args) {
platform::dynload::vsPowx(args...);
}
template <typename... ARGS>
static void VINV(ARGS... args) {
platform::dynload::vsInv(args...);
}
};
template <>
......@@ -213,6 +218,11 @@ struct CBlas<double> {
static void VPOW(ARGS... args) {
platform::dynload::vdPowx(args...);
}
template <typename... ARGS>
static void VINV(ARGS... args) {
platform::dynload::vdInv(args...);
}
};
#else
......@@ -603,6 +613,17 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a.stride_, dim_b.stride_);
}
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VINV(n, a, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = 1.0 / a[i];
}
#endif
}
} // namespace math
} // namespace operators
......
......@@ -82,6 +82,8 @@ extern void* mklml_dso_handle;
__macro(vdSqr); \
__macro(vsPowx); \
__macro(vdPowx); \
__macro(vsInv); \
__macro(vdInv); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册