提交 8e785fec 编写于 作者: T tensor-tang

clean code and refine tests template

上级 00d3afbc
......@@ -57,8 +57,17 @@ std::vector<int> TestSizes() {
return s;
}
template <typename T, typename KernelTuples>
void TestXYZNFunc(const typename KernelTuples::func_type tgt,
namespace jit = paddle::operators::jit;
template <typename KernelTuples, typename... Args>
struct TestFuncWithRefer {
void operator()(const typename KernelTuples::func_type tgt, Args... args) {}
};
template <typename T>
struct TestFuncWithRefer<jit::XYZNTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::XYZNTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& y,
const std::vector<T>& zref) {
EXPECT_TRUE(tgt != nullptr);
......@@ -82,6 +91,154 @@ void TestXYZNFunc(const typename KernelTuples::func_type tgt,
std::copy(y.begin(), y.end(), ztgt.begin());
tgt(x_data, ztgt_data, ztgt_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::AXYNTuples<T>::func_type tgt, const T a,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(&a, x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(&a, ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::LSTMTuples<T>::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& wp,
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
const std::vector<T>& ht_ref,
const typename jit::LSTMTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ct_ref.size(), ht_ref.size());
EXPECT_EQ(ct_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
EXPECT_EQ(wp.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
std::vector<T> checked(2 * d);
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
const T* ct_ref_data = ct_ref.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
T* checked_data = checked.data();
paddle::operators::jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (attr.use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
tgt(&step, &attr);
ExpectEQ<T>(ct_data, ct_ref_data, d);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::GRUTuples<T>::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
const std::vector<T>& ht_ref,
const typename jit::GRUTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ht_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ht(ht_ref.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ht_data = ht.data();
paddle::operators::jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
tgt(&step, &attr);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
TestFuncWithRefer<KernelTuples, Args...> test;
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel ";
test(jitcode, args...);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel ";
test(more, args...);
}
}
}
// test result from Get function
VLOG(10) << "Test Get function ";
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
test(tgt, args...);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
......@@ -113,77 +270,9 @@ void TestXYZNKernel() {
ExpectEQ<T>(xinp_data, zref_data, d);
ExpectEQ<T>(yinp_data, zref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel, size: " << d;
TestXYZNFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, zref);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::XYZNTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d;
TestXYZNFunc<T, jit::XYZNTuples<T>>(more, x, y, zref);
}
}
TestAllImpls<KT, jit::XYZNTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(d, x, y, zref);
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
TestXYZNFunc<T, jit::XYZNTuples<T>>(tgt, x, y, zref);
}
}
TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vadd) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vaddrelu) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vsub) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
}
template <typename T, typename KernelTuples>
void TestAXYNFunc(const typename KernelTuples::func_type tgt, const T a,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(&a, x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(&a, ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
......@@ -208,65 +297,9 @@ void TestAXYNKernel() {
ref(&a, xinp_data, xinp_data, d);
ExpectEQ<T>(xinp_data, yref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::AXYNTuples<T>, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel, size: " << d;
TestAXYNFunc<T, jit::AXYNTuples<T>>(jitcode, a, x, yref);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::AXYNTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d;
TestAXYNFunc<T, jit::AXYNTuples<T>>(more, a, x, yref);
}
}
TestAllImpls<KT, jit::AXYNTuples<T>, PlaceType, T, std::vector<T>,
std::vector<T>>(d, a, x, yref);
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, jit::AXYNTuples<T>, PlaceType>(d);
TestAXYNFunc<T, jit::AXYNTuples<T>>(tgt, a, x, yref);
}
}
TEST(JITKernel, vscal) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vscal, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vscal, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vaddbias) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vaddbias, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>();
}
template <typename T, typename KernelTuples>
void TestXYNFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
......@@ -290,106 +323,9 @@ void TestXYNKernel() {
ref(xinp_data, xinp_data, d);
ExpectEQ<T>(xinp_data, yref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::XYNTuples<T>, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel, size: " << d;
TestXYNFunc<T, jit::XYNTuples<T>>(jitcode, x, yref);
TestAllImpls<KT, jit::XYNTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(d, x, yref);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<jit::XYNTuples<T>>*>(impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d;
TestXYNFunc<T, jit::XYNTuples<T>>(more, x, yref);
}
}
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, jit::XYNTuples<T>, PlaceType>(d);
TestXYNFunc<T, jit::XYNTuples<T>>(tgt, x, yref);
}
}
TEST(JITKernel, vrelu) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vrelu, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vrelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, videntity) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::videntity, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::videntity, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vexp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vexp, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vexp, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vsigmoid) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vsigmoid, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vsigmoid, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vtanh) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vtanh, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
}
template <typename T, typename KernelTuples>
void TestLSTMFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& wp,
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
const std::vector<T>& ht_ref,
const paddle::operators::jit::lstm_attr_t& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ct_ref.size(), ht_ref.size());
EXPECT_EQ(ct_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
EXPECT_EQ(wp.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
std::vector<T> checked(2 * d);
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
const T* ct_ref_data = ct_ref.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
T* checked_data = checked.data();
paddle::operators::jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (attr.use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
tgt(&step, &attr);
ExpectEQ<T>(ct_data, ct_ref_data, d);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
......@@ -435,37 +371,10 @@ void TestLSTMKernel() {
}
ref(&step, &attr);
// test jitcode
auto jitcode =
jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel " << info;
TestLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, xsrc, wp, ct_1,
ct_ref, ht_ref, attr);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel " << info;
TestLSTMFunc<T, jit::LSTMTuples<T>>(more, xsrc, wp, ct_1,
ct_ref, ht_ref, attr);
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
TestLSTMFunc<T, jit::LSTMTuples<T>>(tgt, xsrc, wp, ct_1, ct_ref,
ht_ref, attr);
TestAllImpls<KT, jit::LSTMTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>, std::vector<T>,
std::vector<T>>(attr, xsrc, wp, ct_1, ct_ref, ht_ref,
attr);
}
}
}
......@@ -473,43 +382,6 @@ void TestLSTMKernel() {
}
}
TEST(JITKernel, lstmctht) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, lstmc1h1) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
}
template <typename T, typename KernelTuples>
void TestGRUFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
const std::vector<T>& ht_ref,
const paddle::operators::jit::gru_attr_t& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ht_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ht(ht_ref.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ht_data = ht.data();
paddle::operators::jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
tgt(&step, &attr);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() {
namespace jit = paddle::operators::jit;
......@@ -538,37 +410,97 @@ void TestGRUKernel() {
step.ht = ht_ref_data;
ref(&step, &attr);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::GRUTuples<T>, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel " << info;
TestGRUFunc<T, jit::GRUTuples<T>>(jitcode, xsrc, ht_1, ht_ref, attr);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::GRUTuples<T>>*>(
impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel " << info;
TestGRUFunc<T, jit::GRUTuples<T>>(more, xsrc, ht_1, ht_ref, attr);
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::GRUTuples<T>, PlaceType>(attr);
TestGRUFunc<T, jit::GRUTuples<T>>(tgt, xsrc, ht_1, ht_ref, attr);
TestAllImpls<KT, jit::GRUTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(attr, xsrc, ht_1, ht_ref,
attr);
}
}
}
}
// XYZNTuple
TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vadd) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vaddrelu) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vsub) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, vscal) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vscal, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vscal, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vaddbias) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vaddbias, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>();
}
// XYNTuples
TEST(JITKernel, vrelu) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vrelu, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vrelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, videntity) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::videntity, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::videntity, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vexp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vexp, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vexp, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vsigmoid) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vsigmoid, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vsigmoid, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vtanh) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vtanh, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
}
// LSTM
TEST(JITKernel, lstmctht) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, lstmc1h1) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
}
// GRU
TEST(JITKernel, gruh1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruh1, float, paddle::platform::CPUPlace>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册