提交 28eb7d84 编写于 作者: T tensor-tang

test all impls and all inplace cases

上级 d4cab7d9
......@@ -28,33 +28,16 @@ namespace paddle {
namespace operators {
namespace jit {
// Refer code do not related with attr, and always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) {
typename PlaceType>
inline const Func GetJitCode(Attr attr) {
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
......@@ -73,8 +56,38 @@ const Func Get(Attr attr) {
}
}
}
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) {
auto jitfunc = GetJitCode<KT, T, Func, Attr, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
......
......@@ -55,46 +55,105 @@ void ExpectEQ(const T* target, const T* refer, int n) {
}
}
TEST(JitKernel, vmul) {
using T = float;
using PlaceType = paddle::platform::CPUPlace;
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i < 30; ++i) {
s.push_back(i);
}
// test some large size
s.push_back(100);
s.push_back(1000);
return s;
}
namespace jit = paddle::operators::jit;
// TODO(TJ): test more vector size
for (int d = 1; d < 30; ++d) {
auto ref = jit::GetRefer<jit::vmul, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>();
auto tgt = jit::Get<jit::vmul, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
EXPECT_TRUE(ref != nullptr);
template <typename T, typename Func>
void TestTartgetFunc(const Func tgt, const std::vector<T>& x,
const std::vector<T>& y, const std::vector<T>& zref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(zref.size(), x.size());
EXPECT_EQ(zref.size(), y.size());
const T* x_data = x.data();
const T* y_data = y.data();
const T* zref_data = zref.data();
const int d = zref.size();
std::vector<T> x(d), y(d);
std::vector<T> zref(d), ztgt(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
std::vector<T> ztgt(d);
T* ztgt_data = ztgt.data();
// test normal
tgt(x_data, y_data, ztgt_data, d);
ref(x_data, y_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), zref.begin());
std::copy(x.begin(), x.end(), ztgt.begin());
tgt(ztgt_data, y_data, ztgt_data, d);
ref(zref_data, y_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace y
std::copy(y.begin(), y.end(), zref.begin());
std::copy(y.begin(), y.end(), ztgt.begin());
tgt(x_data, ztgt_data, ztgt_data, d);
ref(x_data, zref_data, zref_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
}
TEST(JitKernel, vmul) {
using T = float;
using PlaceType = paddle::platform::CPUPlace;
namespace jit = paddle::operators::jit;
const auto KT = jit::vmul;
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
std::vector<T> xinp(d), yinp(d); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
std::copy(y.begin(), y.end(), yinp.begin());
const T* x_data = x.data();
const T* y_data = y.data();
T* zref_data = zref.data();
T* xinp_data = xinp.data();
T* yinp_data = yinp.data();
// test refer code inplace
ref(x_data, y_data, zref_data, d);
ref(x_data, yinp_data, yinp_data, d);
ref(xinp_data, y_data, xinp_data, d);
ExpectEQ<T>(xinp_data, zref_data, d);
ExpectEQ<T>(yinp_data, zref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test jitcode, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, zref);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i =
dynamic_cast<const jit::KernelImpl<T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(more, x, y, zref);
}
}
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, zref);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册