提交 fd0a954f 编写于 作者: T tensor-tang

enable blas jitcode vmul, vadd, vaddrelu, vscal and vaddbias

上级 5e97be7b
......@@ -10,3 +10,8 @@ endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(vmul)
USE_JITKERNEL_GEN(vadd)
#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(vaddrelu)
USE_JITKERNEL_GEN(vscal)
USE_JITKERNEL_GEN(vaddbias)
......@@ -104,18 +104,28 @@ void VXXJitCode::genCode() {
ret();
}
class VMulCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx);
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
size_t CodeSize(const int& d) const override {
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<VMulJitCode>(attr, CodeSize(attr));
}
};
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
......@@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> {
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator);
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
......@@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::mul || type_ == operand_type::add)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
......@@ -78,11 +82,22 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3);
};
class VMulJitCode : public VXXJitCode {
public:
explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr)
: VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {}
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
#undef DECLARE_BLAS_JITCODE
} // namespace gen
} // namespace jit
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册