提交 facfecbd 编写于 作者: T tensor-tang

follow comment: reuse time function and change to upper case

test=develop
上级 f5532877
......@@ -2,7 +2,7 @@
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector mul, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
......
......@@ -19,6 +19,7 @@
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
......@@ -26,17 +27,10 @@ DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
std::mt19937 rng(seed);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
......@@ -58,12 +52,12 @@ struct BenchFunc {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
auto start = GetCurrentUS();
auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...);
}
auto end = GetCurrentUS();
return (end - start) / FLAGS_repeat;
auto end = paddle::platform::PosixInNsec() / 1e-3;
return static_cast<double>(end - start) / FLAGS_repeat;
}
};
......
......@@ -67,7 +67,7 @@ class VActFunc : public JitCode {
virtual void genCode() = 0;
protected:
// compute relu with ymm, xmm
// compute RELU with ymm, xmm
template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx);
......@@ -75,7 +75,7 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero);
}
// compute exp with ymm, xmm
// compute EXP with ymm, xmm
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
......@@ -159,7 +159,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global);
}
// compute sigmoid with ymm, xmm
// compute SIGMOID with ymm, xmm
template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
......@@ -184,7 +184,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global);
}
// compute tanh with ymm, xmm
// compute TANH with ymm, xmm
template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
......@@ -211,7 +211,7 @@ class VActFunc : public JitCode {
pop(reg_ptr_global);
}
// compute identity with ymm, xmm
// compute IDENTITY with ymm, xmm
template <typename JMM>
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
JMM zero = JMM(zero_idx);
......@@ -225,19 +225,19 @@ class VActFunc : public JitCode {
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15
switch (type) {
case operand_type::relu:
case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::exp:
case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::sigmoid:
case operand_type::SIGMOID:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::tanh:
case operand_type::TANH:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
case operand_type::identity:
case operand_type::IDENTITY:
identity_jmm<JMM>(dst, src, 15);
break;
default:
......@@ -252,9 +252,9 @@ class VActJitCode : public VActFunc {
explicit VActJitCode(int d, operand_type type, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
type_ == operand_type::identity)) {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -263,19 +263,19 @@ class VActJitCode : public VActFunc {
const char* name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::exp:
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::sigmoid:
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::tanh:
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::identity:
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
......@@ -305,11 +305,11 @@ class VActJitCode : public VActFunc {
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
DECLARE_ACT_JITCODE(VRelu, operand_type::relu);
DECLARE_ACT_JITCODE(VIdentity, operand_type::identity);
DECLARE_ACT_JITCODE(VExp, operand_type::exp);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::sigmoid);
DECLARE_ACT_JITCODE(VTanh, operand_type::tanh);
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
#undef DECLARE_ACT_JITCODE
......
......@@ -39,9 +39,9 @@ void VXXJitCode::genCode() {
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
......@@ -79,10 +79,10 @@ void VXXJitCode::genCode() {
}
}
switch (type_) {
case operand_type::mul:
case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::add:
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
......
......@@ -34,7 +34,7 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::mul || type_ == operand_type::add)) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -47,9 +47,9 @@ class VXXJitCode : public JitCode {
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
if (type_ == operand_type::MUL) {
base += "_Mul";
} else if (type_ == operand_type::add) {
} else if (type_ == operand_type::ADD) {
base += "_Add";
}
if (scalar_index_ == 2) {
......@@ -90,12 +90,12 @@ class VXXJitCode : public JitCode {
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE
......
......@@ -31,17 +31,17 @@ class GRUJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) {
return operand_type::sigmoid;
return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) {
return operand_type::relu;
return operand_type::RELU;
} else if (type == KernelType::vtanh) {
return operand_type::tanh;
return operand_type::TANH;
} else if (type == KernelType::videntity) {
return operand_type::identity;
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::identity;
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
......@@ -60,19 +60,19 @@ class GRUJitCode : public VActFunc {
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::exp:
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::sigmoid:
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::tanh:
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::identity:
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
......
......@@ -46,14 +46,14 @@ using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
MUL = 0,
ADD,
SUB,
RELU,
EXP,
SIGMOID,
TANH,
IDENTITY
} operand_type;
#define DECLARE_JIT_CODE(codename) \
......
......@@ -34,17 +34,17 @@ class LSTMJitCode : public VActFunc {
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) {
return operand_type::sigmoid;
return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) {
return operand_type::relu;
return operand_type::RELU;
} else if (type == KernelType::vtanh) {
return operand_type::tanh;
return operand_type::TANH;
} else if (type == KernelType::videntity) {
return operand_type::identity;
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::identity;
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
......@@ -63,19 +63,19 @@ class LSTMJitCode : public VActFunc {
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::exp:
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::sigmoid:
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::tanh:
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::identity:
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册