From bf9302f95015db6cadf3e814cfc4f21ef8434a3d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 13 Dec 2018 10:18:22 +0000 Subject: [PATCH] add lstm, peephole refer and test --- paddle/fluid/operators/jit/gen_base.cc | 5 - paddle/fluid/operators/jit/gen_base.h | 4 - paddle/fluid/operators/jit/helper.cc | 20 +++ paddle/fluid/operators/jit/helper.h | 4 +- paddle/fluid/operators/jit/kernel_base.h | 54 ++++++- paddle/fluid/operators/jit/kernel_key.cc | 38 +++++ paddle/fluid/operators/jit/kernel_key.h | 4 + .../fluid/operators/jit/refer/CMakeLists.txt | 2 + paddle/fluid/operators/jit/refer/refer.cc | 3 + paddle/fluid/operators/jit/refer/refer.h | 89 ++++++++++++ paddle/fluid/operators/jit/test.cc | 137 ++++++++++++++++++ paddle/fluid/operators/math/jit_kernel_impl.h | 39 ----- .../fluid/operators/math/jit_kernel_refer.h | 85 ----------- 13 files changed, 346 insertions(+), 138 deletions(-) create mode 100644 paddle/fluid/operators/jit/kernel_key.cc diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc index a8bf90296..310da0c76 100644 --- a/paddle/fluid/operators/jit/gen_base.cc +++ b/paddle/fluid/operators/jit/gen_base.cc @@ -23,11 +23,6 @@ namespace paddle { namespace operators { namespace jit { -template <> -size_t JitCodeKey(int d) { - return d; -} - // refer do not need useme, it would be the last one. void GenBase::dumpCode(const unsigned char* code) const { if (code) { diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h index 586f4389c..48855abd2 100644 --- a/paddle/fluid/operators/jit/gen_base.h +++ b/paddle/fluid/operators/jit/gen_base.h @@ -43,10 +43,6 @@ class GenBase : public Kernel { void dumpCode(const unsigned char* code) const; }; -// Every JitCode should have a method to get the key from attribution -template -size_t JitCodeKey(Attr attr); - // Creator is used to creat the jitcode and save in pool. // Every JitCode should have one creator. class GenCreator { diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index c010b64c9..d6fa4891e 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/helper.h" +#include // tolower #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -36,6 +37,8 @@ const char* to_string(KernelType kt) { ONE_CASE(vexp); ONE_CASE(vsigmoid); ONE_CASE(vtanh); + ONE_CASE(lstmctht); + ONE_CASE(lstmc1h1); default: PADDLE_THROW("Not support type: %d", kt); return "NOT JITKernel"; @@ -44,6 +47,23 @@ const char* to_string(KernelType kt) { } #undef ONE_CASE +KernelType to_kerneltype(const std::string& act) { + std::string lower = act; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + if (lower == "relu" || lower == "vrelu") { + return vrelu; + } else if (lower == "identity" || lower == "videntity" || lower == "") { + return videntity; + } else if (lower == "exp" || lower == "vexp") { + return vexp; + } else if (lower == "sigmoid" || lower == "vsigmoid") { + return vsigmoid; + } else if (lower == "tanh" || lower == "vtanh") { + return vtanh; + } + return non_kernel; +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 053e5ed07..302e70caa 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -14,9 +14,7 @@ #pragma once -#include // for unique_ptr #include -#include #include #include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/kernel_base.h" @@ -124,6 +122,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { const char* to_string(KernelType kt); +KernelType to_kerneltype(const std::string& act); + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 29b881b75..3ab0194ce 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -20,8 +20,9 @@ namespace operators { namespace jit { typedef enum { - vmul = 0, - vadd = 1, + non_kernel = 0, + vmul = 1, + vadd = 2, vaddrelu, vsub, vscal, @@ -30,7 +31,9 @@ typedef enum { videntity, vexp, vsigmoid, - vtanh + vtanh, + lstmctht, + lstmc1h1 } KernelType; template @@ -50,6 +53,51 @@ struct XYNTuples { typedef void (*func_type)(const T*, T*, int); }; +typedef struct { + void* gates; // gates: x_ch, x_ih, x_fh, x_oh + const void* ct_1; + void* ct; + void* ht; + /* weight_peephole and checked data are only used in peephole*/ + const void* wp{nullptr}; // W_ic, W_fc, W_oc + void* checked{nullptr}; // size: 2 * d +} lstm_t; + +typedef struct { + void* gates; // gates: {x_update, x_reset; x_state} + const void* ht_1; + void* ht; +} gru_t; + +struct rnn_attr_s { + int d; + KernelType act_gate, act_cand; + rnn_attr_s() = default; + rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) + : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} +}; + +struct lstm_attr_s : public rnn_attr_s { + bool use_peephole; + KernelType act_cell; + lstm_attr_s() = default; + lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, + KernelType _act_cell, bool _use_peephole = false) + : rnn_attr_s(_d, _act_gate, _act_cand), + use_peephole(_use_peephole), + act_cell(_act_cell) {} +}; + +typedef struct rnn_attr_s gru_attr_t; +typedef struct lstm_attr_s lstm_attr_t; + +template +struct LSTMTuples { + typedef T data_type; + typedef lstm_attr_t attr_type; + typedef void (*func_type)(lstm_t*, const lstm_attr_t*); +}; + // Just for adding to kernel pool without template class Kernel { public: diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc new file mode 100644 index 000000000..7a9ae81f8 --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/kernel_key.h" + +namespace paddle { +namespace operators { +namespace jit { + +template <> +size_t JitCodeKey(const int& d) { + return d; +} + +template <> +size_t JitCodeKey(const lstm_attr_t& attr) { + constexpr int act_type_shift = 3; // suppot 2^3 act types + size_t key = attr.d; + int gate_key = static_cast(attr.act_gate) << 1; + int cand_key = static_cast(attr.act_cand) << (1 + act_type_shift); + int cell_key = static_cast(attr.act_cell) << (1 + act_type_shift * 2); + return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + + attr.use_peephole; +} +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_key.h b/paddle/fluid/operators/jit/kernel_key.h index af9df7733..611a0210d 100644 --- a/paddle/fluid/operators/jit/kernel_key.h +++ b/paddle/fluid/operators/jit/kernel_key.h @@ -44,6 +44,10 @@ struct KernelKey { bool operator!=(const KernelKey& o) const { return !(*this == o); } }; +// Every JitCode should have a method to get the key from attribution +template +size_t JitCodeKey(const Attr& attr); + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index dc07ddb91..e30923c4f 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -18,3 +18,5 @@ USE_JITKERNEL_REFER(videntity) USE_JITKERNEL_REFER(vexp) USE_JITKERNEL_REFER(vsigmoid) USE_JITKERNEL_REFER(vtanh) +USE_JITKERNEL_REFER(lstmctht) +USE_JITKERNEL_REFER(lstmc1h1) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index f716ca89c..59b3ce524 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -35,4 +35,7 @@ REGISTER_REFER_KERNEL(vexp, VExp); REGISTER_REFER_KERNEL(vsigmoid, VSigmoid); REGISTER_REFER_KERNEL(vtanh, VTanh); +REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt); +REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 7ef60a2d5..a93123df9 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -110,6 +110,91 @@ void VTanh(const T* x, T* y, int n) { } } +template +void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT + if (type == vsigmoid) { + return VSigmoid; + } else if (type == vrelu) { + return VRelu; + } else if (type == vtanh) { + return VTanh; + } else if (type == videntity) { + return VIdentity; + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} + +// compute ct and ht +template +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + const T* ct_1 = reinterpret_cast(step->ct_1); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + const T* wp = reinterpret_cast(step->wp); + T* checked = reinterpret_cast(step->checked); + auto act_gate = getActFunc(attr->act_gate); + auto act_cand = getActFunc(attr->act_cand); + auto act_cell = getActFunc(attr->act_cell); + int d = attr->d; + int d2 = d * 2; + int d3 = d * 3; + // gates: W_ch, W_ih, W_fh, W_oh + if (attr->use_peephole) { + VMul(wp, ct_1, checked, d); + VMul(wp + d, ct_1, checked + d, d); + VAdd(checked, gates + d, gates + d, d2); + act_gate(gates + d, gates + d, d2); + } else { + act_gate(gates + d, gates + d, d3); + } + + // C_t = C_t-1 * fgated + cand_gated * igated + act_cand(gates, gates, d); + VMul(gates, gates + d, gates + d, d); + VMul(ct_1, gates + d2, gates + d2, d); + VAdd(gates + d, gates + d2, ct, d); + + if (attr->use_peephole) { + // get ogated + VMul(wp + d2, ct, gates + d, d); + VAdd(gates + d, gates + d3, gates + d3, d); + act_gate(gates + d3, gates + d3, d); + } + // H_t = act_cell(C_t) * ogated + act_cell(ct, gates + d2, d); + VMul(gates + d2, gates + d3, ht, d); +} + +// compute c1 and h1 without c0 or h0 +template +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + auto act_gate = getActFunc(attr->act_gate); + auto act_cand = getActFunc(attr->act_cand); + auto act_cell = getActFunc(attr->act_cell); + int d = attr->d; + int d2 = d * 2; + int d3 = d * 3; + /* C_t = igated * cgated*/ + act_gate(gates + d, gates + d, d); + act_cand(gates, gates, d); + VMul(gates, gates + d, ct, d); + if (attr->use_peephole) { + // get outgated, put W_oc * C_t on igated + const T* wp = reinterpret_cast(step->wp); + VMul(wp + d2, ct, gates + d, d); + VAdd(gates + d, gates + d3, gates + d3, d); + } + /* H_t = act_cell(C_t) * ogated */ + act_gate(gates + d3, gates + d3, d); + act_cell(ct, gates + d2, d); + VMul(gates + d2, gates + d3, ht, d); +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -134,6 +219,10 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples); +// lstm_t* , const lstm_attr_t* +DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); +DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 4c9b853b6..03e56416b 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) { TestXYNKernel(); } +template +void TestLSTMFunc(const typename KernelTuples::func_type tgt, + const std::vector& xsrc, const std::vector& wp, + const std::vector& ct_1, const std::vector& ct_ref, + const std::vector& ht_ref, + const paddle::operators::jit::lstm_attr_t& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ct_ref.size(), ht_ref.size()); + EXPECT_EQ(ct_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); + EXPECT_EQ(wp.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); + std::vector checked(2 * d); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + const T* ct_ref_data = ct_ref.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + T* checked_data = checked.data(); + + paddle::operators::jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (attr.use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + + tgt(&step, &attr); + ExpectEQ(ct_data, ct_ref_data, d); + ExpectEQ(ht_data, ht_ref_data, d); +} + +template +void TestLSTMKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; + for (int d : TestSizes()) { + for (bool use_peephole : {true, false}) { + for (auto& act_gate : all_acts) { + for (auto& act_cand : all_acts) { + for (auto& act_cell : all_acts) { + std::string info = act_gate + act_cand + act_cell + + (use_peephole ? "peephole_" : "") + "size_" + + std::to_string(d); + const jit::lstm_attr_t attr( + d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), + jit::to_kerneltype(act_cell), use_peephole); + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector xsrc(4 * d), wp(3 * d), ct_1(d); + std::vector ct_ref(d), ht_ref(d), checked(2 * d); + RandomVec(4 * d, xsrc.data(), -2.f, 2.f); + RandomVec(3 * d, wp.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + // x could be changed after compute, so copy to save src + std::vector x(xsrc.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + T* x_data = x.data(); + T* checked_data = checked.data(); + T* ct_ref_data = ct_ref.data(); + T* ht_ref_data = ht_ref.data(); + jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_ref_data; + step.ht = ht_ref_data; + if (use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + ref(&step, &attr); + + // test jitcode + auto jitcode = + jit::GetJitCode, PlaceType>(attr); + if (jitcode) { + VLOG(10) << "Test Jitcode Kernel " << info; + TestLSTMFunc>(jitcode, xsrc, wp, ct_1, + ct_ref, ht_ref, attr); + } + + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = + dynamic_cast>*>( + impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + VLOG(10) << "Test More Kernel " << info; + TestLSTMFunc>(more, xsrc, wp, ct_1, + ct_ref, ht_ref, attr); + } + } + } + // Test result from Get function + auto tgt = jit::Get, PlaceType>(attr); + TestLSTMFunc>(tgt, xsrc, wp, ct_1, ct_ref, + ht_ref, attr); + } + } + } + } + } +} + +TEST(JITKernel, lstmctht) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +TEST(JITKernel, lstmc1h1) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +// TODO(TJ): refine the tests template + TEST(JITKernel, pool) { // TODO(TJ): add some test } diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h index ba5f20e53..025343dfa 100644 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ b/paddle/fluid/operators/math/jit_kernel_impl.h @@ -28,45 +28,6 @@ namespace jitkernel { #define YMM_FLOAT_BLOCK 8 #define ZMM_FLOAT_BLOCK 16 -typedef struct { - void* gates; // gates: W_ch, W_ih, W_fh, W_oh - const void* ct_1; - void* ct; - void* ht; - /* weight_peephole and checked data are only used in peephole*/ - const void* wp{nullptr}; - void* checked{nullptr}; -} lstm_t; - -typedef struct { - void* gates; // gates: {W_update, W_reset; W_state} - const void* ht_1; - void* ht; -} gru_t; - -struct rnn_attr_s { - int d; - std::string act_gate, act_cand; - rnn_attr_s() = default; - rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand) - : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} -}; - -struct lstm_attr_s : public rnn_attr_s { - bool use_peephole; - std::string act_cell; - lstm_attr_s() = default; - lstm_attr_s(int _d, const std::string& _act_gate, - const std::string& _act_cand, const std::string& _act_cell, - bool _use_peephole = false) - : rnn_attr_s(_d, _act_gate, _act_cand), - use_peephole(_use_peephole), - act_cell(_act_cell) {} -}; - -typedef struct rnn_attr_s gru_attr_t; -typedef struct lstm_attr_s lstm_attr_t; - } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h index a03e851de..122cbcb0d 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/math/jit_kernel_refer.h @@ -24,91 +24,6 @@ namespace math { namespace jitkernel { namespace refer { -template -void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT - if (type == "sigmoid") { - return VSigmoid; - } else if (type == "relu") { - return VRelu; - } else if (type == "tanh") { - return VTanh; - } else if (type == "identity" || type == "") { - return VIdentity; - } - PADDLE_THROW("Not support type: %s", type); - return nullptr; -} - -// compute ct and ht -template -void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { - T* gates = reinterpret_cast(step->gates); - const T* ct_1 = reinterpret_cast(step->ct_1); - T* ct = reinterpret_cast(step->ct); - T* ht = reinterpret_cast(step->ht); - const T* wp = reinterpret_cast(step->wp); - T* checked = reinterpret_cast(step->checked); - auto act_gate = getActFunc(attr->act_gate); - auto act_cand = getActFunc(attr->act_cand); - auto act_cell = getActFunc(attr->act_cell); - int d = attr->d; - int d2 = d * 2; - int d3 = d * 3; - // gates: W_ch, W_ih, W_fh, W_oh - if (attr->use_peephole) { - VMul(wp, ct_1, checked, d); - VMul(wp + d, ct_1, checked + d, d); - VAdd(checked, gates + d, gates + d, d2); - act_gate(gates + d, gates + d, d2); - } else { - act_gate(gates + d, gates + d, d3); - } - - // C_t = C_t-1 * fgated + cand_gated * igated - act_cand(gates, gates, d); - VMul(gates, gates + d, gates + d, d); - VMul(ct_1, gates + d2, gates + d2, d); - VAdd(gates + d, gates + d2, ct, d); - - if (attr->use_peephole) { - // get ogated - VMul(wp + d2, ct, gates + d, d); - VAdd(gates + d, gates + d3, gates + d3, d); - act_gate(gates + d3, gates + d3, d); - } - // H_t = act_cell(C_t) * ogated - act_cell(ct, gates + d2, d); - VMul(gates + d2, gates + d3, ht, d); -} - -// compute c1 and h1 without c0 or h0 -template -void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { - T* gates = reinterpret_cast(step->gates); - T* ct = reinterpret_cast(step->ct); - T* ht = reinterpret_cast(step->ht); - auto act_gate = getActFunc(attr->act_gate); - auto act_cand = getActFunc(attr->act_cand); - auto act_cell = getActFunc(attr->act_cell); - int d = attr->d; - int d2 = d * 2; - int d3 = d * 3; - /* C_t = igated * cgated*/ - act_gate(gates + d, gates + d, d); - act_cand(gates, gates, d); - VMul(gates, gates + d, ct, d); - if (attr->use_peephole) { - // get outgated, put W_oc * C_t on igated - const T* wp = reinterpret_cast(step->wp); - VMul(wp + d2, ct, gates + d, d); - VAdd(gates + d, gates + d3, gates + d3, d); - } - /* H_t = act_cell(C_t) * ogated */ - act_gate(gates + d3, gates + d3, d); - act_cell(ct, gates + d2, d); - VMul(gates + d2, gates + d3, ht, d); -} - // compute h1 without h0 template void GRUH1(gru_t* step, const gru_attr_t* attr) { -- GitLab