diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index bdc9c1250e82718ce07e6eabfac62531d452b18b..e8b73bd83cd37581bb07799c685739cada81f74c 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -147,12 +147,12 @@ template class KernelImpl : public Kernel { // TODO(TJ): rename KernelImpl to KernelMore which seems only used in more // and add name interface for more implements easy for debug + public: using T = typename KernelTuples::data_type; using Func = typename KernelTuples::func_type; using Attr = typename KernelTuples::attr_type; - - public: virtual Func GetFunc() const { return func; } + // TODO(TJ): const &attr virtual bool UseMe(Attr attr) const = 0; protected: diff --git a/paddle/fluid/operators/jit/more/mix/CMakeLists.txt b/paddle/fluid/operators/jit/more/mix/CMakeLists.txt index 3b1c67a6db0ec3b619972c57dbe4ab28cecbb31c..56765f874a6586989f6e320ceb809a905ecd43a7 100644 --- a/paddle/fluid/operators/jit/more/mix/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mix/CMakeLists.txt @@ -7,3 +7,8 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE) USE_JITKERNEL_MORE(vsigmoid, mix) USE_JITKERNEL_MORE(vtanh, mix) +USE_JITKERNEL_MORE(lstmctht, mix) +USE_JITKERNEL_MORE(lstmc1h1, mix) +USE_JITKERNEL_MORE(gruh1, mix) +USE_JITKERNEL_MORE(gruhtpart1, mix) +USE_JITKERNEL_MORE(gruhtpart2, mix) diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index 708e22549ad02ba36962a06c10179b053b5b16e8..d8d5e30d0105b37049974a9d565121f8a17de953 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -23,7 +23,6 @@ namespace jit { namespace more { namespace mix { -template void VSigmoid(const T* x, T* y, int n) { const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -38,7 +37,6 @@ void VSigmoid(const T* x, T* y, int n) { } } -template void VTanh(const T* x, T* y, int n) { const T a = 2, b = -1; auto compute_scal = Get, platform::CPUPlace>(n); @@ -50,26 +48,151 @@ void VTanh(const T* x, T* y, int n) { compute_addbias(&b, y, y, n); } -template <> -bool VSigmoidKernel::UseMe(int d) const { - return true; +void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT + if (type == vsigmoid) { + return Get, platform::CPUPlace>(d); + } else if (type == vrelu) { + return Get, platform::CPUPlace>(d); + } else if (type == vtanh) { + return Get, platform::CPUPlace>(d); + } else if (type == videntity) { + return Get, platform::CPUPlace>(d); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; } -template <> -bool VTanhKernel::UseMe(int d) const { - return true; +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + const T* ct_1 = reinterpret_cast(step->ct_1); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + const T* wp = reinterpret_cast(step->wp); + T* checked = reinterpret_cast(step->checked); + const int d = attr->d; + const int d2 = d * 2; + const int d3 = d * 3; + auto vmul_d = Get, platform::CPUPlace>(d); + auto vadd_d = Get, platform::CPUPlace>(d); + auto vadd_d2 = Get, platform::CPUPlace>(d2); + auto act_gate_d = getActFunc(attr->act_gate, d); + auto act_gate_d2 = getActFunc(attr->act_gate, d2); + auto act_gate_d3 = getActFunc(attr->act_gate, d2); + auto act_cand_d = getActFunc(attr->act_cand, d); + auto act_cell_d = getActFunc(attr->act_cell, d); + + if (attr->use_peephole) { + vmul_d(wp, ct_1, checked, d); + vmul_d(wp + d, ct_1, checked + d, d); + vadd_d2(checked, gates + d, gates + d, d2); + act_gate_d2(gates + d, gates + d, d2); + } else { + act_gate_d3(gates + d, gates + d, d3); + } + + // C_t = C_t-1 * fgated + cand_gated * igated + act_cand_d(gates, gates, d); + vmul_d(gates, gates + d, gates + d, d); + vmul_d(ct_1, gates + d2, gates + d2, d); + vadd_d(gates + d, gates + d2, ct, d); + + if (attr->use_peephole) { + // get ogated + vmul_d(wp + d2, ct, gates + d, d); + vadd_d(gates + d, gates + d3, gates + d3, d); + act_gate_d(gates + d3, gates + d3, d); + } + // H_t = act_cell(C_t) * ogated + act_cell_d(ct, gates + d2, d); + vmul_d(gates + d2, gates + d3, ht, d); } -#define AWALYS_USE_ME_WITH_DOUBLE(func) \ - template <> \ - bool func##Kernel::UseMe(int d) const { \ - return true; \ +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + int d = attr->d; + int d2 = d * 2; + int d3 = d * 3; + auto vmul_d = Get, platform::CPUPlace>(d); + auto vadd_d = Get, platform::CPUPlace>(d); + auto act_gate_d = getActFunc(attr->act_gate, d); + auto act_cand_d = getActFunc(attr->act_cand, d); + auto act_cell_d = getActFunc(attr->act_cell, d); + /* C_t = igated * cgated*/ + act_gate_d(gates + d, gates + d, d); + act_cand_d(gates, gates, d); + vmul_d(gates, gates + d, ct, d); + if (attr->use_peephole) { + // get outgated, put W_oc * C_t on igated + const T* wp = reinterpret_cast(step->wp); + vmul_d(wp + d2, ct, gates + d, d); + vadd_d(gates + d, gates + d3, gates + d3, d); } + /* H_t = act_cell(C_t) * ogated */ + act_gate_d(gates + d3, gates + d3, d); + act_cell_d(ct, gates + d2, d); + vmul_d(gates + d2, gates + d3, ht, d); +} + +// compute h1 without h0 +void GRUH1(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + int d = attr->d; + int d2 = d * 2; + auto act_gate = getActFunc(attr->act_gate, d); + auto act_cand = getActFunc(attr->act_cand, d); + auto vmul_d = Get, platform::CPUPlace>(d); + act_gate(gates, gates, d); + act_cand(gates + d2, gates + d2, d); + vmul_d(gates, gates + d2, ht, d); +} + +// compute the first part of GRU: ht = act_gate(r) * ht_1 +void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { + // W: {W_update, W_reset; W_state} + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + auto act_gate = getActFunc(attr->act_gate, attr->d); + auto vmul_d = Get, platform::CPUPlace>(attr->d); + act_gate(gates + attr->d, gates + attr->d, attr->d); + vmul_d(ht_1, gates + attr->d, ht, attr->d); +} + +// compute the second part of GRU: +// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 +void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + int d = attr->d; + auto act_gate = getActFunc(attr->act_gate, d); + auto act_cand = getActFunc(attr->act_cand, d); + T* y = gates + d * 2; + act_gate(gates, gates, d); + act_cand(y, y, d); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } +} + +// TODO(TJ): tuning me +bool VSigmoidKernel::UseMe(int d) const { return true; } + +bool VTanhKernel::UseMe(int d) const { return true; } + +bool LSTMCtHtKernel::UseMe(lstm_attr_t attr) const { return true; } + +bool LSTMC1H1Kernel::UseMe(lstm_attr_t attr) const { return true; } + +bool GRUH1Kernel::UseMe(gru_attr_t attr) const { return true; } -AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); -AWALYS_USE_ME_WITH_DOUBLE(VTanh); +bool GRUHtPart1Kernel::UseMe(gru_attr_t attr) const { return true; } -#undef AWALYS_USE_ME_WITH_DOUBLE +bool GRUHtPart2Kernel::UseMe(gru_attr_t attr) const { return true; } } // namespace mix } // namespace more @@ -79,11 +202,15 @@ AWALYS_USE_ME_WITH_DOUBLE(VTanh); namespace mix = paddle::operators::jit::more::mix; -#define REGISTER_MORE_KERNEL(key, func) \ - REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel, \ - mix::func##Kernel) +#define REGISTER_MORE_KERNEL(key, func) \ + REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_MORE_KERNEL(vsigmoid, VSigmoid); REGISTER_MORE_KERNEL(vtanh, VTanh); +REGISTER_MORE_KERNEL(lstmctht, LSTMCtHt); +REGISTER_MORE_KERNEL(lstmc1h1, LSTMC1H1); +REGISTER_MORE_KERNEL(gruh1, GRUH1); +REGISTER_MORE_KERNEL(gruhtpart1, GRUHtPart1); +REGISTER_MORE_KERNEL(gruhtpart2, GRUHtPart2); #undef REGISTER_MORE_KERNEL diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h index 38b738a8b1125264d099d4131cdc59b7809d443e..85c8fd4c321fe8d22eb8f288c3054069ed44d5f2 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.h +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -22,18 +22,21 @@ namespace operators { namespace jit { namespace more { namespace mix { +using T = float; -template void VSigmoid(const T* x, T* y, int n); - -template void VTanh(const T* x, T* y, int n); +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr); +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr); +void GRUH1(gru_t* step, const gru_attr_t* attr); +void GRUHtPart1(gru_t* step, const gru_attr_t* attr); +void GRUHtPart2(gru_t* step, const gru_attr_t* attr); + #define DECLARE_MORE_KERNEL(name, tuples) \ - template \ class name##Kernel : public KernelImpl> { \ public: \ - name##Kernel() { this->func = name; } \ + name##Kernel() { this->func = name; } \ bool UseMe(typename tuples::attr_type) const override; \ } @@ -41,6 +44,13 @@ void VTanh(const T* x, T* y, int n); DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); DECLARE_MORE_KERNEL(VTanh, XYNTuples); +DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); +DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); + +DECLARE_MORE_KERNEL(GRUH1, GRUTuples); +DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); +DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); + #undef DECLARE_MORE_KERNEL } // namespace mix