diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h index bcb6615df8409f75b7393df734f528cf7e7d9efb..e0b2e3c7fada6b422318c68a42fd6d103c99af5a 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/math/jit_kernel_refer.h @@ -116,6 +116,7 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT return nullptr; } +// compute ct and ht template void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { T* gates = reinterpret_cast(step->gates); @@ -199,6 +200,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { VMul(gates, gates + d2, ht, d); } +// compute the first part of GRU: ht = act_gate(r) * ht_1 template void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { // W: {W_update, W_reset; W_state} @@ -210,6 +212,8 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { VMul(ht_1, gates + attr->d, ht, attr->d); } +// compute the second part of GRU: +// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 template void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { T* gates = reinterpret_cast(step->gates); diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index cc8a5d4d862cc1b5198d96edb0bf628bc89e4b88..6476e95dc02967730005e0f2b9d794d146d9039e 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -86,7 +86,7 @@ TEST(JitKernel, vrelu) { vrelu_intri8(d, x_data, zref_data); } auto si1 = GetCurrentUS(); - VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us"; } #endif auto ttgts = GetCurrentUS(); @@ -96,7 +96,7 @@ TEST(JitKernel, vrelu) { auto ttgte = GetCurrentUS(); VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat; + << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -129,7 +129,7 @@ TEST(JitKernel, vaddbias) { VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat; + << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -182,7 +182,7 @@ TEST(JitKernel, vexp) { #else << " us, " #endif - << "tgt takes: " << (ttgte - ttgts) / repeat; + << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -238,7 +238,7 @@ TEST(JitKernel, vsigmoid) { VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat; + << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -299,7 +299,7 @@ TEST(JitKernel, vtanh) { VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat; + << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -400,7 +400,7 @@ TEST(JitKernel, lstm) { VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat << " us, better(jit) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat; + << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; } } @@ -474,7 +474,7 @@ TEST(JitKernel, vscal) { } auto si3 = GetCurrentUS(); VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat - << " us, inplace: " << (si3 - si2) / repeat; + << " us, inplace: " << (si3 - si2) / repeat << " us"; } #endif @@ -498,7 +498,8 @@ TEST(JitKernel, vscal) { << " us, " #endif << "tgt takes: " << (ttgte - ttgts) / repeat - << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; + << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat + << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -573,7 +574,7 @@ TEST(JitKernel, vmul) { #else << " us, " #endif - << "tgt takes: " << (ttgte - ttgts) / repeat; + << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -648,7 +649,7 @@ TEST(JitKernel, vadd) { #else << " us, " #endif - << "tgt takes: " << (ttgte - ttgts) / repeat; + << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); } @@ -701,7 +702,7 @@ TEST(JitKernel, vaddrelu) { VLOG(30) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat << " us, better takes: " << (tmkle - tmkls) / repeat << " us, " - << "tgt takes: " << (ttgte - ttgts) / repeat; + << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); }