diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index d2de4545cec4e3035fea5e74a86026cc4e524a19..d65a3299c5bc4e87f50a58042221b52553533f0c 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -328,6 +328,123 @@ TEST(JitKernel, vtanh) { } } +void lstm_ctht_ref( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid_3d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, + const int d, float* gates, const float* ct_1, float* ct, float* ht) { + vsigmoid_3d->Compute(gates + d, gates + d); + vtanh_d->Compute(gates, gates); + const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + for (int k = 0; k < d; ++k) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[k] = ct_1[k] * f[k] + gates[k] * i[k]; + // H_t = act_cell(C_t) * ogated + float tmp = ct[k] * 2; + tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vexp_1->Compute(&tmp, &tmp); + tmp = 2.f / (1.f + tmp) - 1.f; + ht[k] = tmp * o[k]; + } +} + +void lstm_ctht_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid_3d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VMulKernel>& vmul_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, + const int d, float* gates, const float* ct_1, float* ct, float* ht) { + int d2 = d * 2; + vsigmoid_3d->Compute(gates + d, gates + d); + vtanh_d->Compute(gates, gates); + vmul_d->Compute(gates, gates + d, gates + d); + vmul_d->Compute(ct_1, gates + d2, gates + d2); + vadd_d->Compute(gates + d, gates + d2, ct); + /* H_t = act_cell(C_t) * ogated */ + vtanh_d->Compute(ct, gates + d2); + vmul_d->Compute(gates + d2, gates + d * 3, ht); +} + +TEST(JitKernel, lstm) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) { + int d4 = d * 4; + int d3 = d * 3; + std::vector x(d4), xref(d4); + std::vector ct_1(d), ct_tgt(d), ht_tgt(d); + std::vector ct_ref(d), ht_ref(d); + RandomVec(d4, x.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + memcpy(xref.data(), x.data(), sizeof(float) * d4); + std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; + const auto& ker = + jit::KernelPool::Instance() + .template Get, int, const std::string&, + const std::string&, const std::string&>( + d, act_gate, act_cand, act_cell); + // below kernels are used to compute refer + const auto& vsigmoid_3d = + jit::KernelPool::Instance().template Get>( + d3); + const auto& vtanh_d = + jit::KernelPool::Instance().template Get>(d); + const auto& vexp_1 = + jit::KernelPool::Instance().template Get>(1); + const auto& vmul_d = + jit::KernelPool::Instance().template Get>(d); + const auto& vadd_d = + jit::KernelPool::Instance().template Get>(d); + + float* x_data = x.data(); + float* xref_data = xref.data(); + const float* ct_1_data = ct_1.data(); + float* ct_tgt_data = ct_tgt.data(); + float* ht_tgt_data = ht_tgt.data(); + float* ct_ref_data = ct_ref.data(); + float* ht_ref_data = ht_ref.data(); + // compute once to check correctness + lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, + ct_ref_data, ht_ref_data); + ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); + EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); + } + + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data, + ct_1_data, ct_ref_data, ht_ref_data); + } + auto tmkle = GetCurrentUS(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, + ct_ref_data, ht_ref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + } + auto ttgte = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better(jit) takes: " << (tmkle - tmkls) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + } +} + void vscal_ref(const int n, const float a, const float* x, float* y) { for (int i = 0; i < n; ++i) { y[i] = a * x[i];