提交 83d075aa 编写于 作者: T tensor-tang

fix lstm and gru jitcode

test=develop
上级 20392be0
...@@ -59,43 +59,12 @@ extern int g_tmp_mem[]; ...@@ -59,43 +59,12 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
class VActJitCode : public JitCode { class VActFunc : public JitCode {
public: public:
explicit VActJitCode(int d, operand_type type, size_t code_size, explicit VActFunc(size_t code_size, void* code_ptr)
void* code_ptr = nullptr) : JitCode(code_size, code_ptr) {}
: JitCode(code_size, code_ptr), num_(d), type_(type) { virtual const char* name() const = 0;
if (!(type_ == operand_type::relu || type_ == operand_type::exp || virtual void genCode() = 0;
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
type_ == operand_type::identity)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
const char* name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
void genCode() override;
protected: protected:
// compute relu with ymm, xmm // compute relu with ymm, xmm
...@@ -272,10 +241,49 @@ class VActJitCode : public JitCode { ...@@ -272,10 +241,49 @@ class VActJitCode : public JitCode {
identity_jmm<JMM>(dst, src, 15); identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
LOG(FATAL) << "Do not support this operand type: " << type_; LOG(FATAL) << "Do not support this operand type: " << type;
break; break;
} }
} }
};
class VActJitCode : public VActFunc {
public:
explicit VActJitCode(int d, operand_type type, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
type_ == operand_type::identity)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
const char* name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
void genCode() override;
protected: protected:
int num_; int num_;
......
...@@ -24,13 +24,11 @@ namespace operators { ...@@ -24,13 +24,11 @@ namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
class GRUJitCode : public VActJitCode { class GRUJitCode : public VActFunc {
public: public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size, explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
code_ptr),
id_(id) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::vsigmoid) {
return operand_type::sigmoid; return operand_type::sigmoid;
...@@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode { ...@@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode {
} }
return operand_type::identity; return operand_type::identity;
}; };
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate); act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand); act_cand_ = typeExchange(attr.act_cand);
......
...@@ -62,7 +62,9 @@ typedef enum { ...@@ -62,7 +62,9 @@ typedef enum {
class JitCode : public GenBase, public Xbyak::CodeGenerator { class JitCode : public GenBase, public Xbyak::CodeGenerator {
public: public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr) explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator((code_size < 4096 ? 4096 : code_size), code_ptr) {} : Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0; virtual const char* name() const = 0;
virtual void genCode() = 0; virtual void genCode() = 0;
......
...@@ -24,13 +24,14 @@ namespace operators { ...@@ -24,13 +24,14 @@ namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
class LSTMJitCode : public VActJitCode { class LSTMJitCode : public VActFunc {
public: public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr) size_t code_size, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, : VActFunc(code_size, code_ptr),
code_ptr), num_(attr.d),
compute_c1h1_(compute_c1h1) { compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::vsigmoid) {
return operand_type::sigmoid; return operand_type::sigmoid;
...@@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode { ...@@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode {
} }
return operand_type::identity; return operand_type::identity;
}; };
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate); act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand); act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell); act_cell_ = typeExchange(attr.act_cell);
......
...@@ -80,7 +80,7 @@ struct rnn_attr_s { ...@@ -80,7 +80,7 @@ struct rnn_attr_s {
int d; int d;
KernelType act_gate, act_cand; KernelType act_gate, act_cand;
rnn_attr_s() = default; rnn_attr_s() = default;
rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {} : d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
}; };
...@@ -88,8 +88,8 @@ struct lstm_attr_s : public rnn_attr_s { ...@@ -88,8 +88,8 @@ struct lstm_attr_s : public rnn_attr_s {
bool use_peephole; bool use_peephole;
KernelType act_cell; KernelType act_cell;
lstm_attr_s() = default; lstm_attr_s() = default;
lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false) KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand), : rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole), use_peephole(_use_peephole),
act_cell(_act_cell) {} act_cell(_act_cell) {}
...@@ -145,6 +145,8 @@ class Kernel { ...@@ -145,6 +145,8 @@ class Kernel {
template <typename KernelTuples> template <typename KernelTuples>
class KernelImpl : public Kernel { class KernelImpl : public Kernel {
// TODO(TJ): rename KernelImpl to KernelMore which seems only used in more
// and add name interface for more implements easy for debug
using T = typename KernelTuples::data_type; using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type; using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuples::attr_type;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册