diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 98d9231faa6afb58d2cd879e7b3d37f843cffc72..ef74a7118be59658dc03c04db64cea9c69b8166d 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -10,3 +10,8 @@ endfunction() # use gen jitcode kernel by name USE_JITKERNEL_GEN(vmul) +USE_JITKERNEL_GEN(vadd) +#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me +USE_JITKERNEL_GEN(vaddrelu) +USE_JITKERNEL_GEN(vscal) +USE_JITKERNEL_GEN(vaddbias) diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index 3e5ce540647fa83fcfc43d74fd3e9c7c5ead0d8f..b24f44c9f3ba9ae23808561e8a5528e8cb0b6447 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -104,18 +104,28 @@ void VXXJitCode::genCode() { ret(); } -class VMulCreator : public JitCodeCreator { - public: - bool UseMe(const int& attr) const override { - return platform::MayIUse(platform::avx); +#define DECLARE_BLAS_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + bool UseMe(const int& attr) const override { \ + return platform::MayIUse(platform::avx); \ + } \ + size_t CodeSize(const int& d) const override { \ + return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ + } \ + std::unique_ptr CreateJitCode(const int& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ } - size_t CodeSize(const int& d) const override { - return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - } - std::unique_ptr CreateJitCode(const int& attr) const override { - return make_unique(attr, CodeSize(attr)); - } -}; + +DECLARE_BLAS_CREATOR(VMul); +DECLARE_BLAS_CREATOR(VAdd); +DECLARE_BLAS_CREATOR(VSub); +DECLARE_BLAS_CREATOR(VAddRelu); +DECLARE_BLAS_CREATOR(VScal); +DECLARE_BLAS_CREATOR(VAddBias); + +#undef DECLARE_BLAS_CREATOR } // namespace gen } // namespace jit @@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator { namespace gen = paddle::operators::jit::gen; REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator); +REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator); +// TODO(TJ): enable sub +// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator); +REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator); +REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator); +REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index 60f328056787982ecce5ca1928d44621ed6ff06a..5a2192052f8315f6f0eb92a9796ead30fc5071d6 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -15,6 +15,7 @@ #pragma once #include +#include "glog/logging.h" #include "paddle/fluid/operators/jit/gen/jitcode.h" namespace paddle { @@ -33,6 +34,9 @@ class VXXJitCode : public JitCode { type_(type), scalar_index_(scalar_index), with_relu_(with_relu) { + if (!(type_ == operand_type::mul || type_ == operand_type::add)) { + LOG(FATAL) << "Do not support this operand type: " << type_; + } this->genCode(); } @@ -78,11 +82,22 @@ class VXXJitCode : public JitCode { ymm_t ymm_zero = ymm_t(3); }; -class VMulJitCode : public VXXJitCode { - public: - explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr) - : VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {} -}; +#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \ + class name##JitCode : public VXXJitCode { \ + public: \ + explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \ + : VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \ + } \ + }; + +DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false); +DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false); +DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false); +DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true); +DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false); +DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false); + +#undef DECLARE_BLAS_JITCODE } // namespace gen } // namespace jit