diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 56269f051861a87c45dfce7e556edb81be0ea684..1597690275932bff7628f3f6dc3ec6448988a772 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -178,6 +178,7 @@ bool VActJitCode::init(int d, operand_type type) { if (type == operand_type::relu) { return ok; } else { + // TODO(TJ): support more return ok && d == 8; // only 8 yet } } diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 1d443bdbe2bae4be7919423c6ae29f3af5010557..b023ef096ade568acb1e922480d55465c64f344f 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -98,42 +98,23 @@ class VAddBiasKernel : public Kernel { template class VActKernel : public Kernel { public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template -class VReluKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; - void (*Compute)(const T *, T *, int); -}; +class VReluKernel : public VActKernel {}; template -class VIdentityKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; -}; +class VIdentityKernel : public VActKernel {}; template -class VExpKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; - void (*Compute)(const T *, T *, int); -}; +class VExpKernel : public VActKernel {}; template -class VSigmoidKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; - void (*Compute)(const T *, T *, int); -}; +class VSigmoidKernel : public VActKernel {}; template -class VTanhKernel : public VActKernel { - public: - virtual void ComputeDeprecated(const T *x, T *y) const = 0; - void (*Compute)(const T *, T *, int); -}; +class VTanhKernel : public VActKernel {}; template class LSTMKernel : public Kernel { diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 05af7432c5787db28f919858a4319f9e989f5038..e9e7eec445c13ce78bcf1b71e4e0cb9926e623a9 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -346,7 +346,6 @@ class VReluKernelImpl : public VReluKernel { public: JITKERNEL_DECLARE_STATIC_FUNC; explicit VReluKernelImpl(int d) : VReluKernel() { - this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 /* init size */ + @@ -361,9 +360,6 @@ class VReluKernelImpl : public VReluKernel { this->Compute = VReluRefer; } - void ComputeDeprecated(const T* x, T* y) const override { - VReluRefer(x, y, this->num_); - } #ifdef PADDLE_WITH_XBYAK private: @@ -378,22 +374,26 @@ bool VReluKernelImpl::useJIT(int d) { } #endif -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); -REGISTER_JITKERNEL(vrelu, VReluKernel); +template +inline void VIdentityRefer(const T* x, T* y, int n) {} /* An empty JitKernel */ -template +template class VIdentityKernelImpl : public VIdentityKernel { public: - explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } - void ComputeDeprecated(const T* x, T* y) const override {} + JITKERNEL_DECLARE_STATIC_FUNC; + explicit VIdentityKernelImpl(int d) : VIdentityKernel() { + this->Compute = VIdentityRefer; + } }; -REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); +REGISTER_JITKERNEL(vrelu, VReluKernel); +REGISTER_JITKERNEL(videntity, VIdentityKernel); } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 28059ad270fd13ff6008464dd5914d2f53cbb223..0e2cdad4700ee8602f8ed8c3824f366a7ce32806 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -36,6 +36,7 @@ namespace jitkernel { namespace jit = platform::jit; // TODO(TJ): move refer codes to one file +// Refer code only focus on correctness template void VExpRefer(const T* x, T* y, int n) { for (int i = 0; i < n; ++i) { @@ -67,6 +68,7 @@ void VTanhRefer(const T* x, T* y, int n) { } #ifdef PADDLE_WITH_MKLML +// try to use MKL to speedup template void VExpMKL(const T* x, T* y, int n); @@ -112,7 +114,6 @@ class VExpKernelImpl : public VExpKernel { public: JITKERNEL_DECLARE_STATIC_FUNC; explicit VExpKernelImpl(int d) : VExpKernel() { - this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change @@ -130,9 +131,7 @@ class VExpKernelImpl : public VExpKernel { #endif this->Compute = VExpRefer; } - void ComputeDeprecated(const T* x, T* y) const override { - VExpRefer(x, y, this->num_); - } + #ifdef PADDLE_WITH_XBYAK private: @@ -166,7 +165,6 @@ class VSigmoidKernelImpl : public VSigmoidKernel { public: JITKERNEL_DECLARE_STATIC_FUNC; explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { - this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change @@ -186,9 +184,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { #endif this->Compute = VSigmoidRefer; } - void ComputeDeprecated(const T* x, T* y) const override { - VSigmoidRefer(x, y, this->num_); - } + #ifdef PADDLE_WITH_XBYAK private: @@ -221,7 +217,6 @@ class VTanhKernelImpl : public VTanhKernel { public: JITKERNEL_DECLARE_STATIC_FUNC; explicit VTanhKernelImpl(int d) : VTanhKernel() { - this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change @@ -241,9 +236,7 @@ class VTanhKernelImpl : public VTanhKernel { #endif this->Compute = VTanhRefer; } - void ComputeDeprecated(const T* x, T* y) const override { - VTanhRefer(x, y, this->num_); - } + #ifdef PADDLE_WITH_XBYAK private: diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index 926221f0a75c461e275a72f16b4339ae28a8e988..e79b0400ab75d1488a26450bd8cde4a0979fc995 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh - act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); + act_gate_d3_->Compute(gates + d_, gates + d_, d3_); /* C_t = C_t-1 * fgated + cand_gated * igated */ - act_cand_d_->ComputeDeprecated(gates, gates); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); - act_cand_d_->ComputeDeprecated(gates, gates); + act_gate_d_->Compute(gates + d_, gates + d_, d_); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); - act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); + act_gate_d2_->Compute(gates + d_, gates + d_, d2_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ - act_cand_d_->ComputeDeprecated(gates, gates); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); - act_cand_d_->ComputeDeprecated(gates, gates); + act_gate_d_->Compute(gates + d_, gates + d_, d_); + act_cand_d_->Compute(gates, gates, d_); vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); - act_cell_d_->ComputeDeprecated(ct, gates + d2_); + act_gate_d_->Compute(gates + d3_, gates + d3_, d_); + act_cell_d_->Compute(ct, gates + d2_, d_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel { } void ComputeH1(T* gates, T* ht) const override { - act_gate_d_->ComputeDeprecated(gates, gates); - act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); + act_gate_d_->Compute(gates, gates, d_); + act_state_d_->Compute(gates + d2_, gates + d2_, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_); } void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { // W: {W_update, W_reset; W_state} - act_gate_d2_->ComputeDeprecated(gates, gates); + act_gate_d2_->Compute(gates, gates, d2_); vmul_d_->Compute(ht_1, gates + d_, ht, d_); } void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { T* y = gates + d2_; - act_state_d_->ComputeDeprecated(y, y); + act_state_d_->Compute(y, y, d_); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d_; ++i) { ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 2f9dbc585efb47664803f2da30688bd8aa68300a..5a6f87fe1f7d10d65d03d78c168d61719cec772e 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -181,7 +181,7 @@ TEST(JitKernel, vexp) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - // ker->ComputeDeprecated(x_data, ztgt_data); + // ker->Compute(x_data, ztgt_data); ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -345,8 +345,8 @@ void lstm_ctht_ref( const std::shared_ptr< const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, const int d, float* gates, const float* ct_1, float* ct, float* ht) { - vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); - vtanh_d->ComputeDeprecated(gates, gates); + vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); + vtanh_d->Compute(gates, gates, d); const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -356,7 +356,7 @@ void lstm_ctht_ref( // H_t = act_cell(C_t) * ogated float tmp = ct[k] * 2; tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vexp_1->ComputeDeprecated(&tmp, &tmp); + vexp_1->Compute(&tmp, &tmp, 1); tmp = 2.f / (1.f + tmp) - 1.f; ht[k] = tmp * o[k]; } @@ -374,13 +374,13 @@ void lstm_ctht_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, const int d, float* gates, const float* ct_1, float* ct, float* ht) { int d2 = d * 2; - vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); - vtanh_d->ComputeDeprecated(gates, gates); + vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); + vtanh_d->Compute(gates, gates, d); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ - vtanh_d->ComputeDeprecated(ct, gates + d2); + vtanh_d->Compute(ct, gates + d2, d); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); } @@ -737,7 +737,7 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VReluKernel>& vrelu, const float* x, const float* y, float* z, int d) { vadd->Compute(x, y, z, d); - vrelu->ComputeDeprecated(z, z); + vrelu->Compute(z, z, d); } TEST(JitKernel, vaddrelu) {