提交 a3703888 编写于 作者: T tensor-tang

fix unit test with double type

上级 417d031f
...@@ -33,8 +33,11 @@ namespace jit { ...@@ -33,8 +33,11 @@ namespace jit {
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
template <KernelType KT, typename KernelTuples, typename PlaceType> template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename KernelTuples::func_type GetJitCode( inline typename std::enable_if<
typename KernelTuples::attr_type attr) { std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
using Func = typename KernelTuples::func_type; using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
...@@ -45,7 +48,6 @@ inline typename KernelTuples::func_type GetJitCode( ...@@ -45,7 +48,6 @@ inline typename KernelTuples::func_type GetJitCode(
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
...@@ -63,7 +65,15 @@ inline typename KernelTuples::func_type GetJitCode( ...@@ -63,7 +65,15 @@ inline typename KernelTuples::func_type GetJitCode(
} }
} }
} }
} return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
return nullptr; return nullptr;
} }
......
...@@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) { ...@@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) {
std::vector<int> TestSizes() { std::vector<int> TestSizes() {
std::vector<int> s; std::vector<int> s;
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 32; ++i) {
s.push_back(i); s.push_back(i);
} }
// // test some large size // test some large size
// s.push_back(100); s.push_back(100);
// s.push_back(1000); s.push_back(1000);
// s.push_back(2000); s.push_back(2000);
return s; return s;
} }
...@@ -148,8 +148,7 @@ void TestXYZNKernel() { ...@@ -148,8 +148,7 @@ void TestXYZNKernel() {
TEST(JITKernel, vmul) { TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
// TODO(TJ): fix double issue TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
// TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vadd) { TEST(JITKernel, vadd) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册