提交 00d3afbc 编写于 作者: T tensor-tang

add gru refer functions, test and benchmark

上级 6eec4617
...@@ -45,4 +45,6 @@ PaddlePaddle/Paddle/paddle/fluid/ ...@@ -45,4 +45,6 @@ PaddlePaddle/Paddle/paddle/fluid/
-`KernelType` 中添加 `your_key` . -`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt``USE_JITKERNEL_REFER(your_key)`. - 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt``USE_JITKERNEL_REFER(your_key)`.
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`. - 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。
- 添加unit test,需要测试float和double
- 添加benchmark确保get得到的速度是最快。
...@@ -364,6 +364,85 @@ void BenchLSTMKernel() { ...@@ -364,6 +364,85 @@ void BenchLSTMKernel() {
} }
} }
// return this function avg time
template <typename T, typename KernelTuples>
double BenchGRUFunc(const typename KernelTuples::func_type tgt,
const paddle::operators::jit::gru_attr_t* attr,
paddle::operators::jit::gru_t* step) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(step, attr);
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(step, attr);
}
auto end = GetCurrentUS();
return (end - start) / FLAGS_repeat;
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
namespace jit = paddle::operators::jit;
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::vsigmoid, jit::vtanh);
std::vector<std::pair<std::string, double>> infos;
std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_data = ht.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
// test refer
auto refer = jit::GetRefer<KT, jit::GRUTuples<T>>();
if (refer) {
auto res = BenchGRUFunc<T, jit::GRUTuples<T>>(refer, &attr, &step);
infos.push_back(std::make_pair("Refer", res));
}
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::GRUTuples<T>, PlaceType>(attr);
if (jitcode) {
auto res = BenchGRUFunc<T, jit::GRUTuples<T>>(jitcode, &attr, &step);
infos.push_back(std::make_pair("JitCode", res));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<jit::GRUTuples<T>>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
auto res = BenchGRUFunc<T, jit::GRUTuples<T>>(more, &attr, &step);
infos.push_back(std::make_pair("More", res));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::GRUTuples<T>, PlaceType>(attr);
if (!tgt) {
LOG(ERROR) << "Target can not be empty!";
}
auto res = BenchGRUFunc<T, jit::GRUTuples<T>>(tgt, &attr, &step);
infos.push_back(std::make_pair("Target", res));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type: " << jit::to_string(KT) << ", Sigmoid,Tanh, size "
<< d << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
...@@ -396,4 +475,9 @@ int main(int argc, char* argv[]) { ...@@ -396,4 +475,9 @@ int main(int argc, char* argv[]) {
// lstm and peephole // lstm and peephole
BenchLSTMKernel<jit::lstmctht, T, PlaceType>(); BenchLSTMKernel<jit::lstmctht, T, PlaceType>();
BenchLSTMKernel<jit::lstmc1h1, T, PlaceType>(); BenchLSTMKernel<jit::lstmc1h1, T, PlaceType>();
// gru functions
BenchGRUKernel<jit::gruh1, T, PlaceType>();
BenchGRUKernel<jit::gruhtpart1, T, PlaceType>();
BenchGRUKernel<jit::gruhtpart2, T, PlaceType>();
} }
...@@ -39,6 +39,9 @@ const char* to_string(KernelType kt) { ...@@ -39,6 +39,9 @@ const char* to_string(KernelType kt) {
ONE_CASE(vtanh); ONE_CASE(vtanh);
ONE_CASE(lstmctht); ONE_CASE(lstmctht);
ONE_CASE(lstmc1h1); ONE_CASE(lstmc1h1);
ONE_CASE(gruh1);
ONE_CASE(gruhtpart1);
ONE_CASE(gruhtpart2);
default: default:
PADDLE_THROW("Not support type: %d", kt); PADDLE_THROW("Not support type: %d", kt);
return "NOT JITKernel"; return "NOT JITKernel";
......
...@@ -33,7 +33,10 @@ typedef enum { ...@@ -33,7 +33,10 @@ typedef enum {
vsigmoid, vsigmoid,
vtanh, vtanh,
lstmctht, lstmctht,
lstmc1h1 lstmc1h1,
gruh1,
gruhtpart1,
gruhtpart2
} KernelType; } KernelType;
template <typename T> template <typename T>
...@@ -98,6 +101,13 @@ struct LSTMTuples { ...@@ -98,6 +101,13 @@ struct LSTMTuples {
typedef void (*func_type)(lstm_t*, const lstm_attr_t*); typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
}; };
template <typename T>
struct GRUTuples {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
// Just for adding to kernel pool without template // Just for adding to kernel pool without template
class Kernel { class Kernel {
public: public:
......
...@@ -23,9 +23,10 @@ size_t JitCodeKey<int>(const int& d) { ...@@ -23,9 +23,10 @@ size_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
constexpr int act_type_shift = 3; // suppot 2^3 act types
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
constexpr int act_type_shift = 3; // suppot 2^3 act types
size_t key = attr.d; size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1; int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift); int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
...@@ -33,6 +34,14 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { ...@@ -33,6 +34,14 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole; attr.use_peephole;
} }
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,3 +20,6 @@ USE_JITKERNEL_REFER(vsigmoid) ...@@ -20,3 +20,6 @@ USE_JITKERNEL_REFER(vsigmoid)
USE_JITKERNEL_REFER(vtanh) USE_JITKERNEL_REFER(vtanh)
USE_JITKERNEL_REFER(lstmctht) USE_JITKERNEL_REFER(lstmctht)
USE_JITKERNEL_REFER(lstmc1h1) USE_JITKERNEL_REFER(lstmc1h1)
USE_JITKERNEL_REFER(gruh1)
USE_JITKERNEL_REFER(gruhtpart1)
USE_JITKERNEL_REFER(gruhtpart2)
...@@ -38,4 +38,8 @@ REGISTER_REFER_KERNEL(vtanh, VTanh); ...@@ -38,4 +38,8 @@ REGISTER_REFER_KERNEL(vtanh, VTanh);
REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt); REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt);
REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1); REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1);
REGISTER_REFER_KERNEL(gruh1, GRUH1);
REGISTER_REFER_KERNEL(gruhtpart1, GRUHtPart1);
REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -125,6 +125,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT ...@@ -125,6 +125,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
return nullptr; return nullptr;
} }
// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels
// compute ct and ht // compute ct and ht
template <typename T> template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
...@@ -195,6 +197,51 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -195,6 +197,51 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
VMul(gates + d2, gates + d3, ht, d); VMul(gates + d2, gates + d3, ht, d);
} }
// compute h1 without h0
template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
int d2 = d * 2;
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
VMul(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates + attr->d, gates + attr->d, attr->d);
VMul(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -219,10 +266,15 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); ...@@ -219,10 +266,15 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples);
// lstm_t* , const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -485,6 +485,108 @@ TEST(JITKernel, lstmc1h1) { ...@@ -485,6 +485,108 @@ TEST(JITKernel, lstmc1h1) {
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
} }
template <typename T, typename KernelTuples>
void TestGRUFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
const std::vector<T>& ht_ref,
const paddle::operators::jit::gru_attr_t& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ht_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ht(ht_ref.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ht_data = ht.data();
paddle::operators::jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
tgt(&step, &attr);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
std::string info = act_gate + act_cand + "size_" + std::to_string(d);
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
jit::to_kerneltype(act_cand));
auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
// x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_ref_data = ht_ref.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_ref_data;
ref(&step, &attr);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::GRUTuples<T>, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel " << info;
TestGRUFunc<T, jit::GRUTuples<T>>(jitcode, xsrc, ht_1, ht_ref, attr);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::GRUTuples<T>>*>(
impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel " << info;
TestGRUFunc<T, jit::GRUTuples<T>>(more, xsrc, ht_1, ht_ref, attr);
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::GRUTuples<T>, PlaceType>(attr);
TestGRUFunc<T, jit::GRUTuples<T>>(tgt, xsrc, ht_1, ht_ref, attr);
}
}
}
}
TEST(JITKernel, gruh1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruh1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruh1, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, gruhtpart1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruhtpart1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruhtpart1, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, gruhtpart2) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruhtpart2, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>();
}
// TODO(TJ): refine the tests template // TODO(TJ): refine the tests template
TEST(JITKernel, pool) { TEST(JITKernel, pool) {
......
...@@ -22,54 +22,7 @@ namespace paddle { ...@@ -22,54 +22,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace refer { namespace refer {} // namespace refer
// compute h1 without h0
template <typename T>
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
int d2 = d * 2;
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
VMul(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
template <typename T>
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
act_gate(gates + attr->d, gates + attr->d, attr->d);
VMul(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
template <typename T>
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
int d = attr->d;
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
} // namespace refer
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册