diff --git a/paddle/fluid/operators/jit/README.md b/paddle/fluid/operators/jit/README.md index 2d72aa4d569aa717b16f5d29f2df28fe3c66a719..6b2f2b2848e47fa5fadf0e5874cbbc80ffb2c1a7 100644 --- a/paddle/fluid/operators/jit/README.md +++ b/paddle/fluid/operators/jit/README.md @@ -45,4 +45,6 @@ PaddlePaddle/Paddle/paddle/fluid/ - 在`KernelType` 中添加 `your_key` . - 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt`中`USE_JITKERNEL_REFER(your_key)`. -- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`. +- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。 +- 添加unit test,需要测试float和double +- 添加benchmark确保get得到的速度是最快。 diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 01467e324cc66c01ef9e89465bb3014a94dd9be8..ca636b020c222fa54f36d04cddf005ff02b14323 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -364,6 +364,85 @@ void BenchLSTMKernel() { } } +// return this function avg time +template +double BenchGRUFunc(const typename KernelTuples::func_type tgt, + const paddle::operators::jit::gru_attr_t* attr, + paddle::operators::jit::gru_t* step) { + for (int i = 0; i < FLAGS_burning; ++i) { + tgt(step, attr); + } + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeat; ++i) { + tgt(step, attr); + } + auto end = GetCurrentUS(); + return (end - start) / FLAGS_repeat; +} + +template +void BenchGRUKernel() { + namespace jit = paddle::operators::jit; + for (int d : TestSizes()) { + const jit::gru_attr_t attr(d, jit::vsigmoid, jit::vtanh); + std::vector> infos; + std::vector x(3 * d), ht_1(d), ht(d); + RandomVec(3 * d, x.data(), -2.f, 2.f); + RandomVec(d, ht_1.data(), -2.f, 2.f); + const T* ht_1_data = ht_1.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + + // test refer + auto refer = jit::GetRefer>(); + if (refer) { + auto res = BenchGRUFunc>(refer, &attr, &step); + infos.push_back(std::make_pair("Refer", res)); + } + // test jitcode + auto jitcode = jit::GetJitCode, PlaceType>(attr); + if (jitcode) { + auto res = BenchGRUFunc>(jitcode, &attr, &step); + infos.push_back(std::make_pair("JitCode", res)); + } + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = + dynamic_cast>*>(impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + auto res = BenchGRUFunc>(more, &attr, &step); + infos.push_back(std::make_pair("More", res)); + } + } + } + // Test result from Get function + auto tgt = jit::Get, PlaceType>(attr); + if (!tgt) { + LOG(ERROR) << "Target can not be empty!"; + } + auto res = BenchGRUFunc>(tgt, &attr, &step); + infos.push_back(std::make_pair("Target", res)); + // print + std::ostringstream loginfos; + loginfos << "Kernel Type: " << jit::to_string(KT) << ", Sigmoid,Tanh, size " + << d << ": "; + for (auto pair : infos) { + loginfos << pair.first << " takes " << pair.second << " us; "; + } + LOG(INFO) << loginfos.str(); + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -396,4 +475,9 @@ int main(int argc, char* argv[]) { // lstm and peephole BenchLSTMKernel(); BenchLSTMKernel(); + + // gru functions + BenchGRUKernel(); + BenchGRUKernel(); + BenchGRUKernel(); } diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index d6fa4891e38bf35cc361a3415f1c1c2abe3cbb1b..0543b0743c0475adf4148fd5e2e5ada95a074be7 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -39,6 +39,9 @@ const char* to_string(KernelType kt) { ONE_CASE(vtanh); ONE_CASE(lstmctht); ONE_CASE(lstmc1h1); + ONE_CASE(gruh1); + ONE_CASE(gruhtpart1); + ONE_CASE(gruhtpart2); default: PADDLE_THROW("Not support type: %d", kt); return "NOT JITKernel"; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 3ab0194ce2b9d602d0d2ba3fc1069b79b9558b1e..00d583c60bf73582dab7df75ec8feac1b8f3c3c9 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -33,7 +33,10 @@ typedef enum { vsigmoid, vtanh, lstmctht, - lstmc1h1 + lstmc1h1, + gruh1, + gruhtpart1, + gruhtpart2 } KernelType; template @@ -98,6 +101,13 @@ struct LSTMTuples { typedef void (*func_type)(lstm_t*, const lstm_attr_t*); }; +template +struct GRUTuples { + typedef T data_type; + typedef gru_attr_t attr_type; + typedef void (*func_type)(gru_t*, const gru_attr_t*); +}; + // Just for adding to kernel pool without template class Kernel { public: diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 7a9ae81f89f1253b2807d18d1554cc9fb1d92b98..4e6a19f04fd425b920aeea49b63001941d800a73 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -23,9 +23,10 @@ size_t JitCodeKey(const int& d) { return d; } +constexpr int act_type_shift = 3; // suppot 2^3 act types + template <> size_t JitCodeKey(const lstm_attr_t& attr) { - constexpr int act_type_shift = 3; // suppot 2^3 act types size_t key = attr.d; int gate_key = static_cast(attr.act_gate) << 1; int cand_key = static_cast(attr.act_cand) << (1 + act_type_shift); @@ -33,6 +34,14 @@ size_t JitCodeKey(const lstm_attr_t& attr) { return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + attr.use_peephole; } + +template <> +size_t JitCodeKey(const gru_attr_t& attr) { + size_t key = attr.d; + return (key << (act_type_shift * 2)) + static_cast(attr.act_gate) + + (static_cast(attr.act_cand) << act_type_shift); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index e30923c4fd76fbe137eff66783f7c74c9f93c06c..78d1cb8f9a7031b09302f50fecfa0f066c27b8cd 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -20,3 +20,6 @@ USE_JITKERNEL_REFER(vsigmoid) USE_JITKERNEL_REFER(vtanh) USE_JITKERNEL_REFER(lstmctht) USE_JITKERNEL_REFER(lstmc1h1) +USE_JITKERNEL_REFER(gruh1) +USE_JITKERNEL_REFER(gruhtpart1) +USE_JITKERNEL_REFER(gruhtpart2) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 59b3ce524864df6ef515939b3675e25257035766..c99174a66f3f07eba2905dcc2efb8266cae7b7b2 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -38,4 +38,8 @@ REGISTER_REFER_KERNEL(vtanh, VTanh); REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt); REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1); +REGISTER_REFER_KERNEL(gruh1, GRUH1); +REGISTER_REFER_KERNEL(gruhtpart1, GRUHtPart1); +REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index a93123df9d81f837b69f3f26daaf54ef9a62ff02..a9a6ffbccd8a3919ed0c71efb96ae260784ce837 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -125,6 +125,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT return nullptr; } +// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels + // compute ct and ht template void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { @@ -195,6 +197,51 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { VMul(gates + d2, gates + d3, ht, d); } +// compute h1 without h0 +template +void GRUH1(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + auto act_gate = getActFunc(attr->act_gate); + auto act_cand = getActFunc(attr->act_cand); + int d = attr->d; + int d2 = d * 2; + act_gate(gates, gates, d); + act_cand(gates + d2, gates + d2, d); + VMul(gates, gates + d2, ht, d); +} + +// compute the first part of GRU: ht = act_gate(r) * ht_1 +template +void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { + // W: {W_update, W_reset; W_state} + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + auto act_gate = getActFunc(attr->act_gate); + act_gate(gates + attr->d, gates + attr->d, attr->d); + VMul(ht_1, gates + attr->d, ht, attr->d); +} + +// compute the second part of GRU: +// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 +template +void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + auto act_gate = getActFunc(attr->act_gate); + auto act_cand = getActFunc(attr->act_cand); + int d = attr->d; + T* y = gates + d * 2; + act_gate(gates, gates, d); + act_cand(y, y, d); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -219,10 +266,15 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples); -// lstm_t* , const lstm_attr_t* +// lstm_t*, const lstm_attr_t* DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); +// gru_t*, const gru_attr_t* +DECLARE_REFER_KERNEL(GRUH1, GRUTuples); +DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); +DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 03e56416b2f95f61358d4e87748cf7070f20291f..d994a11f97d9aa9d9e41d3e7442ceedb1285edcb 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -485,6 +485,108 @@ TEST(JITKernel, lstmc1h1) { TestLSTMKernel(); } +template +void TestGRUFunc(const typename KernelTuples::func_type tgt, + const std::vector& xsrc, const std::vector& ht_1, + const std::vector& ht_ref, + const paddle::operators::jit::gru_attr_t& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ht_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ht(ht_ref.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + paddle::operators::jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + tgt(&step, &attr); + ExpectEQ(ht_data, ht_ref_data, d); +} + +template +void TestGRUKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; + for (int d : TestSizes()) { + for (auto& act_gate : all_acts) { + for (auto& act_cand : all_acts) { + std::string info = act_gate + act_cand + "size_" + std::to_string(d); + const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), + jit::to_kerneltype(act_cand)); + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector xsrc(3 * d), ht_1(d), ht_ref(d); + RandomVec(3 * d, xsrc.data(), -2.f, 2.f); + RandomVec(d, ht_1.data(), -2.f, 2.f); + // x could be changed after compute, so copy to save src + std::vector x(xsrc.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + T* x_data = x.data(); + T* ht_ref_data = ht_ref.data(); + jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_ref_data; + ref(&step, &attr); + + // test jitcode + auto jitcode = jit::GetJitCode, PlaceType>(attr); + if (jitcode) { + VLOG(10) << "Test Jitcode Kernel " << info; + TestGRUFunc>(jitcode, xsrc, ht_1, ht_ref, attr); + } + + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast>*>( + impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + VLOG(10) << "Test More Kernel " << info; + TestGRUFunc>(more, xsrc, ht_1, ht_ref, attr); + } + } + } + // Test result from Get function + auto tgt = jit::Get, PlaceType>(attr); + TestGRUFunc>(tgt, xsrc, ht_1, ht_ref, attr); + } + } + } +} + +TEST(JITKernel, gruh1) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + +TEST(JITKernel, gruhtpart1) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + +TEST(JITKernel, gruhtpart2) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + // TODO(TJ): refine the tests template TEST(JITKernel, pool) { diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h index 122cbcb0d6ffdd5c15b8d440dc64cc62583a9094..d49fc935dc530d75130bef35ff013df30df72896 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/math/jit_kernel_refer.h @@ -22,54 +22,7 @@ namespace paddle { namespace operators { namespace math { namespace jitkernel { -namespace refer { - -// compute h1 without h0 -template -void GRUH1(gru_t* step, const gru_attr_t* attr) { - T* gates = reinterpret_cast(step->gates); - T* ht = reinterpret_cast(step->ht); - auto act_gate = getActFunc(attr->act_gate); - auto act_cand = getActFunc(attr->act_cand); - int d = attr->d; - int d2 = d * 2; - act_gate(gates, gates, d); - act_cand(gates + d2, gates + d2, d); - VMul(gates, gates + d2, ht, d); -} - -// compute the first part of GRU: ht = act_gate(r) * ht_1 -template -void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { - // W: {W_update, W_reset; W_state} - T* gates = reinterpret_cast(step->gates); - T* ht = reinterpret_cast(step->ht); - const T* ht_1 = reinterpret_cast(step->ht_1); - auto act_gate = getActFunc(attr->act_gate); - act_gate(gates + attr->d, gates + attr->d, attr->d); - VMul(ht_1, gates + attr->d, ht, attr->d); -} - -// compute the second part of GRU: -// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 -template -void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { - T* gates = reinterpret_cast(step->gates); - T* ht = reinterpret_cast(step->ht); - const T* ht_1 = reinterpret_cast(step->ht_1); - auto act_gate = getActFunc(attr->act_gate); - auto act_cand = getActFunc(attr->act_cand); - int d = attr->d; - T* y = gates + d * 2; - act_gate(gates, gates, d); - act_cand(y, y, d); - // out = zt*ht~ + (1-zt)*ht_1 - for (int i = 0; i < d; ++i) { - ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; - } -} - -} // namespace refer +namespace refer {} // namespace refer } // namespace jitkernel } // namespace math } // namespace operators