提交 facfecbd 编写于 作者: T tensor-tang

follow comment: reuse time function and change to upper case

test=develop
上级 f5532877
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
结合函数模板和JIT生成需要的kernel函数。 结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。 这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector mul, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。 这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。 目前仅支持CPU上的高性能计算。
## 目录结构 ## 目录结构
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -26,17 +27,10 @@ DEFINE_int32(burning, 10, "Burning times."); ...@@ -26,17 +27,10 @@ DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times."); DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested."); DEFINE_int32(max_size, 1000, "The Max size would be tested.");
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
template <typename T> template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) { const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
static unsigned int seed = 100; std::mt19937 rng(seed);
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1); std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower); a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
...@@ -58,12 +52,12 @@ struct BenchFunc { ...@@ -58,12 +52,12 @@ struct BenchFunc {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
} }
auto start = GetCurrentUS(); auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) { for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...); tgt(args...);
} }
auto end = GetCurrentUS(); auto end = paddle::platform::PosixInNsec() / 1e-3;
return (end - start) / FLAGS_repeat; return static_cast<double>(end - start) / FLAGS_repeat;
} }
}; };
......
...@@ -67,7 +67,7 @@ class VActFunc : public JitCode { ...@@ -67,7 +67,7 @@ class VActFunc : public JitCode {
virtual void genCode() = 0; virtual void genCode() = 0;
protected: protected:
// compute relu with ymm, xmm // compute RELU with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx); JMM zero = JMM(zero_idx);
...@@ -75,7 +75,7 @@ class VActFunc : public JitCode { ...@@ -75,7 +75,7 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute EXP with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
...@@ -159,7 +159,7 @@ class VActFunc : public JitCode { ...@@ -159,7 +159,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute sigmoid with ymm, xmm // compute SIGMOID with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -184,7 +184,7 @@ class VActFunc : public JitCode { ...@@ -184,7 +184,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute tanh with ymm, xmm // compute TANH with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -211,7 +211,7 @@ class VActFunc : public JitCode { ...@@ -211,7 +211,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute identity with ymm, xmm // compute IDENTITY with ymm, xmm
template <typename JMM> template <typename JMM>
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
JMM zero = JMM(zero_idx); JMM zero = JMM(zero_idx);
...@@ -225,19 +225,19 @@ class VActFunc : public JitCode { ...@@ -225,19 +225,19 @@ class VActFunc : public JitCode {
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15 // use 11~15
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::exp: case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::tanh: case operand_type::TANH:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::identity: case operand_type::IDENTITY:
identity_jmm<JMM>(dst, src, 15); identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
...@@ -252,9 +252,9 @@ class VActJitCode : public VActFunc { ...@@ -252,9 +252,9 @@ class VActJitCode : public VActFunc {
explicit VActJitCode(int d, operand_type type, size_t code_size, explicit VActJitCode(int d, operand_type type, size_t code_size,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), num_(d), type_(type) { : VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::relu || type_ == operand_type::exp || if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::sigmoid || type_ == operand_type::tanh || type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::identity)) { type_ == operand_type::IDENTITY)) {
LOG(FATAL) << "Do not support this operand type: " << type_; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
this->genCode(); this->genCode();
...@@ -263,19 +263,19 @@ class VActJitCode : public VActFunc { ...@@ -263,19 +263,19 @@ class VActJitCode : public VActFunc {
const char* name() const override { const char* name() const override {
std::string base = "VActJitCode"; std::string base = "VActJitCode";
switch (type_) { switch (type_) {
case operand_type::relu: case operand_type::RELU:
base += "_Relu"; base += "_Relu";
break; break;
case operand_type::exp: case operand_type::EXP:
base += "_Exp"; base += "_Exp";
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
base += "_Sigmoid"; base += "_Sigmoid";
break; break;
case operand_type::tanh: case operand_type::TANH:
base += "_Tanh"; base += "_Tanh";
break; break;
case operand_type::identity: case operand_type::IDENTITY:
base += "_Identity"; base += "_Identity";
break; break;
default: default:
...@@ -305,11 +305,11 @@ class VActJitCode : public VActFunc { ...@@ -305,11 +305,11 @@ class VActJitCode : public VActFunc {
: VActJitCode(d, op_type, code_size, code_ptr) {} \ : VActJitCode(d, op_type, code_size, code_ptr) {} \
}; };
DECLARE_ACT_JITCODE(VRelu, operand_type::relu); DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VIdentity, operand_type::identity); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::exp); DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::sigmoid); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
DECLARE_ACT_JITCODE(VTanh, operand_type::tanh); DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
#undef DECLARE_ACT_JITCODE #undef DECLARE_ACT_JITCODE
......
...@@ -39,9 +39,9 @@ void VXXJitCode::genCode() { ...@@ -39,9 +39,9 @@ void VXXJitCode::genCode() {
if (scalar_index_ != 2) { if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]); vmovups(ymm_src2, ptr[param2 + offset]);
} }
if (type_ == operand_type::mul) { if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2); vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2); vaddps(ymm_dst, ymm_src1, ymm_src2);
} }
if (with_relu_) { if (with_relu_) {
...@@ -79,10 +79,10 @@ void VXXJitCode::genCode() { ...@@ -79,10 +79,10 @@ void VXXJitCode::genCode() {
} }
} }
switch (type_) { switch (type_) {
case operand_type::mul: case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2); vmulps(xmm_dst, xmm_src1, xmm_src2);
break; break;
case operand_type::add: case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2); vaddps(xmm_dst, xmm_src1, xmm_src2);
break; break;
default: default:
......
...@@ -34,7 +34,7 @@ class VXXJitCode : public JitCode { ...@@ -34,7 +34,7 @@ class VXXJitCode : public JitCode {
type_(type), type_(type),
scalar_index_(scalar_index), scalar_index_(scalar_index),
with_relu_(with_relu) { with_relu_(with_relu) {
if (!(type_ == operand_type::mul || type_ == operand_type::add)) { if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
this->genCode(); this->genCode();
...@@ -47,9 +47,9 @@ class VXXJitCode : public JitCode { ...@@ -47,9 +47,9 @@ class VXXJitCode : public JitCode {
} else { } else {
base += "_Vec"; base += "_Vec";
} }
if (type_ == operand_type::mul) { if (type_ == operand_type::MUL) {
base += "_Mul"; base += "_Mul";
} else if (type_ == operand_type::add) { } else if (type_ == operand_type::ADD) {
base += "_Add"; base += "_Add";
} }
if (scalar_index_ == 2) { if (scalar_index_ == 2) {
...@@ -90,12 +90,12 @@ class VXXJitCode : public JitCode { ...@@ -90,12 +90,12 @@ class VXXJitCode : public JitCode {
} \ } \
}; };
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false); DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false); DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false); DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true); DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false); DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false); DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE #undef DECLARE_BLAS_JITCODE
......
...@@ -31,17 +31,17 @@ class GRUJitCode : public VActFunc { ...@@ -31,17 +31,17 @@ class GRUJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) { : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::vsigmoid) {
return operand_type::sigmoid; return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) { } else if (type == KernelType::vrelu) {
return operand_type::relu; return operand_type::RELU;
} else if (type == KernelType::vtanh) { } else if (type == KernelType::vtanh) {
return operand_type::tanh; return operand_type::TANH;
} else if (type == KernelType::videntity) { } else if (type == KernelType::videntity) {
return operand_type::identity; return operand_type::IDENTITY;
} else { } else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type; LOG(FATAL) << "Do not support this jit::KernelType: " << type;
} }
return operand_type::identity; return operand_type::IDENTITY;
}; };
act_gate_ = typeExchange(attr.act_gate); act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand); act_cand_ = typeExchange(attr.act_cand);
...@@ -60,19 +60,19 @@ class GRUJitCode : public VActFunc { ...@@ -60,19 +60,19 @@ class GRUJitCode : public VActFunc {
} }
auto AddTypeStr = [&](operand_type type) { auto AddTypeStr = [&](operand_type type) {
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
base += "_Relu"; base += "_Relu";
break; break;
case operand_type::exp: case operand_type::EXP:
base += "_Exp"; base += "_Exp";
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
base += "_Sigmoid"; base += "_Sigmoid";
break; break;
case operand_type::tanh: case operand_type::TANH:
base += "_Tanh"; base += "_Tanh";
break; break;
case operand_type::identity: case operand_type::IDENTITY:
base += "_Identity"; base += "_Identity";
break; break;
default: default:
......
...@@ -46,14 +46,14 @@ using zmm_t = const Xbyak::Zmm; ...@@ -46,14 +46,14 @@ using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label; using Label = Xbyak::Label;
typedef enum { typedef enum {
mul = 0, MUL = 0,
add, ADD,
sub, SUB,
relu, RELU,
exp, EXP,
sigmoid, SIGMOID,
tanh, TANH,
identity IDENTITY
} operand_type; } operand_type;
#define DECLARE_JIT_CODE(codename) \ #define DECLARE_JIT_CODE(codename) \
......
...@@ -34,17 +34,17 @@ class LSTMJitCode : public VActFunc { ...@@ -34,17 +34,17 @@ class LSTMJitCode : public VActFunc {
use_peephole_(attr.use_peephole) { use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::vsigmoid) {
return operand_type::sigmoid; return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) { } else if (type == KernelType::vrelu) {
return operand_type::relu; return operand_type::RELU;
} else if (type == KernelType::vtanh) { } else if (type == KernelType::vtanh) {
return operand_type::tanh; return operand_type::TANH;
} else if (type == KernelType::videntity) { } else if (type == KernelType::videntity) {
return operand_type::identity; return operand_type::IDENTITY;
} else { } else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type; LOG(FATAL) << "Do not support this jit::KernelType: " << type;
} }
return operand_type::identity; return operand_type::IDENTITY;
}; };
act_gate_ = typeExchange(attr.act_gate); act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand); act_cand_ = typeExchange(attr.act_cand);
...@@ -63,19 +63,19 @@ class LSTMJitCode : public VActFunc { ...@@ -63,19 +63,19 @@ class LSTMJitCode : public VActFunc {
} }
auto AddTypeStr = [&](operand_type type) { auto AddTypeStr = [&](operand_type type) {
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
base += "_Relu"; base += "_Relu";
break; break;
case operand_type::exp: case operand_type::EXP:
base += "_Exp"; base += "_Exp";
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
base += "_Sigmoid"; base += "_Sigmoid";
break; break;
case operand_type::tanh: case operand_type::TANH:
base += "_Tanh"; base += "_Tanh";
break; break;
case operand_type::identity: case operand_type::IDENTITY:
base += "_Identity"; base += "_Identity";
break; break;
default: default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册