From 45bfa70cb8a8123cfa5c32ec7323d616f9192e3d Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 3 Dec 2018 12:15:16 +0000 Subject: [PATCH] complete vmul jit kernel --- .../fluid/operators/jitkernels/CMakeLists.txt | 12 +- paddle/fluid/operators/jitkernels/README.md | 3 + .../operators/jitkernels/jitcode/jitcode.cc | 23 ++++ .../operators/jitkernels/jitcode/jitcode.h | 7 +- .../fluid/operators/jitkernels/jitcode_base.h | 9 +- .../fluid/operators/jitkernels/kernel_base.h | 13 +- paddle/fluid/operators/jitkernels/kernels.cc | 7 +- paddle/fluid/operators/jitkernels/kernels.h | 110 ++++++++------- .../fluid/operators/jitkernels/refer/refer.cc | 3 +- .../fluid/operators/jitkernels/refer/refer.h | 8 ++ paddle/fluid/operators/jitkernels/registry.h | 126 ++++++++++-------- paddle/fluid/operators/jitkernels/test.cc | 78 ++++++++++- 12 files changed, 273 insertions(+), 126 deletions(-) diff --git a/paddle/fluid/operators/jitkernels/CMakeLists.txt b/paddle/fluid/operators/jitkernels/CMakeLists.txt index f073210542..6392d82e16 100644 --- a/paddle/fluid/operators/jitkernels/CMakeLists.txt +++ b/paddle/fluid/operators/jitkernels/CMakeLists.txt @@ -1,17 +1,19 @@ +# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) +# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n") +# file(APPEND ${pass_file} "\#pragma once\n") +# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") + + set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) cc_library(jit_kernel_base SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS}) -add_subdirectory(more) add_subdirectory(refer) - +add_subdirectory(more) if(WITH_XBYAK) add_subdirectory(jitcode) endif() -# Debug -message(STATUS "--------${JIT_KERNEL_DEPS}") - cc_library(jit_kernel SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS}) cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/jitkernels/README.md b/paddle/fluid/operators/jitkernels/README.md index a0990367ef..3401e9be53 100644 --- a/paddle/fluid/operators/jitkernels/README.md +++ b/paddle/fluid/operators/jitkernels/README.md @@ -1 +1,4 @@ TBD + +# Use me +Add USE_JIT_KERNEL(yourname) to CMakefile. diff --git a/paddle/fluid/operators/jitkernels/jitcode/jitcode.cc b/paddle/fluid/operators/jitkernels/jitcode/jitcode.cc index 0dd2d049d2..8078ace7a8 100644 --- a/paddle/fluid/operators/jitkernels/jitcode/jitcode.cc +++ b/paddle/fluid/operators/jitkernels/jitcode/jitcode.cc @@ -13,3 +13,26 @@ * limitations under the License. */ #include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h" + +namespace paddle { +namespace operators { +namespace jitkernels { + +template <> +size_t GetKey(int d) { + return d; +} + +// template <> +// std::shared_ptr CreateJitCode(int attr) +// { +// if (UseJitCode(attr)) { +// return std::make_shared>(attr, +// CodeSize(attr))); +// } +// return nullptr; +// } + +} // namespace jitkernels +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/jitcode/jitcode.h b/paddle/fluid/operators/jitkernels/jitcode/jitcode.h index c100444766..7e0b6442ed 100644 --- a/paddle/fluid/operators/jitkernels/jitcode/jitcode.h +++ b/paddle/fluid/operators/jitkernels/jitcode/jitcode.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/fluid/operators/jitkernels/jitcode_base.h" #include "paddle/fluid/operators/jitkernels/kernels.h" #define XBYAK_USE_MMAP_ALLOCATOR @@ -31,10 +32,10 @@ constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX); -template -class JitCode : public JitBase, public Xbyak::CodeGenerator { +template +class VMulJitCode : public JitBase, public Xbyak::CodeGenerator { public: - JitCode(Attr attr, size_t code_size, void* code_ptr = nullptr) + VMulJitCode(Attr attr, size_t code_size, void* code_ptr = nullptr) : Xbyak::CodeGenerator(code_size, code_ptr) { this->genCode(); } diff --git a/paddle/fluid/operators/jitkernels/jitcode_base.h b/paddle/fluid/operators/jitkernels/jitcode_base.h index 0cd6d3c741..a164746561 100644 --- a/paddle/fluid/operators/jitkernels/jitcode_base.h +++ b/paddle/fluid/operators/jitkernels/jitcode_base.h @@ -15,6 +15,7 @@ #pragma once #include +#include // for shared_ptr #include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/platform/macros.h" @@ -42,11 +43,6 @@ bool UseJitCode(Attr attr) { template size_t GetKey(Attr attr); -template <> -size_t GetKey(int d) { - return d; -} - class JitBase { public: JitBase() = default; @@ -68,6 +64,9 @@ class JitBase { void dumpCode(const unsigned char* code); }; +template +std::shared_ptr CreateJitCode(Attr attr); + } // namespace jitkernels } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/kernel_base.h b/paddle/fluid/operators/jitkernels/kernel_base.h index bd95a921c5..eeaa0617cb 100644 --- a/paddle/fluid/operators/jitkernels/kernel_base.h +++ b/paddle/fluid/operators/jitkernels/kernel_base.h @@ -25,6 +25,7 @@ typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType; class Kernel { public: Kernel() = default; + virtual ~Kernel() = default; DISABLE_COPY_AND_ASSIGN(Kernel); }; @@ -32,16 +33,20 @@ template // TODO(TJ): use tuple class KernelImpl : public Kernel { public: using ELEMENT_TYPE = T; // TODO(TJ): remove me? - KernelImpl() = default; - virtual ~KernelImpl() = default; - - virtual Func GetFunc() { return func; } + virtual Func GetFunc() const { return func; } virtual bool UseMe(Attr attr) const = 0; protected: Func func{nullptr}; }; +template // TODO(TJ): use tuple +class ReferKernel : public KernelImpl { + public: + // Refer code can always be used + bool UseMe(Attr attr) const override { return true; } +}; + } // namespace jitkernels } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/kernels.cc b/paddle/fluid/operators/jitkernels/kernels.cc index 76f49514ee..35095220e3 100644 --- a/paddle/fluid/operators/jitkernels/kernels.cc +++ b/paddle/fluid/operators/jitkernels/kernels.cc @@ -21,13 +21,16 @@ namespace paddle { namespace operators { namespace jitkernels { -// refer do not need useme, it would be the last one. - KernelPool& KernelPool::Instance() { static KernelPool g_kernel_pool; return g_kernel_pool; } +ReferKernelPool& ReferKernelPool::Instance() { + static ReferKernelPool g_refer_kernel_pool; + return g_refer_kernel_pool; +} + } // namespace jitkernels } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/kernels.h b/paddle/fluid/operators/jitkernels/kernels.h index 2792b897d3..866f72cce0 100644 --- a/paddle/fluid/operators/jitkernels/kernels.h +++ b/paddle/fluid/operators/jitkernels/kernels.h @@ -18,22 +18,21 @@ #include #include #include - #include "paddle/fluid/operators/jitkernels/jitcode_base.h" #include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_key.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h" -#endif +#include "paddle/fluid/platform/place.h" namespace paddle { namespace operators { namespace jitkernels { +// TODO(TJ): rename file to kernel_pool + template class JitCodePool { public: + JitCodePool() = default; static JitCodePool& Instance() { static thread_local JitCodePool g_jit_codes; return g_jit_codes; @@ -51,13 +50,11 @@ class JitCodePool { } private: - JitCodePool() = default; std::unordered_map> codes_; - DISABLE_COPY_AND_ASSIGN(JitCodePool); }; -// std::tuple +// TODO(TJ): std::tuple template struct KernelAttr { typedef T data_type; @@ -65,76 +62,99 @@ struct KernelAttr { typedef Attr attr_type; }; +typedef std::unique_ptr KernelPtr; +typedef std::unordered_map, KernelKey::Hash> + KernelMap; + class KernelPool { public: static KernelPool& Instance(); - - typedef std::unique_ptr KernelPtr; - typedef std::unordered_map, KernelKey::Hash> - KernelMap; + KernelPool() = default; KernelMap& AllKernels() { return pool_; } - void Insert(const KernelKey& key, KernelPtr value) { if (pool_.find(key) == pool_.end()) { pool_.emplace(key, std::vector()); } pool_.at(key).emplace_back(std::move(value)); } - KernelPool() = default; private: KernelMap pool_; - DISABLE_COPY_AND_ASSIGN(KernelPool); }; -// TODO(TJ): create_jitcode; +// Every kernel should have refer code and it should be used in unit tests, +// so refer kernels should have it's independent kernel pool +class ReferKernelPool { + public: + static ReferKernelPool& Instance(); + ReferKernelPool() = default; + KernelMap& AllKernels() { return pool_; } + void Insert(const KernelKey& key, KernelPtr value) { + if (pool_.find(key) == pool_.end()) { + pool_.emplace(key, std::vector()); + } + pool_.at(key).emplace_back(std::move(value)); + } + + private: + KernelMap pool_; + DISABLE_COPY_AND_ASSIGN(ReferKernelPool); +}; + +// Refer code do not related with attr, and always on CPUPlace +template +inline Func GetRefer() { + auto& ref_pool = ReferKernelPool().Instance().AllKernels(); + KernelKey kkey(KT, platform::CPUPlace()); + auto ref_iter = ref_pool.find(kkey); + PADDLE_ENFORCE(ref_iter != ref_pool.end(), + "Every Kernel should have reference function."); + auto& ref_impls = ref_iter->second; + for (auto& impl : ref_impls) { + auto i = dynamic_cast*>(impl.get()); + if (i) { + return i->GetFunc(); + } + } + return nullptr; +} // TODO(TJ): make tuple? named KernelAttr template Func Get(Attr attr) { - size_t key = GetKey(attr); - auto jitcode = JitCodePool().Instance().Get(key); - if (jitcode) { - return jitcode->template getCode(); + // size_t key = GetKey(attr); + // auto jitcode = JitCodePool().Instance().Get(key); + // if (jitcode) { + // return jitcode->template getCode(); + // } + + if (std::is_same::value && + std::is_same::value) { // TODO(TJ): float move to create + // auto p = CreateJitCode(attr); + // if (p) { + // JitCodePool().Instance().Insert(key, p); + // return p->template getCode(); + // } } -#ifdef PADDLE_WITH_XBYAK -// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK -// if (std::is_same::value) { -// if (UseJitCode(attr)) { -// std::shared_ptr p(std::make_shared>( -// attr, CodeSize(attr))); -// JitCodePool().Instance().Insert(key, p); -// return p->getCode(); -// } -// } -#endif - - // (KernelKey(type, place), vector) + // pool: (KernelKey(type, place), vector) auto& pool = KernelPool().Instance().AllKernels(); KernelKey kkey(KT, PlaceType()); auto iter = pool.find(kkey); if (iter != pool.end()) { - auto impls = iter->second; - for (auto impl : impls) { - auto i = std::dynamic_pointer_cast>(impl.get()); + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } } } - // The last implementation should be reference function on CPU - // Every kernel should have refer code. - - // because of test refer should have it's own pool - // PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation"; - // const auto& refer = KernelRefer().AllKernels(); - // return refer.Get(); - - return nullptr; + // The last implementation should be reference function on CPUPlace. + return GetRefer(); } } // namespace jitkernels diff --git a/paddle/fluid/operators/jitkernels/refer/refer.cc b/paddle/fluid/operators/jitkernels/refer/refer.cc index 1f6d384fc2..dbccac896c 100644 --- a/paddle/fluid/operators/jitkernels/refer/refer.cc +++ b/paddle/fluid/operators/jitkernels/refer/refer.cc @@ -17,4 +17,5 @@ namespace refer = paddle::operators::jitkernels::refer; -// REGISTER_JITKERNEL_REFER(vmul, refer::VMul, refer::VMul); +REGISTER_JITKERNEL_REFER(vmul, refer::VMulKernel, + refer::VMulKernel); diff --git a/paddle/fluid/operators/jitkernels/refer/refer.h b/paddle/fluid/operators/jitkernels/refer/refer.h index be55c30b1e..163c6d73dc 100644 --- a/paddle/fluid/operators/jitkernels/refer/refer.h +++ b/paddle/fluid/operators/jitkernels/refer/refer.h @@ -13,6 +13,7 @@ * limitations under the License. */ #pragma once +#include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -27,6 +28,13 @@ void VMul(const T* x, const T* y, T* z, int n) { } } +template +class VMulKernel + : public ReferKernel { + public: + VMulKernel() { this->func = VMul; } +}; + } // namespace refer } // namespace jitkernels } // namespace operators diff --git a/paddle/fluid/operators/jitkernels/registry.h b/paddle/fluid/operators/jitkernels/registry.h index 1d2d47a804..62a0de3641 100644 --- a/paddle/fluid/operators/jitkernels/registry.h +++ b/paddle/fluid/operators/jitkernels/registry.h @@ -20,6 +20,7 @@ #include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernels.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/variant.h" // for UNUSED namespace paddle { namespace operators { @@ -32,37 +33,40 @@ inline std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } -template +template struct JitKernelRegistrarFunctor; -template -struct JitKernelRegistrarFunctor { +template +struct JitKernelRegistrarFunctor { void operator()(KernelType kt) const {} }; -template -struct JitKernelRegistrarFunctor { +template +struct JitKernelRegistrarFunctor { using KERNEL_IMPL_TYPE = typename std::tuple_element>::type; void operator()(KernelType kt) const { KernelKey kkey(kt, PlaceType()); - KernelPool().Instance().Insert( - kkey, std::move(make_unique())); + Pool().Instance().Insert(kkey, + std::move(make_unique())); constexpr auto size = std::tuple_size>::value; - JitKernelRegistrarFunctor + JitKernelRegistrarFunctor func; func(kt); } }; -template +template class JitKernelRegistrar { public: explicit JitKernelRegistrar(KernelType kt) { - JitKernelRegistrarFunctor func; + JitKernelRegistrarFunctor func; func(kt); } + void Touch() {} }; #define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ @@ -71,17 +75,40 @@ class JitKernelRegistrar { __test_global_namespace_##uniq_name##__>::value, \ msg) +// Refer always on CPUPlace +#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_refer_CPUPlace, \ + "REGISTER_KERNEL_REFER must be called in global namespace"); \ + static ::paddle::operators::jitkernels::JitKernelRegistrar< \ + ::paddle::operators::jitkernels::ReferKernelPool, \ + ::paddle::platform::CPUPlace, __VA_ARGS__> \ + __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \ + ::paddle::operators::jitkernels::KernelType::kernel_type); \ + int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ + __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \ + return 0; \ + } + // kernel_type: should be in paddle::operators::jitkernels::KernelType // place_type: should be one of CPUPlace and GPUPlace in paddle::platform -#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ - "REGISTER_KERNEL_MORE must be called in global namespace"); \ - static ::paddle::operators::jitkernels::JitKernelRegistrar< \ - ::paddle::platform::place_type, __VA_ARGS__> \ - __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \ - ::paddle::operators::jitkernels::KernelType::kernel_type) -// TODO(TJ): Add Touch and use me +#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ + "REGISTER_KERNEL_MORE must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \ + UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static ::paddle::operators::jitkernels::JitKernelRegistrar< \ + ::paddle::operators::jitkernels::KernelPool, \ + ::paddle::platform::place_type, __VA_ARGS__> \ + __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \ + ::paddle::operators::jitkernels::KernelType::kernel_type); \ + int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ + __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \ + .Touch(); \ + return 0; \ + } #define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \ REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) @@ -89,45 +116,28 @@ class JitKernelRegistrar { #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) -/* -REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode); - -// refer must be only one and at least one -REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype - -// you can register more implementations and the condition when use it -REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL, UseMe, mkl::VMUL, - UseMe) - -#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ - struct __test_global_namespace_##uniq_name##__ {}; \ - static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ - __test_global_namespace_##uniq_name##__>::value, \ - msg) - -// Register a new pass that can be applied on the IR. -#define REGISTER_PASS(pass_type, pass_class) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_type, \ - "REGISTER_PASS must be called in global namespace"); \ - static ::paddle::framework::ir::PassRegistrar \ - __pass_registrar_##pass_type##__(#pass_type); \ - int TouchPassRegistrar_##pass_type() { \ - __pass_registrar_##pass_type##__.Touch(); \ - return 0; \ - } \ - static ::paddle::framework::ir::PassRegistrar& \ - __pass_tmp_registrar_##pass_type##__ UNUSED = \ - __pass_registrar_##pass_type##__ - -#define USE_PASS(pass_type) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __use_pass_itself_##pass_type, \ - "USE_PASS must be called in global namespace"); \ - extern int TouchPassRegistrar_##pass_type(); \ - static int use_pass_itself_##pass_type##_ UNUSED = \ - TouchPassRegistrar_##pass_type() -*/ +// REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode); + +#define USE_JITKERNEL_REFER(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_refer_CPUPlace_, \ + "USE_JITKERNEL_REFER must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \ + TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() + +#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \ + "USE_JITKERNEL_MORE must be called in global namespace"); \ + extern int \ + TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \ + static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \ + UNUSED = \ + TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() + +#define USE_JITKERNEL_MORE(kernel_type, impl_type) \ + USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace) } // namespace jitkernels } // namespace operators diff --git a/paddle/fluid/operators/jitkernels/test.cc b/paddle/fluid/operators/jitkernels/test.cc index 86c6669173..d11c7afe9a 100644 --- a/paddle/fluid/operators/jitkernels/test.cc +++ b/paddle/fluid/operators/jitkernels/test.cc @@ -19,8 +19,11 @@ #include "gflags/gflags.h" #include "glog/logging.h" #include "gtest/gtest.h" -#include "paddle/fluid/operators/math/jit_kernel.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" +#include "paddle/fluid/operators/jitkernels/kernels.h" +// TODO(TJ): remove me +#include "paddle/fluid/operators/jitkernels/registry.h" + +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" constexpr int repeat = 20000; @@ -31,6 +34,75 @@ inline double GetCurrentUS() { return 1e+6 * time.tv_sec + time.tv_usec; } -TEST(JitKernel, vmul) {} +template +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f)) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +template +void ExpectEQ(const T* target, const T* refer, int n) { + if (std::is_floating_point::value) { + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(target[i], refer[i], 1e-3); + } + } else { + for (int i = 0; i < n; ++i) { + EXPECT_EQ(target[i], refer[i]); + } + } +} + +// TODO(TJ): remove me +USE_JITKERNEL_MORE(vmul, mkl); +USE_JITKERNEL_REFER(vmul); + +TEST(JitKernel, vmul) { + using T = float; + using PlaceType = paddle::platform::CPUPlace; + + namespace jit = paddle::operators::jitkernels; + // TODO(TJ): test more vector size + for (int d = 1; d < 30; ++d) { + auto ref = jit::GetRefer(); + auto tgt = jit::Get(d); + EXPECT_TRUE(ref != nullptr); + EXPECT_TRUE(tgt != nullptr); + + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + const float* x_data = x.data(); + const float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + + tgt(x_data, y_data, ztgt_data, d); + ref(x_data, y_data, zref_data, d); + ExpectEQ(ztgt_data, zref_data, d); + + // test inplace x + std::copy(x.begin(), x.end(), zref.begin()); + std::copy(x.begin(), x.end(), ztgt.begin()); + tgt(ztgt_data, y_data, ztgt_data, d); + ref(zref_data, y_data, zref_data, d); + ExpectEQ(ztgt_data, zref_data, d); + + // test inplace y + std::copy(y.begin(), y.end(), zref.begin()); + std::copy(y.begin(), y.end(), ztgt.begin()); + tgt(x_data, ztgt_data, ztgt_data, d); + ref(x_data, zref_data, zref_data, d); + ExpectEQ(ztgt_data, zref_data, d); + } +} TEST(JitKernel, pool) {} -- GitLab