提交 a53b1b0b 编写于 作者: T tensor-tang

refine and init jitkernel vmul

上级 2139b9f6
...@@ -77,5 +77,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) ...@@ -77,5 +77,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_gen.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc SRCS jit_kernel.cc jit_gen.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas gflags) DEPS cpu_info cblas gflags enforce)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -39,8 +39,8 @@ class Kernel { ...@@ -39,8 +39,8 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
// TODO(TJ): below members should be deprecated.
int num_{0}; int num_{0};
// TODO(TJ): below two should be reomved.
int end_{0}; int end_{0};
int rest_{0}; int rest_{0};
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
...@@ -65,7 +65,7 @@ class KernelPool { ...@@ -65,7 +65,7 @@ class KernelPool {
template <typename T> template <typename T>
class VMulKernel : public Kernel { class VMulKernel : public Kernel {
public: public:
virtual void Compute(const T *x, const T *y, T *z) const = 0; void (*Compute)(const T *, const T *, T *, int);
}; };
template <typename T> template <typename T>
......
...@@ -14,7 +14,10 @@ limitations under the License. */ ...@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -28,64 +31,97 @@ namespace operators { ...@@ -28,64 +31,97 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; namespace jit = platform::jit; // remove me
using namespace platform::jit; // NOLINT
/* VMUL JitKernel */ /* VMUL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> struct VMulJitCode : public gen::JitCode {
class VMulKernelImpl : public VMulKernel<T> { DECLARE_JIT_CODE(VMulJitCode);
public: explicit VMulJitCode(size_t code_size = 256 * 1024, void* code_ptr = nullptr)
explicit VMulKernelImpl(int d) : VMulKernel<T>() { this->num_ = d; } : gen::JitCode(code_size, code_ptr) {}
void Compute(const T* x, const T* y, T* z) const override { static bool init(int d) {
for (int i = 0; i < this->num_; ++i) { if (MayIUse(avx) || MayIUse(avx2)) {
z[i] = x[i] * y[i]; return d % AVX_FLOAT_BLOCK == 0;
} } else if (MayIUse(avx512f)) {
return d % AVX512_FLOAT_BLOCK == 0;
} else {
return false;
}
}
void generate() override {
preCode();
postCode();
} }
}; };
#ifdef PADDLE_WITH_MKLML template <typename T>
#define MKL_FLOAT(isa, block) \ void VMulRefer(const T* x, const T* y, T* z, int n) {
template <> \ for (int i = 0; i < n; ++i) {
void VMulKernelImpl<float, isa, block>::Compute( \ z[i] = x[i] * y[i];
const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(this->num_, x, y, z); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(this->num_, x, y, z); \
} }
}
FOR_EACH_ISA(MKL_FLOAT, kGT16); #ifdef PADDLE_WITH_MKLML
FOR_EACH_ISA_BLOCK(MKL_DOUBLE); template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
template <>
void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
#endif #endif
#define INTRI8_FLOAT(isa) \ template <typename T>
template <> \ class VMulKernelImpl : public VMulKernel<T> {
void VMulKernelImpl<float, isa, kEQ8>::Compute( \ public:
const float* x, const float* y, float* z) const { \ static inline std::string name(int d) {
__m256 tmpx, tmpy; \ PADDLE_THROW("DType should be either float or double");
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
} }
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
// avx > for > mkl explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef __AVX__ if (useJIT(d)) {
INTRI8_FLOAT(jit::avx); constexpr size_t sz = 256 * 1024; // TODO(TJ): should be related with d
#endif jitcode_.reset(new VMulJitCode(sz));
#ifdef __AVX2__ this->Compute =
INTRI8_FLOAT(jit::avx2); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
#endif return;
#ifdef __AVX512F__ }
INTRI8_FLOAT(jit::avx512f); #ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VMulMKL<T>;
return;
}
#endif #endif
// TODO(TJ): eq16 test and complete avx512 this->Compute = VMulRefer<T>;
#undef INTRI8_FLOAT }
#undef MKL_FLOAT
#undef MKL_DOUBLE private:
std::unique_ptr<VMulJitCode> jitcode_{nullptr};
};
template <>
bool VMulKernelImpl<float>::useJIT(int d) {
return VMulJitCode::init(d);
}
template <>
bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512;
}
template <>
bool VMulKernelImpl<double>::useMKL(int d) {
return true;
}
REGISTER_JITKERNEL(vmul, VMulKernel);
/* VADD JitKernel */ /* VADD JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::jit::cpu_isa_t isa, jit_block>
...@@ -465,13 +501,12 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16); ...@@ -465,13 +501,12 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#undef INTRI16_FLOAT #undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT #undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel);
REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -288,7 +288,7 @@ INTRIAVX512_FLOAT(kGT16); ...@@ -288,7 +288,7 @@ INTRIAVX512_FLOAT(kGT16);
#undef INIT_ALPHA #undef INIT_ALPHA
#undef UPDATE_ALPHA #undef UPDATE_ALPHA
REGISTER_JITKERNEL(crf_decode, CRFDecodeKernel); REGISTER_JITKERNEL_DEPRECATED(crf_decode, CRFDecodeKernel);
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -250,7 +250,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); ...@@ -250,7 +250,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef MKL_FLOAT #undef MKL_FLOAT
#undef MKL_DOUBLE #undef MKL_DOUBLE
REGISTER_JITKERNEL(vexp, VExpKernel); REGISTER_JITKERNEL_DEPRECATED(vexp, VExpKernel);
/* VSigmoid JitKernel */ /* VSigmoid JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
...@@ -396,7 +396,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); ...@@ -396,7 +396,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT #undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID #undef INTRI_VSIGMOID
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); REGISTER_JITKERNEL_DEPRECATED(vsigmoid, VSigmoidKernel);
/* VTanh JitKernel */ /* VTanh JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
...@@ -531,7 +531,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); ...@@ -531,7 +531,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT #undef INTRI_GT16_FLOAT
#undef INTRI_VTANH #undef INTRI_VTANH
REGISTER_JITKERNEL(vtanh, VTanhKernel); REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel);
#undef JITKERNEL_NEW_ACT_IMPL #undef JITKERNEL_NEW_ACT_IMPL
......
...@@ -21,8 +21,71 @@ namespace operators { ...@@ -21,8 +21,71 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit; #define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/ \
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \ if (d < AVX_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \ macro_(ker, dtype, isa, kLT8); \
...@@ -47,20 +110,16 @@ namespace jit = platform::jit; ...@@ -47,20 +110,16 @@ namespace jit = platform::jit;
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \ SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
} }
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \ #define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) #ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d)) std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \ #define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
marco_declare, macro_key, macro_impl) \ dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \ marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \ std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \ if (kers_.find(key) == kers_.end()) { \
...@@ -73,18 +132,20 @@ namespace jit = platform::jit; ...@@ -73,18 +132,20 @@ namespace jit = platform::jit;
kers_.at(key)); \ kers_.at(key)); \
} }
#define REGISTER_JITKERNEL(ker_key, ker_class) \ #define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \ JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \ JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \ JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL) JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \ JITKERNEL_NEW_IMPL_DEPRECATED)
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \ #define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_impl); \ macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \ JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl) macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \ #define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \ macro_(jit::avx512f, block); \
......
...@@ -179,23 +179,23 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -179,23 +179,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
/* C_t = C_t-1 * fgated + cand_gated * igated */ /* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct); vadd_d_->Compute(gates + d_, gates + d2_, ct);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
private: private:
...@@ -289,36 +289,36 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -289,36 +289,36 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override { T* checked) const override {
/* get fgated and igated*/ /* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked); vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_); vadd_d2_->Compute(checked, gates + d_, gates + d_);
act_gate_d2_->Compute(gates + d_, gates + d_); act_gate_d2_->Compute(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_); vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct); vadd_d_->Compute(gates + d_, gates + d2_, ct);
/* get ogated*/ /* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
act_gate_d_->Compute(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/ /* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_); act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates); act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct); vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */ /* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_); vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_); act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_); act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
} }
private: private:
...@@ -352,7 +352,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -352,7 +352,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
act_cell, d)); \ act_cell, d)); \
} }
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, REGISTER_JITKERNEL_ARGS_DEPRECATED(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
#undef INTRI8_FLOAT #undef INTRI8_FLOAT
...@@ -378,13 +378,13 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -378,13 +378,13 @@ class GRUKernelImpl : public GRUKernel<T> {
void ComputeH1(T* gates, T* ht) const override { void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates); act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_); act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht); vmul_d_->Compute(gates, gates + d2_, ht, d_);
} }
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates); act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht); vmul_d_->Compute(ht_1, gates + d_, ht, d_);
} }
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
...@@ -472,7 +472,7 @@ INTRI8_FLOAT(jit::avx512f); ...@@ -472,7 +472,7 @@ INTRI8_FLOAT(jit::avx512f);
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d)); std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU, REGISTER_JITKERNEL_ARGS_DEPRECATED(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT #undef INTRI8_FLOAT
......
...@@ -369,12 +369,12 @@ void lstm_ctht_better( ...@@ -369,12 +369,12 @@ void lstm_ctht_better(
int d2 = d * 2; int d2 = d * 2;
vsigmoid_3d->Compute(gates + d, gates + d); vsigmoid_3d->Compute(gates + d, gates + d);
vtanh_d->Compute(gates, gates); vtanh_d->Compute(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d); vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2); vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct); vadd_d->Compute(gates + d, gates + d2, ct);
/* H_t = act_cell(C_t) * ogated */ /* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2); vtanh_d->Compute(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht); vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
} }
TEST(JitKernel, lstm) { TEST(JitKernel, lstm) {
...@@ -578,7 +578,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) { ...@@ -578,7 +578,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST(JitKernel, vmul) { TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) { for (int d : {7, 8, 15, 16, 30, 256, 512, 1000, 1024}) {
std::vector<float> x(d), y(d); std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d); std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data()); RandomVec<float>(d, x.data());
...@@ -616,7 +616,7 @@ TEST(JitKernel, vmul) { ...@@ -616,7 +616,7 @@ TEST(JitKernel, vmul) {
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data); ker->Compute(x_data, y_data, ztgt_data, d);
} }
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
...@@ -800,8 +800,8 @@ TEST(JitKernel, pool) { ...@@ -800,8 +800,8 @@ TEST(JitKernel, pool) {
EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) != EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d)); std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d));
const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4"); const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfany");
EXPECT_EQ(pvmul_f, pvmul_from_key); EXPECT_EQ(pvmul_f, pvmul_from_key);
const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5"); const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit");
EXPECT_TRUE(pvmul_from_key2 == nullptr); EXPECT_TRUE(pvmul_from_key2 == nullptr);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册