提交 83d075aa 编写于 作者: T tensor-tang

fix lstm and gru jitcode

test=develop
上级 20392be0
......@@ -59,43 +59,12 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
class VActJitCode : public JitCode {
class VActFunc : public JitCode {
public:
explicit VActJitCode(int d, operand_type type, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
type_ == operand_type::identity)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
const char* name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
void genCode() override;
explicit VActFunc(size_t code_size, void* code_ptr)
: JitCode(code_size, code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
protected:
// compute relu with ymm, xmm
......@@ -272,10 +241,49 @@ class VActJitCode : public JitCode {
identity_jmm<JMM>(dst, src, 15);
break;
default:
LOG(FATAL) << "Do not support this operand type: " << type_;
LOG(FATAL) << "Do not support this operand type: " << type;
break;
}
}
};
class VActJitCode : public VActFunc {
public:
explicit VActJitCode(int d, operand_type type, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::relu || type_ == operand_type::exp ||
type_ == operand_type::sigmoid || type_ == operand_type::tanh ||
type_ == operand_type::identity)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
const char* name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
void genCode() override;
protected:
int num_;
......
......@@ -24,13 +24,11 @@ namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActJitCode {
class GRUJitCode : public VActFunc {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) {
return operand_type::sigmoid;
......@@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode {
}
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
......
......@@ -62,7 +62,9 @@ typedef enum {
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator((code_size < 4096 ? 4096 : code_size), code_ptr) {}
: Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
......
......@@ -24,13 +24,14 @@ namespace operators {
namespace jit {
namespace gen {
class LSTMJitCode : public VActJitCode {
class LSTMJitCode : public VActFunc {
public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
: VActFunc(code_size, code_ptr),
num_(attr.d),
compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) {
return operand_type::sigmoid;
......@@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode {
}
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
......
......@@ -80,7 +80,7 @@ struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
......@@ -88,8 +88,8 @@ struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
......@@ -145,6 +145,8 @@ class Kernel {
template <typename KernelTuples>
class KernelImpl : public Kernel {
// TODO(TJ): rename KernelImpl to KernelMore which seems only used in more
// and add name interface for more implements easy for debug
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册