提交 1f0291a5 编写于 作者: T tensor-tang

add comments and follow comments

test=develop
上级 557229bd
...@@ -116,6 +116,7 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT ...@@ -116,6 +116,7 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
return nullptr; return nullptr;
} }
// compute ct and ht
template <typename T> template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates); T* gates = reinterpret_cast<T*>(step->gates);
...@@ -199,6 +200,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { ...@@ -199,6 +200,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
VMul(gates, gates + d2, ht, d); VMul(gates, gates + d2, ht, d);
} }
// compute the first part of GRU: ht = act_gate(r) * ht_1
template <typename T> template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
...@@ -210,6 +212,8 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -210,6 +212,8 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
VMul(ht_1, gates + attr->d, ht, attr->d); VMul(ht_1, gates + attr->d, ht, attr->d);
} }
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template <typename T> template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates); T* gates = reinterpret_cast<T*>(step->gates);
......
...@@ -86,7 +86,7 @@ TEST(JitKernel, vrelu) { ...@@ -86,7 +86,7 @@ TEST(JitKernel, vrelu) {
vrelu_intri8(d, x_data, zref_data); vrelu_intri8(d, x_data, zref_data);
} }
auto si1 = GetCurrentUS(); auto si1 = GetCurrentUS();
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us";
} }
#endif #endif
auto ttgts = GetCurrentUS(); auto ttgts = GetCurrentUS();
...@@ -96,7 +96,7 @@ TEST(JitKernel, vrelu) { ...@@ -96,7 +96,7 @@ TEST(JitKernel, vrelu) {
auto ttgte = GetCurrentUS(); auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -129,7 +129,7 @@ TEST(JitKernel, vaddbias) { ...@@ -129,7 +129,7 @@ TEST(JitKernel, vaddbias) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -182,7 +182,7 @@ TEST(JitKernel, vexp) { ...@@ -182,7 +182,7 @@ TEST(JitKernel, vexp) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -238,7 +238,7 @@ TEST(JitKernel, vsigmoid) { ...@@ -238,7 +238,7 @@ TEST(JitKernel, vsigmoid) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -299,7 +299,7 @@ TEST(JitKernel, vtanh) { ...@@ -299,7 +299,7 @@ TEST(JitKernel, vtanh) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -400,7 +400,7 @@ TEST(JitKernel, lstm) { ...@@ -400,7 +400,7 @@ TEST(JitKernel, lstm) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat << " us, better(jit) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat; << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
} }
} }
...@@ -474,7 +474,7 @@ TEST(JitKernel, vscal) { ...@@ -474,7 +474,7 @@ TEST(JitKernel, vscal) {
} }
auto si3 = GetCurrentUS(); auto si3 = GetCurrentUS();
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
<< " us, inplace: " << (si3 - si2) / repeat; << " us, inplace: " << (si3 - si2) / repeat << " us";
} }
#endif #endif
...@@ -498,7 +498,8 @@ TEST(JitKernel, vscal) { ...@@ -498,7 +498,8 @@ TEST(JitKernel, vscal) {
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << "tgt takes: " << (ttgte - ttgts) / repeat
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat
<< " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -573,7 +574,7 @@ TEST(JitKernel, vmul) { ...@@ -573,7 +574,7 @@ TEST(JitKernel, vmul) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -648,7 +649,7 @@ TEST(JitKernel, vadd) { ...@@ -648,7 +649,7 @@ TEST(JitKernel, vadd) {
#else #else
<< " us, " << " us, "
#endif #endif
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
...@@ -701,7 +702,7 @@ TEST(JitKernel, vaddrelu) { ...@@ -701,7 +702,7 @@ TEST(JitKernel, vaddrelu) {
VLOG(30) << "Vec size " << d VLOG(30) << "Vec size " << d
<< ": refer takes: " << (trefe - trefs) / repeat << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, " << " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat; << "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) { for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册