提交 a3703888 编写于 作者: T tensor-tang

fix unit test with double type

上级 417d031f
...@@ -33,8 +33,11 @@ namespace jit { ...@@ -33,8 +33,11 @@ namespace jit {
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
template <KernelType KT, typename KernelTuples, typename PlaceType> template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename KernelTuples::func_type GetJitCode( inline typename std::enable_if<
typename KernelTuples::attr_type attr) { std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
using Func = typename KernelTuples::func_type; using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
...@@ -45,21 +48,19 @@ inline typename KernelTuples::func_type GetJitCode( ...@@ -45,21 +48,19 @@ inline typename KernelTuples::func_type GetJitCode(
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) { // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto iter = creator_map.find(kkey);
auto iter = creator_map.find(kkey); if (iter != creator_map.end()) {
if (iter != creator_map.end()) { auto& creators = iter->second;
auto& creators = iter->second; for (auto& cur : creators) {
for (auto& cur : creators) { auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get()); if (i && i->UseMe(attr)) {
if (i && i->UseMe(attr)) { auto p = i->CreateJitCode(attr);
auto p = i->CreateJitCode(attr); if (p) {
if (p) { auto f = p->template getCode<Func>();
auto f = p->template getCode<Func>(); codes.Insert(key, std::move(p));
codes.Insert(key, std::move(p)); return f;
return f;
}
} }
} }
} }
...@@ -67,6 +68,15 @@ inline typename KernelTuples::func_type GetJitCode( ...@@ -67,6 +68,15 @@ inline typename KernelTuples::func_type GetJitCode(
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples> template <KernelType KT, typename KernelTuples>
......
...@@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) { ...@@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) {
std::vector<int> TestSizes() { std::vector<int> TestSizes() {
std::vector<int> s; std::vector<int> s;
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 32; ++i) {
s.push_back(i); s.push_back(i);
} }
// // test some large size // test some large size
// s.push_back(100); s.push_back(100);
// s.push_back(1000); s.push_back(1000);
// s.push_back(2000); s.push_back(2000);
return s; return s;
} }
...@@ -148,8 +148,7 @@ void TestXYZNKernel() { ...@@ -148,8 +148,7 @@ void TestXYZNKernel() {
TEST(JITKernel, vmul) { TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
// TODO(TJ): fix double issue TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
// TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vadd) { TEST(JITKernel, vadd) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册