提交 bc0df6a9 编写于 作者: T tensor-tang

make typename tuples

上级 194ce2e9
...@@ -94,8 +94,7 @@ int main(int argc, char* argv[]) { ...@@ -94,8 +94,7 @@ int main(int argc, char* argv[]) {
RandomVec<T>(d, x.data()); RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data()); RandomVec<T>(d, y.data());
// refer // refer
auto refer = jit::GetRefer<KT, T, jit::VMulTuples<T>::func_type, auto refer = jit::GetRefer<KT, jit::VMulTuples<T>>();
jit::VMulTuples<T>::attr_type>();
if (refer) { if (refer) {
auto res = auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(refer, x, y, z); BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(refer, x, y, z);
...@@ -103,8 +102,7 @@ int main(int argc, char* argv[]) { ...@@ -103,8 +102,7 @@ int main(int argc, char* argv[]) {
} }
// test jitcode // test jitcode
auto jitcode = jit::GetJitCode<KT, T, jit::VMulTuples<T>::func_type, auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
jit::VMulTuples<T>::attr_type, PlaceType>(d);
if (jitcode) { if (jitcode) {
auto res = auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, z); BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, z);
...@@ -118,10 +116,8 @@ int main(int argc, char* argv[]) { ...@@ -118,10 +116,8 @@ int main(int argc, char* argv[]) {
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
dynamic_cast<const jit::KernelImpl<T, jit::VMulTuples<T>::func_type, impl.get());
jit::VMulTuples<T>::attr_type>*>(
impl.get());
if (i && i->UseMe(d)) { if (i && i->UseMe(d)) {
auto more = i->GetFunc(); auto more = i->GetFunc();
auto res = auto res =
...@@ -132,8 +128,7 @@ int main(int argc, char* argv[]) { ...@@ -132,8 +128,7 @@ int main(int argc, char* argv[]) {
} }
// Test result from Get function // Test result from Get function
auto tgt = jit::Get<KT, T, jit::VMulTuples<T>::func_type, auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
jit::VMulTuples<T>::attr_type, PlaceType>(d);
if (!tgt) { if (!tgt) {
LOG(ERROR) << "Target can not be empty!"; LOG(ERROR) << "Target can not be empty!";
} }
......
...@@ -32,9 +32,11 @@ namespace jit { ...@@ -32,9 +32,11 @@ namespace jit {
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
template <KernelType KT, typename T, typename Func, typename Attr, template <KernelType KT, typename KernelTuples, typename PlaceType>
typename PlaceType> inline typename KernelTuples::func_type GetJitCode(
inline Func GetJitCode(Attr attr) { typename KernelTuples::attr_type attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr); size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance(); auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
...@@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) { ...@@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) {
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr> template <KernelType KT, typename KernelTuples>
inline Func GetRefer() { inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace()); KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
...@@ -74,7 +76,7 @@ inline Func GetRefer() { ...@@ -74,7 +76,7 @@ inline Func GetRefer() {
"Every Kernel should have reference function."); "Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i->GetFunc();
} }
...@@ -82,10 +84,10 @@ inline Func GetRefer() { ...@@ -82,10 +84,10 @@ inline Func GetRefer() {
return nullptr; return nullptr;
} }
template <KernelType KT, typename T, typename Func, typename Attr, template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace> typename PlaceType = platform::CPUPlace>
Func Get(Attr attr) { typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
auto jitfunc = GetJitCode<KT, T, Func, Attr, PlaceType>(attr); auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) { if (jitfunc) {
return jitfunc; return jitfunc;
} }
...@@ -97,7 +99,7 @@ Func Get(Attr attr) { ...@@ -97,7 +99,7 @@ Func Get(Attr attr) {
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get()); auto i = dynamic_cast<const KernelImpl<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->UseMe(attr)) {
return i->GetFunc(); return i->GetFunc();
} }
...@@ -105,7 +107,7 @@ Func Get(Attr attr) { ...@@ -105,7 +107,7 @@ Func Get(Attr attr) {
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KT, T, Func, Attr>(); return GetRefer<KT, KernelTuples>();
} }
} // namespace jit } // namespace jit
......
...@@ -36,10 +36,13 @@ class Kernel { ...@@ -36,10 +36,13 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
template <typename T, typename Func, typename Attr> template <typename KernelTuples>
class KernelImpl : public Kernel { class KernelImpl : public Kernel {
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
public: public:
using ELEMENT_TYPE = T;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(Attr attr) const = 0; virtual bool UseMe(Attr attr) const = 0;
...@@ -47,11 +50,13 @@ class KernelImpl : public Kernel { ...@@ -47,11 +50,13 @@ class KernelImpl : public Kernel {
Func func{nullptr}; Func func{nullptr};
}; };
template <typename T, typename Func, typename Attr> template <typename KernelTuples>
class ReferKernel : public KernelImpl<T, Func, Attr> { class ReferKernel : public KernelImpl<KernelTuples> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(Attr attr) const override { return true; } bool UseMe(typename KernelTuples::attr_type attr) const override {
return true;
}
}; };
} // namespace jit } // namespace jit
......
...@@ -28,8 +28,7 @@ template <typename T> ...@@ -28,8 +28,7 @@ template <typename T>
void VMul(const T* x, const T* y, T* z, int n); void VMul(const T* x, const T* y, T* z, int n);
template <typename T> template <typename T>
class VMulKernel : public KernelImpl<T, typename VMulTuples<T>::func_type, class VMulKernel : public KernelImpl<VMulTuples<T>> {
typename VMulTuples<T>::attr_type> {
public: public:
VMulKernel() { this->func = VMul<T>; } VMulKernel() { this->func = VMul<T>; }
bool UseMe(int d) const override { bool UseMe(int d) const override {
......
...@@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) { ...@@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) {
} }
template <typename T> template <typename T>
class VMulKernel : public ReferKernel<T, typename VMulTuples<T>::func_type, class VMulKernel : public ReferKernel<VMulTuples<T>> {
typename VMulTuples<T>::attr_type> {
public: public:
VMulKernel() { this->func = VMul<T>; } VMulKernel() { this->func = VMul<T>; }
}; };
......
...@@ -89,8 +89,7 @@ TEST(JitKernel, vmul) { ...@@ -89,8 +89,7 @@ TEST(JitKernel, vmul) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
const auto KT = jit::vmul; const auto KT = jit::vmul;
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, T, jit::VMulTuples<T>::func_type, auto ref = jit::GetRefer<KT, jit::VMulTuples<T>>();
jit::VMulTuples<T>::attr_type>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d); std::vector<T> x(d), y(d), zref(d);
...@@ -115,8 +114,7 @@ TEST(JitKernel, vmul) { ...@@ -115,8 +114,7 @@ TEST(JitKernel, vmul) {
ExpectEQ<T>(yinp_data, zref_data, d); ExpectEQ<T>(yinp_data, zref_data, d);
// test jitcode // test jitcode
auto jitcode = jit::GetJitCode<KT, T, jit::VMulTuples<T>::func_type, auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
jit::VMulTuples<T>::attr_type, PlaceType>(d);
if (jitcode) { if (jitcode) {
VLOG(10) << "Test jitcode, size: " << d; VLOG(10) << "Test jitcode, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, zref); TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, zref);
...@@ -129,10 +127,8 @@ TEST(JitKernel, vmul) { ...@@ -129,10 +127,8 @@ TEST(JitKernel, vmul) {
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
dynamic_cast<const jit::KernelImpl<T, jit::VMulTuples<T>::func_type, impl.get());
jit::VMulTuples<T>::attr_type>*>(
impl.get());
if (i && i->UseMe(d)) { if (i && i->UseMe(d)) {
auto more = i->GetFunc(); auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d; VLOG(10) << "Test More Kernel, size: " << d;
...@@ -142,8 +138,7 @@ TEST(JitKernel, vmul) { ...@@ -142,8 +138,7 @@ TEST(JitKernel, vmul) {
} }
// Test result from Get function // Test result from Get function
VLOG(10) << "Test Get function, size: " << d; VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, T, jit::VMulTuples<T>::func_type, auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
jit::VMulTuples<T>::attr_type, PlaceType>(d);
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, zref); TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, zref);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册