提交 a3703888 编写于 作者: T tensor-tang

fix unit test with double type

上级 417d031f
......@@ -33,8 +33,11 @@ namespace jit {
#define EXP_MAX_INPUT 40.0
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename KernelTuples::func_type GetJitCode(
typename KernelTuples::attr_type attr) {
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
......@@ -45,21 +48,19 @@ inline typename KernelTuples::func_type GetJitCode(
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
......@@ -67,6 +68,15 @@ inline typename KernelTuples::func_type GetJitCode(
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(typename KernelTuples::attr_type attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
......
......@@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) {
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i < 10; ++i) {
for (int i = 1; i < 32; ++i) {
s.push_back(i);
}
// // test some large size
// s.push_back(100);
// s.push_back(1000);
// s.push_back(2000);
// test some large size
s.push_back(100);
s.push_back(1000);
s.push_back(2000);
return s;
}
......@@ -148,8 +148,7 @@ void TestXYZNKernel() {
TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
// TODO(TJ): fix double issue
// TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vadd) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册