/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include // for memcpy #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/platform/place.h" template void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), const T upper = static_cast(20.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } } template void ExpectEQ(const T* target, const T* refer, int n) { if (std::is_floating_point::value) { for (int i = 0; i < n; ++i) { EXPECT_NEAR(target[i], refer[i], 1e-3); } } else { for (int i = 0; i < n; ++i) { EXPECT_EQ(target[i], refer[i]); } } } std::vector TestSizes() { std::vector s; for (int i = 1; i < 10; ++i) { s.push_back(i); } // // test some large size // s.push_back(100); // s.push_back(1000); // s.push_back(2000); return s; } template void TestTartgetFunc(const typename KernelTuples::func_type tgt, const std::vector& x, const std::vector& y, const std::vector& zref) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(zref.size(), x.size()); EXPECT_EQ(zref.size(), y.size()); const T* x_data = x.data(); const T* y_data = y.data(); const T* zref_data = zref.data(); const int d = zref.size(); std::vector ztgt(d); T* ztgt_data = ztgt.data(); // test normal tgt(x_data, y_data, ztgt_data, d); ExpectEQ(ztgt_data, zref_data, d); // test inplace x std::copy(x.begin(), x.end(), ztgt.begin()); tgt(ztgt_data, y_data, ztgt_data, d); ExpectEQ(ztgt_data, zref_data, d); // test inplace y std::copy(y.begin(), y.end(), ztgt.begin()); tgt(x_data, ztgt_data, ztgt_data, d); ExpectEQ(ztgt_data, zref_data, d); } template void TestXYZNKernel() { namespace jit = paddle::operators::jit; for (int d : TestSizes()) { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT) << ", size: " << d; auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(d), y(d), zref(d); RandomVec(d, x.data()); RandomVec(d, y.data()); std::vector xinp(d), yinp(d); // inplace test std::copy(x.begin(), x.end(), xinp.begin()); std::copy(y.begin(), y.end(), yinp.begin()); const T* x_data = x.data(); const T* y_data = y.data(); T* zref_data = zref.data(); T* xinp_data = xinp.data(); T* yinp_data = yinp.data(); // test refer code inplace ref(x_data, y_data, zref_data, d); ref(x_data, yinp_data, yinp_data, d); ref(xinp_data, y_data, xinp_data, d); ExpectEQ(xinp_data, zref_data, d); ExpectEQ(yinp_data, zref_data, d); // test jitcode auto jitcode = jit::GetJitCode, PlaceType>(d); if (jitcode) { VLOG(10) << "Test Jitcode Kernel, size: " << d; TestTartgetFunc>(jitcode, x, y, zref); } // test all impls in more jit::KernelKey kkey(KT, PlaceType()); auto& pool = jit::KernelPool().Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast>*>( impl.get()); if (i && i->UseMe(d)) { auto more = i->GetFunc(); VLOG(10) << "Test More Kernel, size: " << d; TestTartgetFunc>(more, x, y, zref); } } } // Test result from Get function VLOG(10) << "Test Get function, size: " << d; auto tgt = jit::Get, PlaceType>(d); TestTartgetFunc>(tgt, x, y, zref); } } TEST(JITKernel, vmul) { namespace jit = paddle::operators::jit; TestXYZNKernel(); // TODO(TJ): fix double issue // TestXYZNKernel(); } TEST(JITKernel, vadd) { namespace jit = paddle::operators::jit; TestXYZNKernel(); TestXYZNKernel(); } TEST(JITKernel, vaddrelu) { namespace jit = paddle::operators::jit; TestXYZNKernel(); TestXYZNKernel(); } TEST(JITKernel, vsub) { namespace jit = paddle::operators::jit; TestXYZNKernel(); TestXYZNKernel(); } TEST(JITKernel, pool) {}