提交 93e75c5a 编写于 作者: T tensor-tang

refine jitcode of vsub and vsquare

test=develop
上级 d618e483
......@@ -11,11 +11,12 @@ endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVSub)
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVSquare)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
......
......@@ -91,6 +91,7 @@ void VActJitCode::genCode() {
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VSquare);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
......@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
8 /* average bytes for each instruction */;
}
size_t VSquareCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
......@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
......
......@@ -75,6 +75,12 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero);
}
// compute SQUARE with ymm, xmm
template <typename JMM>
void square_jmm(JMM& dst, JMM& src) { // NOLINT
vmulps(dst, src, src);
}
// compute EXP with ymm, xmm
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
......@@ -228,6 +234,9 @@ class VActFunc : public JitCode {
case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::SQUARE:
square_jmm<JMM>(dst, src);
break;
case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
......@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY)) {
type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::SQUARE:
base += "_Square";
break;
case operand_type::EXP:
base += "_Exp";
break;
......@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
};
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE);
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
......
......@@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::SUB) {
vsubps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
......@@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::SUB:
vsubps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
......@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
......
......@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD ||
type_ == operand_type::SUB)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
} else if (type_ == operand_type::SUB) {
base += "_SUB";
}
if (scalar_index_ == 2) {
base += "_Scalar";
......
......@@ -51,6 +51,7 @@ typedef enum {
SUB,
RELU,
EXP,
SQUARE,
SIGMOID,
TANH,
IDENTITY
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册