diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index a7e5eb6cf4a19bd6a53522b2bb4986651a8cf910..01467e324cc66c01ef9e89465bb3014a94dd9be8 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -272,6 +272,98 @@ void BenchXYNKernel() { } } +// return this function avg time +template +double BenchLSTMFunc(const typename KernelTuples::func_type tgt, + const paddle::operators::jit::lstm_attr_t* attr, + paddle::operators::jit::lstm_t* step) { + for (int i = 0; i < FLAGS_burning; ++i) { + tgt(step, attr); + } + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeat; ++i) { + tgt(step, attr); + } + auto end = GetCurrentUS(); + return (end - start) / FLAGS_repeat; +} + +template +void BenchLSTMKernel() { + namespace jit = paddle::operators::jit; + for (bool use_peephole : {true, false}) { + for (int d : TestSizes()) { + const jit::lstm_attr_t attr(d, jit::vsigmoid, jit::vtanh, jit::vtanh, + use_peephole); + std::vector> infos; + std::vector x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); + RandomVec(4 * d, x.data(), -2.f, 2.f); + RandomVec(3 * d, wp.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + T* x_data = x.data(); + T* checked_data = checked.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + + // test refer + auto refer = jit::GetRefer>(); + if (refer) { + auto res = BenchLSTMFunc>(refer, &attr, &step); + infos.push_back(std::make_pair("Refer", res)); + } + // test jitcode + auto jitcode = jit::GetJitCode, PlaceType>(attr); + if (jitcode) { + auto res = BenchLSTMFunc>(jitcode, &attr, &step); + infos.push_back(std::make_pair("JitCode", res)); + } + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast>*>( + impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + auto res = BenchLSTMFunc>(more, &attr, &step); + infos.push_back(std::make_pair("More", res)); + } + } + } + // Test result from Get function + auto tgt = jit::Get, PlaceType>(attr); + if (!tgt) { + LOG(ERROR) << "Target can not be empty!"; + } + auto res = BenchLSTMFunc>(tgt, &attr, &step); + infos.push_back(std::make_pair("Target", res)); + // print + std::ostringstream loginfos; + loginfos << "Kernel Type: " << jit::to_string(KT) + << ", Sigmoid,Tanh,Tanh, " << (use_peephole ? "Peephole_" : "") + << " size " << d << ": "; + for (auto pair : infos) { + loginfos << pair.first << " takes " << pair.second << " us; "; + } + LOG(INFO) << loginfos.str(); + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -294,9 +386,14 @@ int main(int argc, char* argv[]) { BenchAXYNKernel(); BenchAXYNKernel(); + // act BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); + + // lstm and peephole + BenchLSTMKernel(); + BenchLSTMKernel(); }