提交 bf9302f9 编写于 作者: T tensor-tang

add lstm, peephole refer and test

上级 bf951fa7
...@@ -23,11 +23,6 @@ namespace paddle { ...@@ -23,11 +23,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <>
size_t JitCodeKey<int>(int d) {
return d;
}
// refer do not need useme, it would be the last one. // refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const { void GenBase::dumpCode(const unsigned char* code) const {
if (code) { if (code) {
......
...@@ -43,10 +43,6 @@ class GenBase : public Kernel { ...@@ -43,10 +43,6 @@ class GenBase : public Kernel {
void dumpCode(const unsigned char* code) const; void dumpCode(const unsigned char* code) const;
}; };
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(Attr attr);
// Creator is used to creat the jitcode and save in pool. // Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator. // Every JitCode should have one creator.
class GenCreator { class GenCreator {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h" #include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -36,6 +37,8 @@ const char* to_string(KernelType kt) { ...@@ -36,6 +37,8 @@ const char* to_string(KernelType kt) {
ONE_CASE(vexp); ONE_CASE(vexp);
ONE_CASE(vsigmoid); ONE_CASE(vsigmoid);
ONE_CASE(vtanh); ONE_CASE(vtanh);
ONE_CASE(lstmctht);
ONE_CASE(lstmc1h1);
default: default:
PADDLE_THROW("Not support type: %d", kt); PADDLE_THROW("Not support type: %d", kt);
return "NOT JITKernel"; return "NOT JITKernel";
...@@ -44,6 +47,23 @@ const char* to_string(KernelType kt) { ...@@ -44,6 +47,23 @@ const char* to_string(KernelType kt) {
} }
#undef ONE_CASE #undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") {
return vrelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") {
return videntity;
} else if (lower == "exp" || lower == "vexp") {
return vexp;
} else if (lower == "sigmoid" || lower == "vsigmoid") {
return vsigmoid;
} else if (lower == "tanh" || lower == "vtanh") {
return vtanh;
}
return non_kernel;
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
#pragma once #pragma once
#include <memory> // for unique_ptr
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -124,6 +122,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { ...@@ -124,6 +122,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
KernelType to_kerneltype(const std::string& act);
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,8 +20,9 @@ namespace operators { ...@@ -20,8 +20,9 @@ namespace operators {
namespace jit { namespace jit {
typedef enum { typedef enum {
vmul = 0, non_kernel = 0,
vadd = 1, vmul = 1,
vadd = 2,
vaddrelu, vaddrelu,
vsub, vsub,
vscal, vscal,
...@@ -30,7 +31,9 @@ typedef enum { ...@@ -30,7 +31,9 @@ typedef enum {
videntity, videntity,
vexp, vexp,
vsigmoid, vsigmoid,
vtanh vtanh,
lstmctht,
lstmc1h1
} KernelType; } KernelType;
template <typename T> template <typename T>
...@@ -50,6 +53,51 @@ struct XYNTuples { ...@@ -50,6 +53,51 @@ struct XYNTuples {
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr}; // W_ic, W_fc, W_oc
void* checked{nullptr}; // size: 2 * d
} lstm_t;
typedef struct {
void* gates; // gates: {x_update, x_reset; x_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
// Just for adding to kernel pool without template // Just for adding to kernel pool without template
class Kernel { class Kernel {
public: public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
return d;
}
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
constexpr int act_type_shift = 3; // suppot 2^3 act types
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -44,6 +44,10 @@ struct KernelKey { ...@@ -44,6 +44,10 @@ struct KernelKey {
bool operator!=(const KernelKey& o) const { return !(*this == o); } bool operator!=(const KernelKey& o) const { return !(*this == o); }
}; };
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,3 +18,5 @@ USE_JITKERNEL_REFER(videntity) ...@@ -18,3 +18,5 @@ USE_JITKERNEL_REFER(videntity)
USE_JITKERNEL_REFER(vexp) USE_JITKERNEL_REFER(vexp)
USE_JITKERNEL_REFER(vsigmoid) USE_JITKERNEL_REFER(vsigmoid)
USE_JITKERNEL_REFER(vtanh) USE_JITKERNEL_REFER(vtanh)
USE_JITKERNEL_REFER(lstmctht)
USE_JITKERNEL_REFER(lstmc1h1)
...@@ -35,4 +35,7 @@ REGISTER_REFER_KERNEL(vexp, VExp); ...@@ -35,4 +35,7 @@ REGISTER_REFER_KERNEL(vexp, VExp);
REGISTER_REFER_KERNEL(vsigmoid, VSigmoid); REGISTER_REFER_KERNEL(vsigmoid, VSigmoid);
REGISTER_REFER_KERNEL(vtanh, VTanh); REGISTER_REFER_KERNEL(vtanh, VTanh);
REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt);
REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -110,6 +110,91 @@ void VTanh(const T* x, T* y, int n) { ...@@ -110,6 +110,91 @@ void VTanh(const T* x, T* y, int n) {
} }
} }
template <typename T>
void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
if (type == vsigmoid) {
return VSigmoid<T>;
} else if (type == vrelu) {
return VRelu<T>;
} else if (type == vtanh) {
return VTanh<T>;
} else if (type == videntity) {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
// compute ct and ht
template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute c1 and h1 without c0 or h0
template <typename T>
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -134,6 +219,10 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); ...@@ -134,6 +219,10 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples);
// lstm_t* , const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) { ...@@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) {
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
} }
template <typename T, typename KernelTuples>
void TestLSTMFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& wp,
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
const std::vector<T>& ht_ref,
const paddle::operators::jit::lstm_attr_t& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ct_ref.size(), ht_ref.size());
EXPECT_EQ(ct_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
EXPECT_EQ(wp.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
std::vector<T> checked(2 * d);
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
const T* ct_ref_data = ct_ref.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
T* checked_data = checked.data();
paddle::operators::jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (attr.use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
tgt(&step, &attr);
ExpectEQ<T>(ct_data, ct_ref_data, d);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
for (bool use_peephole : {true, false}) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
for (auto& act_cell : all_acts) {
std::string info = act_gate + act_cand + act_cell +
(use_peephole ? "peephole_" : "") + "size_" +
std::to_string(d);
const jit::lstm_attr_t attr(
d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
jit::to_kerneltype(act_cell), use_peephole);
auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
// x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_ref_data = ct_ref.data();
T* ht_ref_data = ht_ref.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
ref(&step, &attr);
// test jitcode
auto jitcode =
jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel " << info;
TestLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, xsrc, wp, ct_1,
ct_ref, ht_ref, attr);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel " << info;
TestLSTMFunc<T, jit::LSTMTuples<T>>(more, xsrc, wp, ct_1,
ct_ref, ht_ref, attr);
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
TestLSTMFunc<T, jit::LSTMTuples<T>>(tgt, xsrc, wp, ct_1, ct_ref,
ht_ref, attr);
}
}
}
}
}
}
TEST(JITKernel, lstmctht) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, lstmc1h1) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
}
// TODO(TJ): refine the tests template
TEST(JITKernel, pool) { TEST(JITKernel, pool) {
// TODO(TJ): add some test // TODO(TJ): add some test
} }
...@@ -28,45 +28,6 @@ namespace jitkernel { ...@@ -28,45 +28,6 @@ namespace jitkernel {
#define YMM_FLOAT_BLOCK 8 #define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
typedef struct {
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr};
void* checked{nullptr};
} lstm_t;
typedef struct {
void* gates; // gates: {W_update, W_reset; W_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
std::string act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
std::string act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -24,91 +24,6 @@ namespace math { ...@@ -24,91 +24,6 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace refer { namespace refer {
template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") {
return VSigmoid<T>;
} else if (type == "relu") {
return VRelu<T>;
} else if (type == "tanh") {
return VTanh<T>;
} else if (type == "identity" || type == "") {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
// compute ct and ht
template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
if (attr->use_peephole) {
VMul(wp, ct_1, checked, d);
VMul(wp + d, ct_1, checked + d, d);
VAdd(checked, gates + d, gates + d, d2);
act_gate(gates + d, gates + d, d2);
} else {
act_gate(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
act_gate(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute c1 and h1 without c0 or h0
template <typename T>
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
VMul(wp + d2, ct, gates + d, d);
VAdd(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
// compute h1 without h0 // compute h1 without h0
template <typename T> template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) { void GRUH1(gru_t* step, const gru_attr_t* attr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册