提交 c2cfb03a 编写于 作者: T tensor-tang

add lstm jitcode

上级 8bc1c5d2
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace paddle {
......@@ -210,6 +211,54 @@ void VActJitCode::generate() {
ret();
}
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void LSTMJitCode::generate() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* C_t = C_t-1 * fgated + cand_gated * igated*/
// c
vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
act<ymm_t>(ymm_c, ymm_src, act_cand_);
// i
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
act<ymm_t>(ymm_i, ymm_src, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (first_) {
// f
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
act<ymm_t>(ymm_f, ymm_src, act_gate_);
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
vmulps(ymm_f, ymm_f, ymm_i);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * ogated */
ymm_t ymm_ct = first_ ? ymm_c : ymm_f;
ymm_t ymm_o = first_ ? ymm_f : ymm_c;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);
act<ymm_t>(ymm_o, ymm_src, act_gate_);
vmulps(ymm_o, ymm_tmp, ymm_o);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -46,14 +47,6 @@ extern const float exp_float_consts[];
extern const int exp_int_0x7f[];
extern int g_tmp_mem[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
......@@ -322,6 +315,99 @@ class VActJitCode : public JitCode {
ymm_t ymm_dst = ymm_t(1);
};
class LSTMJitCode : public VActJitCode {
public:
const char* name() const override {
std::string base = "LSTMJitCode";
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
if (first_) {
base += "_C1H1";
}
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(int d, bool first, operand_type act_gate,
operand_type act_cand, operand_type act_cell,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(d, act_gate, code_size, code_ptr),
num_(d),
first_(first),
act_gate_(act_gate),
act_cand_(act_cand),
act_cell_(act_cell) {}
static bool init(int d);
void generate() override;
protected:
int num_;
bool first_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
xmm_t xmm_src = xmm_t(0);
xmm_t xmm_c = xmm_t(1);
xmm_t xmm_i = xmm_t(2);
xmm_t xmm_f = xmm_t(3);
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_c = ymm_t(1);
ymm_t ymm_i = ymm_t(2);
ymm_t ymm_f = ymm_t(3);
template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 15
JMM zero = JMM(15);
if (type_ == operand_type::relu) {
vxorps(zero, zero, zero);
}
switch (type) {
case operand_type::relu:
relu_jmm<JMM>(dst, src, zero);
break;
case operand_type::exp:
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::sigmoid:
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::tanh:
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
break;
case operand_type::identity:
break;
default:
// throw error
break;
}
}
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h"
......@@ -26,14 +27,7 @@ namespace operators {
namespace math {
namespace jitkernel {
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
// TODO(TJ): remove me
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel {
......@@ -124,10 +118,13 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr,
T *checked = nullptr) const = 0;
// compute c1 and h1 without c0 or h0
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
/* below only used in peephole*/
const T *wp_data = nullptr) const = 0;
// void (*ComputeCtHt)(lstm_t *);
// // compute c1 and h1 without c0 or h0
// void (*ComputeC1H1)(lstm_t *);
};
template <typename T>
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef struct {
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
const void* ct_1;
void* ct;
void* ht;
/* below only used in peephole*/
const void* wp_data{nullptr};
void* checked{nullptr};
} lstm_t;
typedef struct {
int d;
std::string act_gate, act_cand, act_cell;
} lstm_attr_t;
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册