提交 e3b61cf5 编写于 作者: T tensor-tang

init gru jitcode and fix lstm jitcode

test=develop
上级 0f254465
...@@ -214,6 +214,9 @@ void VActJitCode::generate() { ...@@ -214,6 +214,9 @@ void VActJitCode::generate() {
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void LSTMJitCode::generate() { void LSTMJitCode::generate() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax; reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9; reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10; reg64_t reg_ptr_ct = r10;
...@@ -224,18 +227,19 @@ void LSTMJitCode::generate() { ...@@ -224,18 +227,19 @@ void LSTMJitCode::generate() {
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
int offset = 0; int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* C_t = C_t-1 * fgated + cand_gated * igated*/ /* C_t = C_t-1 * fgated + cand_gated * igated*/
// c // c
vmovups(ymm_src, ptr[reg_ptr_gates + offset]); vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
act<ymm_t>(ymm_c, ymm_src, act_cand_); act<ymm_t>(ymm_c, ymm_src, act_cand_);
// i // i
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
act<ymm_t>(ymm_i, ymm_src, act_gate_); act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i); vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) { if (!compute_c1h1_) {
// f // f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
act<ymm_t>(ymm_f, ymm_src, act_gate_); act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]); vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
vmulps(ymm_f, ymm_f, ymm_i); vmulps(ymm_f, ymm_f, ymm_i);
...@@ -245,20 +249,36 @@ void LSTMJitCode::generate() { ...@@ -245,20 +249,36 @@ void LSTMJitCode::generate() {
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c; ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i; ymm_t ymm_tmp = ymm_i;
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_); act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]); vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
act<ymm_t>(ymm_o, ymm_src, act_gate_); act<ymm_t>(ymm_o, ymm_src, act_gate_);
vmulps(ymm_o, ymm_tmp, ymm_o); vmulps(ymm_o, ymm_tmp, ymm_o);
// save ct and ht vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK; offset += sizeof(float) * YMM_FLOAT_BLOCK;
} }
if (use_peephole_) {
postCode();
} else {
ret(); ret();
}
} }
bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void GRUJitCode::generate() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
ret();
}
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -302,6 +302,34 @@ class VActJitCode : public JitCode { ...@@ -302,6 +302,34 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 15
JMM zero = JMM(15);
if (type_ == operand_type::relu) {
vxorps(zero, zero, zero);
}
switch (type) {
case operand_type::relu:
relu_jmm<JMM>(dst, src, zero);
break;
case operand_type::exp:
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
// throw error
break;
}
}
protected: protected:
int num_; int num_;
operand_type type_; operand_type type_;
...@@ -386,44 +414,94 @@ class LSTMJitCode : public VActJitCode { ...@@ -386,44 +414,94 @@ class LSTMJitCode : public VActJitCode {
operand_type act_cand_; operand_type act_cand_;
operand_type act_cell_; operand_type act_cell_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
xmm_t xmm_src = xmm_t(0); xmm_t xmm_src = xmm_t(0);
xmm_t xmm_c = xmm_t(1); xmm_t xmm_c = xmm_t(1);
xmm_t xmm_i = xmm_t(2); xmm_t xmm_i = xmm_t(6);
xmm_t xmm_f = xmm_t(3); xmm_t xmm_f = xmm_t(7);
ymm_t ymm_src = ymm_t(0); ymm_t ymm_src = ymm_t(0);
ymm_t ymm_c = ymm_t(1); ymm_t ymm_c = ymm_t(1); // 2~5 for act
ymm_t ymm_i = ymm_t(2); ymm_t ymm_i = ymm_t(6);
ymm_t ymm_f = ymm_t(3); ymm_t ymm_f = ymm_t(7);
};
template <typename JMM> class GRUJitCode : public VActJitCode {
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT public:
// use 15 const char* name() const override {
JMM zero = JMM(15); std::string base = "GRUJitCode";
if (type_ == operand_type::relu) { if (id_ == 0) {
vxorps(zero, zero, zero); base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
} }
auto AddTypeStr = [&](operand_type type) {
switch (type) { switch (type) {
case operand_type::relu: case operand_type::relu:
relu_jmm<JMM>(dst, src, zero); base += "_Relu";
break; break;
case operand_type::exp: case operand_type::exp:
exp_jmm<JMM>(dst, src, 2, 3, 4, 5); base += "_Exp";
break; break;
case operand_type::sigmoid: case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5); base += "_Sigmoid";
break; break;
case operand_type::tanh: case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5); base += "_Tanh";
break; break;
case operand_type::identity: case operand_type::identity:
base += "_Identity";
break; break;
default: default:
// throw error
break; break;
} }
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
} }
static bool init(int d);
void generate() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
xmm_t xmm_src = xmm_t(0);
xmm_t xmm_c = xmm_t(1);
xmm_t xmm_i = xmm_t(6);
xmm_t xmm_f = xmm_t(7);
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_c = ymm_t(1);
ymm_t ymm_i = ymm_t(6);
ymm_t ymm_f = ymm_t(7);
}; };
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -40,7 +40,7 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -40,7 +40,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() { explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) { if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt = this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>(); jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
...@@ -66,7 +66,7 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -66,7 +66,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
template <> template <>
bool LSTMKernelImpl<float>::useJIT(int d) { bool LSTMKernelImpl<float>::useJIT(int d) {
return false; // not ready yet gen::LSTMJitCode::init(d); return gen::LSTMJitCode::init(d);
} }
#endif #endif
...@@ -82,7 +82,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> { ...@@ -82,7 +82,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() { explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) { if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8;
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt = this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>(); jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
...@@ -175,12 +175,42 @@ class GRUKernelImpl : public GRUKernel<T> { ...@@ -175,12 +175,42 @@ class GRUKernelImpl : public GRUKernel<T> {
static inline bool useJIT(int d) { return false; } static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; } static inline bool useMKL(int d) { return false; }
explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() { explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096));
this->ComputeH1 =
jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart1 =
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart2 =
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
return;
}
#endif
this->ComputeH1 = refer::GRUH1<T>; this->ComputeH1 = refer::GRUH1<T>;
this->ComputeHtPart1 = refer::GRUHtPart1<T>; this->ComputeHtPart1 = refer::GRUHtPart1<T>;
this->ComputeHtPart2 = refer::GRUHtPart2<T>; this->ComputeHtPart2 = refer::GRUHtPart2<T>;
} }
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::GRUJitCode> jitcode0_{nullptr}, jitcode1_{nullptr},
jitcode2_{nullptr};
#endif
}; };
#ifdef PADDLE_WITH_XBYAK
template <>
bool GRUKernelImpl<float>::useJIT(int d) {
return false; // jitcode not ready yet
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \ #define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \ template <> \
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \ std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册