diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 2b8c758a032fd7edff0d4b7e23bd8e685eb3ab15..40310c2d2b372a414054f75348e8e1b4471bf3d2 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -11,11 +11,12 @@ endfunction() # use gen jitcode kernel by name USE_JITKERNEL_GEN(kVMul) USE_JITKERNEL_GEN(kVAdd) -#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me +USE_JITKERNEL_GEN(kVSub) USE_JITKERNEL_GEN(kVAddRelu) USE_JITKERNEL_GEN(kVScal) USE_JITKERNEL_GEN(kVAddBias) USE_JITKERNEL_GEN(kVRelu) +USE_JITKERNEL_GEN(kVSquare) USE_JITKERNEL_GEN(kVIdentity) USE_JITKERNEL_GEN(kVExp) USE_JITKERNEL_GEN(kVSigmoid) diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index 3ea076f217dc7c8a755055d3f48c22b7a3627012..a2a5661b93ad3d885983c502566860aa313d110f 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -91,6 +91,7 @@ void VActJitCode::genCode() { } DECLARE_ACT_CREATOR(VRelu); +DECLARE_ACT_CREATOR(VSquare); DECLARE_ACT_CREATOR(VIdentity); DECLARE_ACT_CREATOR(VExp); DECLARE_ACT_CREATOR(VSigmoid); @@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const { 8 /* average bytes for each instruction */; } +size_t VSquareCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; +} + size_t VIdentityCreator::CodeSize(const int& d) const { return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; } @@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const { namespace gen = paddle::operators::jit::gen; REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); +REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator); REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); diff --git a/paddle/fluid/operators/jit/gen/act.h b/paddle/fluid/operators/jit/gen/act.h index 81503c42ab5cd46961378847584f68f2cbed0ed5..68e66f9298c4eafabb55c20195d46fed800f4ec4 100644 --- a/paddle/fluid/operators/jit/gen/act.h +++ b/paddle/fluid/operators/jit/gen/act.h @@ -75,6 +75,12 @@ class VActFunc : public JitCode { vmaxps(dst, src, zero); } + // compute SQUARE with ymm, xmm + template + void square_jmm(JMM& dst, JMM& src) { // NOLINT + vmulps(dst, src, src); + } + // compute EXP with ymm, xmm template void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT @@ -228,6 +234,9 @@ class VActFunc : public JitCode { case operand_type::RELU: relu_jmm(dst, src, 15); break; + case operand_type::SQUARE: + square_jmm(dst, src); + break; case operand_type::EXP: exp_jmm(dst, src, 11, 12, 13, 14, 15); break; @@ -254,7 +263,7 @@ class VActJitCode : public VActFunc { : VActFunc(code_size, code_ptr), num_(d), type_(type) { if (!(type_ == operand_type::RELU || type_ == operand_type::EXP || type_ == operand_type::SIGMOID || type_ == operand_type::TANH || - type_ == operand_type::IDENTITY)) { + type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) { LOG(FATAL) << "Do not support this operand type: " << type_; } this->genCode(); @@ -266,6 +275,9 @@ class VActJitCode : public VActFunc { case operand_type::RELU: base += "_Relu"; break; + case operand_type::SQUARE: + base += "_Square"; + break; case operand_type::EXP: base += "_Exp"; break; @@ -306,6 +318,7 @@ class VActJitCode : public VActFunc { }; DECLARE_ACT_JITCODE(VRelu, operand_type::RELU); +DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY); DECLARE_ACT_JITCODE(VExp, operand_type::EXP); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID); diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index c1198773088faa594bac0714dd8449b240b3ce4d..dee6c7b9d3ee9756c1b11d10d55fdca341cbee85 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -43,6 +43,8 @@ void VXXJitCode::genCode() { vmulps(ymm_dst, ymm_src1, ymm_src2); } else if (type_ == operand_type::ADD) { vaddps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::SUB) { + vsubps(ymm_dst, ymm_src1, ymm_src2); } if (with_relu_) { vmaxps(ymm_dst, ymm_zero, ymm_dst); @@ -85,6 +87,9 @@ void VXXJitCode::genCode() { case operand_type::ADD: vaddps(xmm_dst, xmm_src1, xmm_src2); break; + case operand_type::SUB: + vsubps(xmm_dst, xmm_src1, xmm_src2); + break; default: break; } @@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen; REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); -// TODO(TJ): enable sub -// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); +REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index c46ec15fb788c0c7a90cfc8732aad375a9e226a1..de6b33f467279124d7acd97709516c31706ec4f9 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -34,7 +34,8 @@ class VXXJitCode : public JitCode { type_(type), scalar_index_(scalar_index), with_relu_(with_relu) { - if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) { + if (!(type_ == operand_type::MUL || type_ == operand_type::ADD || + type_ == operand_type::SUB)) { LOG(FATAL) << "Do not support this operand type: " << type_; } this->genCode(); @@ -51,6 +52,8 @@ class VXXJitCode : public JitCode { base += "_Mul"; } else if (type_ == operand_type::ADD) { base += "_Add"; + } else if (type_ == operand_type::SUB) { + base += "_SUB"; } if (scalar_index_ == 2) { base += "_Scalar"; diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 5b7234c1cb5d15d290685a3dceb3b757be1ef0c6..f63d40ad5a559ab87a9b3735406671cfd936d9e4 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -51,6 +51,7 @@ typedef enum { SUB, RELU, EXP, + SQUARE, SIGMOID, TANH, IDENTITY