提交 bc0df6a9 编写于 作者: T tensor-tang

make typename tuples

上级 194ce2e9
......@@ -94,8 +94,7 @@ int main(int argc, char* argv[]) {
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
// refer
auto refer = jit::GetRefer<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>();
auto refer = jit::GetRefer<KT, jit::VMulTuples<T>>();
if (refer) {
auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(refer, x, y, z);
......@@ -103,8 +102,7 @@ int main(int argc, char* argv[]) {
}
// test jitcode
auto jitcode = jit::GetJitCode<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
if (jitcode) {
auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, z);
......@@ -118,9 +116,7 @@ int main(int argc, char* argv[]) {
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>*>(
auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
......@@ -132,8 +128,7 @@ int main(int argc, char* argv[]) {
}
// Test result from Get function
auto tgt = jit::Get<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
if (!tgt) {
LOG(ERROR) << "Target can not be empty!";
}
......
......@@ -32,9 +32,11 @@ namespace jit {
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType>
inline Func GetJitCode(Attr attr) {
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename KernelTuples::func_type GetJitCode(
typename KernelTuples::attr_type attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
......@@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) {
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
......@@ -74,7 +76,7 @@ inline Func GetRefer() {
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) {
return i->GetFunc();
}
......@@ -82,10 +84,10 @@ inline Func GetRefer() {
return nullptr;
}
template <KernelType KT, typename T, typename Func, typename Attr,
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
Func Get(Attr attr) {
auto jitfunc = GetJitCode<KT, T, Func, Attr, PlaceType>(attr);
typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
......@@ -97,7 +99,7 @@ Func Get(Attr attr) {
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get());
auto i = dynamic_cast<const KernelImpl<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
......@@ -105,7 +107,7 @@ Func Get(Attr attr) {
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, T, Func, Attr>();
return GetRefer<KT, KernelTuples>();
}
} // namespace jit
......
......@@ -36,10 +36,13 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename T, typename Func, typename Attr>
template <typename KernelTuples>
class KernelImpl : public Kernel {
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
public:
using ELEMENT_TYPE = T;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(Attr attr) const = 0;
......@@ -47,11 +50,13 @@ class KernelImpl : public Kernel {
Func func{nullptr};
};
template <typename T, typename Func, typename Attr>
class ReferKernel : public KernelImpl<T, Func, Attr> {
template <typename KernelTuples>
class ReferKernel : public KernelImpl<KernelTuples> {
public:
// Refer code can always be used
bool UseMe(Attr attr) const override { return true; }
bool UseMe(typename KernelTuples::attr_type attr) const override {
return true;
}
};
} // namespace jit
......
......@@ -28,8 +28,7 @@ template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
class VMulKernel : public KernelImpl<T, typename VMulTuples<T>::func_type,
typename VMulTuples<T>::attr_type> {
class VMulKernel : public KernelImpl<VMulTuples<T>> {
public:
VMulKernel() { this->func = VMul<T>; }
bool UseMe(int d) const override {
......
......@@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
template <typename T>
class VMulKernel : public ReferKernel<T, typename VMulTuples<T>::func_type,
typename VMulTuples<T>::attr_type> {
class VMulKernel : public ReferKernel<VMulTuples<T>> {
public:
VMulKernel() { this->func = VMul<T>; }
};
......
......@@ -89,8 +89,7 @@ TEST(JitKernel, vmul) {
namespace jit = paddle::operators::jit;
const auto KT = jit::vmul;
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>();
auto ref = jit::GetRefer<KT, jit::VMulTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d);
......@@ -115,8 +114,7 @@ TEST(JitKernel, vmul) {
ExpectEQ<T>(yinp_data, zref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test jitcode, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, zref);
......@@ -129,9 +127,7 @@ TEST(JitKernel, vmul) {
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>*>(
auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
......@@ -142,8 +138,7 @@ TEST(JitKernel, vmul) {
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, zref);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册