提交 fd0a954f 编写于 作者: T tensor-tang

enable blas jitcode vmul, vadd, vaddrelu, vscal and vaddbias

上级 5e97be7b
...@@ -10,3 +10,8 @@ endfunction() ...@@ -10,3 +10,8 @@ endfunction()
# use gen jitcode kernel by name # use gen jitcode kernel by name
USE_JITKERNEL_GEN(vmul) USE_JITKERNEL_GEN(vmul)
USE_JITKERNEL_GEN(vadd)
#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(vaddrelu)
USE_JITKERNEL_GEN(vscal)
USE_JITKERNEL_GEN(vaddbias)
...@@ -104,18 +104,28 @@ void VXXJitCode::genCode() { ...@@ -104,18 +104,28 @@ void VXXJitCode::genCode() {
ret(); ret();
} }
class VMulCreator : public JitCodeCreator<int> { #define DECLARE_BLAS_CREATOR(name) \
public: class name##Creator : public JitCodeCreator<int> { \
bool UseMe(const int& attr) const override { public: \
return platform::MayIUse(platform::avx); bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
} }
size_t CodeSize(const int& d) const override {
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; DECLARE_BLAS_CREATOR(VMul);
} DECLARE_BLAS_CREATOR(VAdd);
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { DECLARE_BLAS_CREATOR(VSub);
return make_unique<VMulJitCode>(attr, CodeSize(attr)); DECLARE_BLAS_CREATOR(VAddRelu);
} DECLARE_BLAS_CREATOR(VScal);
}; DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
...@@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> { ...@@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> {
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator);
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle { namespace paddle {
...@@ -33,6 +34,9 @@ class VXXJitCode : public JitCode { ...@@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
type_(type), type_(type),
scalar_index_(scalar_index), scalar_index_(scalar_index),
with_relu_(with_relu) { with_relu_(with_relu) {
if (!(type_ == operand_type::mul || type_ == operand_type::add)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode(); this->genCode();
} }
...@@ -78,11 +82,22 @@ class VXXJitCode : public JitCode { ...@@ -78,11 +82,22 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3); ymm_t ymm_zero = ymm_t(3);
}; };
class VMulJitCode : public VXXJitCode { #define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
public: class name##JitCode : public VXXJitCode { \
explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr) public: \
: VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {} explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
}; : VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
#undef DECLARE_BLAS_JITCODE
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册