提交 417d031f 编写于 作者: T tensor-tang

add refer vadd, vaddrelu, vsub and tests and benchmark

上级 f3250097
......@@ -41,6 +41,6 @@ PaddlePaddle/Paddle/paddle/fluid/
- 性能测试
# 如何添加新的算子
TBD
## Use me
Add USE_JIT_KERNEL(yourname) to CMakefile.
-`KernelType` 中添加 `your_key`
- 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt``USE_JITKERNEL_REFER(your_key)`
......@@ -52,9 +52,10 @@ std::vector<int> TestSizes() {
}
// return this function avg time
template <typename T, typename Func>
double BenchTartgetFunc(const Func tgt, const std::vector<T>& x,
const std::vector<T>& y, std::vector<T>& z) { // NOLINT
template <typename T, typename KernelTuples>
double BenchTartgetFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& x, const std::vector<T>& y,
std::vector<T>& z) { // NOLINT
const T* x_data = x.data();
const T* y_data = y.data();
const int d = z.size();
......@@ -71,40 +72,25 @@ double BenchTartgetFunc(const Func tgt, const std::vector<T>& x,
return (end - start) / FLAGS_repeat;
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
using T = float;
using PlaceType = paddle::platform::CPUPlace;
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
namespace jit = paddle::operators::jit;
const auto KT = jit::vmul;
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
for (int d : TestSizes()) {
// for (kernels type) { // TODO(TJ): more jit::KernelType
std::vector<std::pair<std::string, double>> infos;
std::vector<T> x(d), y(d), z(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
// refer
auto refer = jit::GetRefer<KT, jit::VMulTuples<T>>();
auto refer = jit::GetRefer<KT, jit::XYZNTuples<T>>();
if (refer) {
auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(refer, x, y, z);
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(refer, x, y, z);
infos.push_back(std::make_pair("Refer", res));
}
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
if (jitcode) {
auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, z);
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, z);
infos.push_back(std::make_pair("JitCode", res));
}
......@@ -115,32 +101,50 @@ int main(int argc, char* argv[]) {
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
auto i = dynamic_cast<const jit::KernelImpl<jit::XYZNTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
auto res =
BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(more, x, y, z);
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(more, x, y, z);
infos.push_back(std::make_pair("More", res));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
if (!tgt) {
LOG(ERROR) << "Target can not be empty!";
}
auto res = BenchTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, z);
auto res = BenchTartgetFunc<T, jit::XYZNTuples<T>>(tgt, x, y, z);
infos.push_back(std::make_pair("Target", res));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type: " << KT << ", size " << d << ": ";
loginfos << "Kernel Type: " << jit::to_string(KT) << ", size " << d << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
// }
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
using T = float;
using PlaceType = paddle::platform::CPUPlace;
namespace jit = paddle::operators::jit;
BenchXYZNKernel<jit::vmul, T, PlaceType>();
BenchXYZNKernel<jit::vadd, T, PlaceType>();
BenchXYZNKernel<jit::vaddrelu, T, PlaceType>();
BenchXYZNKernel<jit::vsub, T, PlaceType>();
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
namespace paddle {
namespace operators {
namespace jit {
const char* to_string(KernelType kt) {
switch (kt) {
case vmul:
return "vmul";
case vadd:
return "vadd";
case vaddrelu:
return "vaddrelu";
case vsub:
return "vsub";
case vscal:
return "vscal";
case vexp:
return "vexp";
default:
return "NOT JITKernel";
}
return nullptr;
}
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -112,6 +112,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
return GetRefer<KT, KernelTuples>();
}
const char* to_string(KernelType kt);
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -19,10 +19,10 @@ namespace paddle {
namespace operators {
namespace jit {
typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
typedef enum { vmul = 0, vadd = 1, vaddrelu, vsub, vscal, vexp } KernelType;
template <typename T>
struct VMulTuples {
struct XYZNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
......
......@@ -28,7 +28,7 @@ template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
class VMulKernel : public KernelImpl<VMulTuples<T>> {
class VMulKernel : public KernelImpl<XYZNTuples<T>> {
public:
VMulKernel() { this->func = VMul<T>; }
bool UseMe(int d) const override {
......
......@@ -17,5 +17,13 @@
namespace refer = paddle::operators::jit::refer;
REGISTER_JITKERNEL_REFER(vmul, refer::VMulKernel<float>,
refer::VMulKernel<double>);
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(vmul, VMul);
REGISTER_REFER_KERNEL(vadd, VAdd);
REGISTER_REFER_KERNEL(vaddrelu, VAddRelu);
REGISTER_REFER_KERNEL(vsub, VSub);
#undef REGISTER_REFER_KERNEL
......@@ -13,6 +13,7 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -21,6 +22,7 @@ namespace operators {
namespace jit {
namespace refer {
// Refer code only focus on correctness
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
......@@ -29,10 +31,47 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
template <typename T>
class VMulKernel : public ReferKernel<VMulTuples<T>> {
public:
VMulKernel() { this->func = VMul<T>; }
};
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VSub(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
DECLARE_REFER_KERNEL(VMul, XYZNTuples);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
DECLARE_REFER_KERNEL(VSub, XYZNTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
} // namespace jit
......
......@@ -48,18 +48,20 @@ void ExpectEQ(const T* target, const T* refer, int n) {
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i < 30; ++i) {
for (int i = 1; i < 10; ++i) {
s.push_back(i);
}
// test some large size
s.push_back(100);
s.push_back(1000);
// // test some large size
// s.push_back(100);
// s.push_back(1000);
// s.push_back(2000);
return s;
}
template <typename T, typename Func>
void TestTartgetFunc(const Func tgt, const std::vector<T>& x,
const std::vector<T>& y, const std::vector<T>& zref) {
template <typename T, typename KernelTuples>
void TestTartgetFunc(const typename KernelTuples::func_type tgt,
const std::vector<T>& x, const std::vector<T>& y,
const std::vector<T>& zref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(zref.size(), x.size());
EXPECT_EQ(zref.size(), y.size());
......@@ -83,13 +85,13 @@ void TestTartgetFunc(const Func tgt, const std::vector<T>& x,
ExpectEQ<T>(ztgt_data, zref_data, d);
}
TEST(JitKernel, vmul) {
using T = float;
using PlaceType = paddle::platform::CPUPlace;
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() {
namespace jit = paddle::operators::jit;
const auto KT = jit::vmul;
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::VMulTuples<T>>();
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT)
<< ", size: " << d;
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d);
......@@ -114,10 +116,10 @@ TEST(JitKernel, vmul) {
ExpectEQ<T>(yinp_data, zref_data, d);
// test jitcode
auto jitcode = jit::GetJitCode<KT, jit::VMulTuples<T>, PlaceType>(d);
auto jitcode = jit::GetJitCode<KT, jit::XYZNTuples<T>, PlaceType>(d);
if (jitcode) {
VLOG(10) << "Test jitcode, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(jitcode, x, y, zref);
VLOG(10) << "Test Jitcode Kernel, size: " << d;
TestTartgetFunc<T, jit::XYZNTuples<T>>(jitcode, x, y, zref);
}
// test all impls in more
......@@ -127,20 +129,45 @@ TEST(JitKernel, vmul) {
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelImpl<jit::VMulTuples<T>>*>(
auto i = dynamic_cast<const jit::KernelImpl<jit::XYZNTuples<T>>*>(
impl.get());
if (i && i->UseMe(d)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel, size: " << d;
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(more, x, y, zref);
TestTartgetFunc<T, jit::XYZNTuples<T>>(more, x, y, zref);
}
}
}
// Test result from Get function
VLOG(10) << "Test Get function, size: " << d;
auto tgt = jit::Get<KT, jit::VMulTuples<T>, PlaceType>(d);
TestTartgetFunc<T, jit::VMulTuples<T>::func_type>(tgt, x, y, zref);
auto tgt = jit::Get<KT, jit::XYZNTuples<T>, PlaceType>(d);
TestTartgetFunc<T, jit::XYZNTuples<T>>(tgt, x, y, zref);
}
}
TEST(JitKernel, pool) {}
TEST(JITKernel, vmul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>();
// TODO(TJ): fix double issue
// TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vadd) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vaddrelu) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, vsub) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, pool) {}
......@@ -23,36 +23,6 @@ namespace operators {
namespace math {
namespace jitkernel {
namespace refer {
/* Refer code only focus on correctness */
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBias(const T* a, const T* x, T* y, int n) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册