提交 6eec4617 编写于 作者: T tensor-tang

add lstm peephole benchmark

上级 bf9302f9
...@@ -272,6 +272,98 @@ void BenchXYNKernel() { ...@@ -272,6 +272,98 @@ void BenchXYNKernel() {
} }
} }
// return this function avg time
template <typename T, typename KernelTuples>
double BenchLSTMFunc(const typename KernelTuples::func_type tgt,
const paddle::operators::jit::lstm_attr_t* attr,
paddle::operators::jit::lstm_t* step) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(step, attr);
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(step, attr);
}
auto end = GetCurrentUS();
return (end - start) / FLAGS_repeat;
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
namespace jit = paddle::operators::jit;
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::vsigmoid, jit::vtanh, jit::vtanh,
use_peephole);
std::vector<std::pair<std::string, double>> infos;
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
// test refer
auto refer = jit::GetRefer<KT, jit::LSTMTuples<T>>();
if (refer) {
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(refer, &attr, &step);
infos.push_back(std::make_pair("Refer", res));
}
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
if (jitcode) {
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, &attr, &step);
infos.push_back(std::make_pair("JitCode", res));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(more, &attr, &step);
infos.push_back(std::make_pair("More", res));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
if (!tgt) {
LOG(ERROR) << "Target can not be empty!";
}
auto res = BenchLSTMFunc<T, jit::LSTMTuples<T>>(tgt, &attr, &step);
infos.push_back(std::make_pair("Target", res));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type: " << jit::to_string(KT)
<< ", Sigmoid,Tanh,Tanh, " << (use_peephole ? "Peephole_" : "")
<< " size " << d << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
...@@ -294,9 +386,14 @@ int main(int argc, char* argv[]) { ...@@ -294,9 +386,14 @@ int main(int argc, char* argv[]) {
BenchAXYNKernel<jit::vscal, T, PlaceType>(); BenchAXYNKernel<jit::vscal, T, PlaceType>();
BenchAXYNKernel<jit::vaddbias, T, PlaceType>(); BenchAXYNKernel<jit::vaddbias, T, PlaceType>();
// act
BenchXYNKernel<jit::vrelu, T, PlaceType>(); BenchXYNKernel<jit::vrelu, T, PlaceType>();
BenchXYNKernel<jit::videntity, T, PlaceType>(); BenchXYNKernel<jit::videntity, T, PlaceType>();
BenchXYNKernel<jit::vexp, T, PlaceType>(); BenchXYNKernel<jit::vexp, T, PlaceType>();
BenchXYNKernel<jit::vsigmoid, T, PlaceType>(); BenchXYNKernel<jit::vsigmoid, T, PlaceType>();
BenchXYNKernel<jit::vtanh, T, PlaceType>(); BenchXYNKernel<jit::vtanh, T, PlaceType>();
// lstm and peephole
BenchLSTMKernel<jit::lstmctht, T, PlaceType>();
BenchLSTMKernel<jit::lstmc1h1, T, PlaceType>();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册