/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/platform/device_tracer.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" DEFINE_int32(burning, 10, "Burning times."); DEFINE_int32(repeat, 3000, "Repeat times."); DEFINE_int32(max_size, 1000, "The Max size would be tested."); template void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), const T upper = static_cast(20.f), unsigned int seed = 100) { std::mt19937 rng(seed); std::uniform_real_distribution uniform_dist(0, 1); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } } std::vector TestSizes() { std::vector s; for (int i = 1; i <= FLAGS_max_size; ++i) { s.push_back(i); } return s; } template struct BenchFunc { // return this function avg time double operator()(const typename KernelTuples::func_type tgt, Args... args) { for (int i = 0; i < FLAGS_burning; ++i) { tgt(args...); } auto start = paddle::platform::PosixInNsec() / 1e-3; for (int i = 0; i < FLAGS_repeat; ++i) { tgt(args...); } auto end = paddle::platform::PosixInNsec() / 1e-3; return static_cast(end - start) / FLAGS_repeat; } }; namespace jit = paddle::operators::jit; template void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { BenchFunc benchmark; std::vector> infos; // test refer auto refer = jit::GetRefer(); if (!refer) { LOG(FATAL) << "Refer can not be empty!"; } infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); // test jitcode auto jitcode = jit::GetJitCode(attr); if (jitcode) { infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); } // test all impls in more jit::KernelKey kkey(KT, PlaceType()); auto& pool = jit::KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); infos.push_back( std::make_pair(i->ImplType(), benchmark(more, args...))); } } } // Test result from Get function auto tgt = jit::Get(attr); if (!tgt) { LOG(FATAL) << "Target can not be empty!"; } infos.push_back(std::make_pair("Target", benchmark(tgt, args...))); // print std::ostringstream loginfos; loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; for (auto pair : infos) { loginfos << pair.first << " takes " << pair.second << " us; "; } LOG(INFO) << loginfos.str(); } template void BenchXYZNKernel() { for (int d : TestSizes()) { std::vector x(d), y(d), z(d); RandomVec(d, x.data()); RandomVec(d, y.data()); BenchAllImpls, PlaceType>(d, x.data(), y.data(), z.data(), d); } } template void BenchAXYNKernel() { for (int d : TestSizes()) { const T a = static_cast(3); std::vector x(d), y(d); RandomVec(d, x.data()); BenchAllImpls, PlaceType>(d, &a, x.data(), y.data(), d); } } template void BenchXYNKernel() { for (int d : TestSizes()) { std::vector x(d), y(d); RandomVec(d, x.data()); BenchAllImpls, PlaceType>(d, x.data(), y.data(), d); } } template void BenchLSTMKernel() { for (bool use_peephole : {true, false}) { for (int d : TestSizes()) { const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, use_peephole); std::vector x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); RandomVec(4 * d, x.data(), -2.f, 2.f); RandomVec(3 * d, wp.data(), -2.f, 2.f); RandomVec(d, ct_1.data(), -2.f, 2.f); const T* ct_1_data = ct_1.data(); const T* wp_data = wp.data(); T* x_data = x.data(); T* checked_data = checked.data(); T* ct_data = ct.data(); T* ht_data = ht.data(); jit::lstm_t step; step.gates = x_data; step.ct_1 = ct_1_data; step.ct = ct_data; step.ht = ht_data; if (use_peephole) { step.wp = wp_data; step.checked = checked_data; } BenchAllImpls, PlaceType>(attr, &step, &attr); } } } template void BenchGRUKernel() { for (int d : TestSizes()) { const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); std::vector x(3 * d), ht_1(d), ht(d); RandomVec(3 * d, x.data(), -2.f, 2.f); RandomVec(d, ht_1.data(), -2.f, 2.f); const T* ht_1_data = ht_1.data(); T* x_data = x.data(); T* ht_data = ht.data(); jit::gru_t step; step.gates = x_data; step.ht_1 = ht_1_data; step.ht = ht_data; BenchAllImpls, PlaceType>(attr, &step, &attr); } } template void BenchSeqPoolKernel() { std::vector pool_types = { jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; for (auto type : pool_types) { for (int w : TestSizes()) { jit::seq_pool_attr_t attr(w, type); for (int h : TestSizes()) { attr.h = h; std::vector x(h * w), y(w); RandomVec(h * w, x.data(), -2.f, 2.f); const T* x_data = x.data(); T* y_data = y.data(); BenchAllImpls, PlaceType>(attr, x_data, y_data, &attr); } } } } // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: // --burning: the burning time before count // --repeat: the repeat times // --max_size: the max size would be tested int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat << " times."; using T = float; using PlaceType = paddle::platform::CPUPlace; // xyzn BenchXYZNKernel(); BenchXYZNKernel(); BenchXYZNKernel(); BenchXYZNKernel(); // axyn BenchAXYNKernel(); BenchAXYNKernel(); // xyn BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); // lstm and peephole BenchLSTMKernel(); BenchLSTMKernel(); // gru functions BenchGRUKernel(); BenchGRUKernel(); BenchGRUKernel(); // seq pool function BenchSeqPoolKernel(); }