提交 93e75c5a 编写于 作者: T tensor-tang

refine jitcode of vsub and vsquare

test=develop
上级 d618e483
...@@ -11,11 +11,12 @@ endfunction() ...@@ -11,11 +11,12 @@ endfunction()
# use gen jitcode kernel by name # use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul) USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd) USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me USE_JITKERNEL_GEN(kVSub)
USE_JITKERNEL_GEN(kVAddRelu) USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal) USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias) USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu) USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVSquare)
USE_JITKERNEL_GEN(kVIdentity) USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp) USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid) USE_JITKERNEL_GEN(kVSigmoid)
......
...@@ -91,6 +91,7 @@ void VActJitCode::genCode() { ...@@ -91,6 +91,7 @@ void VActJitCode::genCode() {
} }
DECLARE_ACT_CREATOR(VRelu); DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VSquare);
DECLARE_ACT_CREATOR(VIdentity); DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp); DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid); DECLARE_ACT_CREATOR(VSigmoid);
...@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const { ...@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
8 /* average bytes for each instruction */; 8 /* average bytes for each instruction */;
} }
size_t VSquareCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VIdentityCreator::CodeSize(const int& d) const { size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
} }
...@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const { ...@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
......
...@@ -75,6 +75,12 @@ class VActFunc : public JitCode { ...@@ -75,6 +75,12 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute SQUARE with ymm, xmm
template <typename JMM>
void square_jmm(JMM& dst, JMM& src) { // NOLINT
vmulps(dst, src, src);
}
// compute EXP with ymm, xmm // compute EXP with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
...@@ -228,6 +234,9 @@ class VActFunc : public JitCode { ...@@ -228,6 +234,9 @@ class VActFunc : public JitCode {
case operand_type::RELU: case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::SQUARE:
square_jmm<JMM>(dst, src);
break;
case operand_type::EXP: case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
...@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc { ...@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), num_(d), type_(type) { : VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP || if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH || type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY)) { type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
LOG(FATAL) << "Do not support this operand type: " << type_; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
this->genCode(); this->genCode();
...@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc { ...@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
case operand_type::RELU: case operand_type::RELU:
base += "_Relu"; base += "_Relu";
break; break;
case operand_type::SQUARE:
base += "_Square";
break;
case operand_type::EXP: case operand_type::EXP:
base += "_Exp"; base += "_Exp";
break; break;
...@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc { ...@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
}; };
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU); DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE);
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::EXP); DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
......
...@@ -43,6 +43,8 @@ void VXXJitCode::genCode() { ...@@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
vmulps(ymm_dst, ymm_src1, ymm_src2); vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) { } else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2); vaddps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::SUB) {
vsubps(ymm_dst, ymm_src1, ymm_src2);
} }
if (with_relu_) { if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst); vmaxps(ymm_dst, ymm_zero, ymm_dst);
...@@ -85,6 +87,9 @@ void VXXJitCode::genCode() { ...@@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
case operand_type::ADD: case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2); vaddps(xmm_dst, xmm_src1, xmm_src2);
break; break;
case operand_type::SUB:
vsubps(xmm_dst, xmm_src1, xmm_src2);
break;
default: default:
break; break;
} }
...@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen; ...@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
......
...@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode { ...@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
type_(type), type_(type),
scalar_index_(scalar_index), scalar_index_(scalar_index),
with_relu_(with_relu) { with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) { if (!(type_ == operand_type::MUL || type_ == operand_type::ADD ||
type_ == operand_type::SUB)) {
LOG(FATAL) << "Do not support this operand type: " << type_; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
this->genCode(); this->genCode();
...@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode { ...@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
base += "_Mul"; base += "_Mul";
} else if (type_ == operand_type::ADD) { } else if (type_ == operand_type::ADD) {
base += "_Add"; base += "_Add";
} else if (type_ == operand_type::SUB) {
base += "_SUB";
} }
if (scalar_index_ == 2) { if (scalar_index_ == 2) {
base += "_Scalar"; base += "_Scalar";
......
...@@ -51,6 +51,7 @@ typedef enum { ...@@ -51,6 +51,7 @@ typedef enum {
SUB, SUB,
RELU, RELU,
EXP, EXP,
SQUARE,
SIGMOID, SIGMOID,
TANH, TANH,
IDENTITY IDENTITY
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册