提交 8bc63815 编写于 作者: T tensor-tang

fix jitcodekey and refine test

test=develop
上级 7044cfa7
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
......@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
......@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(act_type_convert(attr.act_cand) << act_type_shift);
}
template <>
......
......@@ -40,11 +40,11 @@ template <typename T>
void ExpectEQ(const T* target, const T* refer, size_t n) {
if (std::is_floating_point<T>::value) {
for (size_t i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], FLAGS_acc);
EXPECT_NEAR(target[i], refer[i], FLAGS_acc) << " at index : " << i;
}
} else {
for (size_t i = 0; i < n; ++i) {
EXPECT_EQ(target[i], refer[i]);
EXPECT_EQ(target[i], refer[i]) << " at index : " << i;
}
}
}
......@@ -447,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() {
void TestKernelXYZNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
......@@ -480,7 +480,7 @@ void TestXYZNKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() {
void TestKernelAXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
......@@ -506,7 +506,7 @@ void TestAXYNKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() {
void TestKernelXRNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc;
FLAGS_acc = 1e-4;
......@@ -524,7 +524,7 @@ void TestXRNKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() {
void TestKernelXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
......@@ -549,10 +549,12 @@ void TestXYNKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() {
void TestKernelLSTMTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (bool use_peephole : {true, false}) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
......@@ -599,10 +601,12 @@ void TestLSTMKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() {
void TestKernelGRUTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
......@@ -633,14 +637,16 @@ void TestGRUKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
void TestKernelSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (auto type : pool_types) {
for (int w : TestSizes()) {
for (int w : test_sizes) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
for (int h : test_sizes) {
attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
......@@ -658,11 +664,11 @@ void TestSeqPoolKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() {
void TestKernelMatMulTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc;
// TODO(intel): fix MKL acc issue
// https://github.com/PaddlePaddle/Paddle/issues/15447
// export MKL_CBWR=AVX would make MKL force to use AVX
// export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
FLAGS_acc = 1e-3;
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
......@@ -686,7 +692,7 @@ void TestMatMulKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestSoftmaxKernel() {
void TestKernelSoftmaxTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
......@@ -711,12 +717,14 @@ void TestSoftmaxKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestEmbSeqPoolKernel() {
void TestKernelEmbSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
int64_t tbl_h = 1e4;
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum}; // only support sum yet
for (int tbl_w : TestSizes()) {
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int tbl_w : test_sizes) {
std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f);
const T* table_data = table.data();
......@@ -745,7 +753,7 @@ void TestEmbSeqPoolKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestSgdKernel() {
void TestKernelSgdTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
......@@ -799,7 +807,7 @@ void TestSgdKernel() {
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
void TestKernelNCHW16CMulNCTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
......@@ -852,7 +860,7 @@ void TestNCHW16CMulNCKernel() {
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestLayerNormKernel() {
void TestKernelLayerNormTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) {
......@@ -891,11 +899,13 @@ void TestLayerNormKernel() {
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestCRFDecodingKernel() {
void TestKernelCRFDecodingTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
constexpr int state_trans_base_idx = 2;
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) {
for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
EXPECT_TRUE(ref != nullptr);
int x_sz = seq_len * tag_num;
......@@ -916,148 +926,76 @@ void TestCRFDecodingKernel() {
}
}
// XYZNTuple
TEST(JITKernel, kVMul) {
TestXYZNKernel<jit::kVMul, float, CPUPlace>();
TestXYZNKernel<jit::kVMul, double, CPUPlace>();
}
TEST(JITKernel, kVAdd) {
TestXYZNKernel<jit::kVAdd, float, CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, CPUPlace>();
}
TEST(JITKernel, kVAddRelu) {
TestXYZNKernel<jit::kVAddRelu, float, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, CPUPlace>();
}
TEST(JITKernel, kVSub) {
TestXYZNKernel<jit::kVSub, float, CPUPlace>();
TestXYZNKernel<jit::kVSub, double, CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, kVScal) {
TestAXYNKernel<jit::kVScal, float, CPUPlace>();
TestAXYNKernel<jit::kVScal, double, CPUPlace>();
}
TEST(JITKernel, kVAddBias) {
TestAXYNKernel<jit::kVAddBias, float, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, CPUPlace>();
}
// XRNTuples
TEST(JITKernel, kHMax) {
TestXRNKernel<jit::kHMax, float, CPUPlace>();
TestXRNKernel<jit::kHMax, double, CPUPlace>();
}
TEST(JITKernel, kHSum) {
TestXRNKernel<jit::kHSum, float, CPUPlace>();
TestXRNKernel<jit::kHSum, double, CPUPlace>();
}
// XYNTuples
TEST(JITKernel, kVRelu) {
TestXYNKernel<jit::kVRelu, float, CPUPlace>();
TestXYNKernel<jit::kVRelu, double, CPUPlace>();
}
TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, float, CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, CPUPlace>();
}
TEST(JITKernel, kVSquare) {
TestXYNKernel<jit::kVSquare, float, CPUPlace>();
TestXYNKernel<jit::kVSquare, double, CPUPlace>();
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
}
TEST(JITKernel, kVExp) {
TestXYNKernel<jit::kVExp, float, CPUPlace>();
TestXYNKernel<jit::kVExp, double, CPUPlace>();
}
TEST_CPU_KERNEL(XYZNTuples, kVMul);
TEST_CPU_KERNEL(XYZNTuples, kVAdd);
TEST_CPU_KERNEL(XYZNTuples, kVAddRelu);
TEST_CPU_KERNEL(XYZNTuples, kVSub);
TEST(JITKernel, kVSigmoid) {
TestXYNKernel<jit::kVSigmoid, float, CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, CPUPlace>();
}
TEST_CPU_KERNEL(AXYNTuples, kVScal);
TEST_CPU_KERNEL(AXYNTuples, kVAddBias);
TEST(JITKernel, kVTanh) {
TestXYNKernel<jit::kVTanh, float, CPUPlace>();
TestXYNKernel<jit::kVTanh, double, CPUPlace>();
}
TEST_CPU_KERNEL(XRNTuples, kHMax);
TEST_CPU_KERNEL(XRNTuples, kHSum);
// LSTM
TEST(JITKernel, kLSTMCtHt) {
TestLSTMKernel<jit::kLSTMCtHt, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, CPUPlace>();
}
TEST_CPU_KERNEL(XYNTuples, kVRelu);
TEST_CPU_KERNEL(XYNTuples, kVIdentity);
TEST_CPU_KERNEL(XYNTuples, kVSquare);
TEST_CPU_KERNEL(XYNTuples, kVExp);
TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
TEST_CPU_KERNEL(XYNTuples, kVTanh);
TEST(JITKernel, kLSTMC1H1) {
TestLSTMKernel<jit::kLSTMC1H1, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, double, CPUPlace>();
}
TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
// GRU
TEST(JITKernel, kGRUH1) {
TestGRUKernel<jit::kGRUH1, float, CPUPlace>();
TestGRUKernel<jit::kGRUH1, double, CPUPlace>();
}
TEST_CPU_KERNEL(GRUTuples, kGRUH1);
TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1);
TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2);
TEST(JITKernel, kGRUHtPart1) {
TestGRUKernel<jit::kGRUHtPart1, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, CPUPlace>();
}
TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC);
TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel<jit::kGRUHtPart2, float, CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, double, CPUPlace>();
}
TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool);
TEST_CPU_KERNEL(MatMulTuples, kMatMul);
TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax);
TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
TEST_CPU_KERNEL(SgdTuples, kSgd);
TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
TEST(JITKernel, kSeqPool) {
TestSeqPoolKernel<jit::kSeqPool, float, CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, CPUPlace>();
}
TEST(JITKernel, kMatMul) {
TestMatMulKernel<jit::kMatMul, float, CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, CPUPlace>();
}
TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, float, CPUPlace>();
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
}
TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
TEST(JITKernel, kEmbSeqPool) {
TestEmbSeqPoolKernel<jit::kEmbSeqPool, float, CPUPlace>();
TestEmbSeqPoolKernel<jit::kEmbSeqPool, double, CPUPlace>();
}
auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
TEST(JITKernel, kSgd) {
TestSgdKernel<jit::kSgd, float, CPUPlace>();
TestSgdKernel<jit::kSgd, double, CPUPlace>();
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
TEST(JITKernel, kNCHW16CMulNC) {
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>();
}
TEST(JITKernel_key, gru) {
jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
TEST(JITKernel, kLayerNorm) {
TestLayerNormKernel<jit::kLayerNorm, float, paddle::platform::CPUPlace>();
TestLayerNormKernel<jit::kLayerNorm, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kCRFDecoding) {
TestCRFDecodingKernel<jit::kCRFDecoding, float, paddle::platform::CPUPlace>();
TestCRFDecodingKernel<jit::kCRFDecoding, double,
paddle::platform::CPUPlace>();
}
auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
TEST(JITKernel, pool) {
// TODO(TJ): add some test
EXPECT_TRUE(key1 != key2);
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
// TODO(TJ): add more test about key and pool
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册