From c2cfb03a7277a92297b4617cb5c778bb495a998b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 20 Nov 2018 08:50:24 +0000 Subject: [PATCH] add lstm jitcode --- paddle/fluid/operators/math/jit_code.cc | 49 +++++++++ paddle/fluid/operators/math/jit_code.h | 102 ++++++++++++++++-- paddle/fluid/operators/math/jit_kernel.h | 15 ++- paddle/fluid/operators/math/jit_kernel_impl.h | 49 +++++++++ 4 files changed, 198 insertions(+), 17 deletions(-) create mode 100644 paddle/fluid/operators/math/jit_kernel_impl.h diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index e484e9a3c70..418c8433625 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_code.h" +#include // offsetof #include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me namespace paddle { @@ -210,6 +211,54 @@ void VActJitCode::generate() { ret(); } +bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } + +void LSTMJitCode::generate() { + reg64_t reg_ptr_gates = rax; + reg64_t reg_ptr_ct_1 = r9; + reg64_t reg_ptr_ct = r10; + reg64_t reg_ptr_ht = r11; + mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); + mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); + mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); + + int offset = 0; + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + // c + vmovups(ymm_src, ptr[reg_ptr_gates + offset]); + act(ymm_c, ymm_src, act_cand_); + // i + vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]); + act(ymm_i, ymm_src, act_gate_); + vmulps(ymm_c, ymm_c, ymm_i); + if (first_) { + // f + vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]); + act(ymm_f, ymm_src, act_gate_); + vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]); + vmulps(ymm_f, ymm_f, ymm_i); + vaddps(ymm_f, ymm_f, ymm_c); + } + /* H_t = act_cell(C_t) * ogated */ + ymm_t ymm_ct = first_ ? ymm_c : ymm_f; + ymm_t ymm_o = first_ ? ymm_f : ymm_c; + ymm_t ymm_tmp = ymm_i; + act(ymm_tmp, ymm_ct, act_cell_); + vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]); + act(ymm_o, ymm_src, act_gate_); + vmulps(ymm_o, ymm_tmp, ymm_o); + // save ct and ht + vmovups(ptr[reg_ptr_ct + offset], ymm_ct); + vmovups(ptr[reg_ptr_ht + offset], ymm_o); + + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + + ret(); +} + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 65f83ff4846..938b5525c1c 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/jit_gen.h" +#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/platform/cpu_info.h" namespace paddle { @@ -46,14 +47,6 @@ extern const float exp_float_consts[]; extern const int exp_int_0x7f[]; extern int g_tmp_mem[]; -// TODO(TJ): move these to some proper place -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 -#define XMM_FLOAT_BLOCK 4 -#define YMM_FLOAT_BLOCK 8 -#define ZMM_FLOAT_BLOCK 16 - #define ALIGN32 __attribute__((aligned(32))) #define EXP_HIG 88.3762626647949f #define EXP_LOW -88.3762626647949f @@ -322,6 +315,99 @@ class VActJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +class LSTMJitCode : public VActJitCode { + public: + const char* name() const override { + std::string base = "LSTMJitCode"; + auto AddTypeStr = [&](operand_type type) { + switch (type) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + }; + if (first_) { + base += "_C1H1"; + } + AddTypeStr(act_gate_); + AddTypeStr(act_cand_); + AddTypeStr(act_cell_); + return base.c_str(); + } + + explicit LSTMJitCode(int d, bool first, operand_type act_gate, + operand_type act_cand, operand_type act_cell, + size_t code_size = 256 * 1024, void* code_ptr = nullptr) + : VActJitCode(d, act_gate, code_size, code_ptr), + num_(d), + first_(first), + act_gate_(act_gate), + act_cand_(act_cand), + act_cell_(act_cell) {} + static bool init(int d); + void generate() override; + + protected: + int num_; + bool first_; + operand_type act_gate_; + operand_type act_cand_; + operand_type act_cell_; + reg64_t param1{abi_param1}; + + xmm_t xmm_src = xmm_t(0); + xmm_t xmm_c = xmm_t(1); + xmm_t xmm_i = xmm_t(2); + xmm_t xmm_f = xmm_t(3); + + ymm_t ymm_src = ymm_t(0); + ymm_t ymm_c = ymm_t(1); + ymm_t ymm_i = ymm_t(2); + ymm_t ymm_f = ymm_t(3); + + template + void act(JMM& dst, JMM& src, operand_type type) { // NOLINT + // use 15 + JMM zero = JMM(15); + if (type_ == operand_type::relu) { + vxorps(zero, zero, zero); + } + switch (type) { + case operand_type::relu: + relu_jmm(dst, src, zero); + break; + case operand_type::exp: + exp_jmm(dst, src, 2, 3, 4, 5); + break; + case operand_type::sigmoid: + sigmoid_jmm(dst, src, 2, 3, 4, 5); + break; + case operand_type::tanh: + tanh_jmm(dst, src, 2, 3, 4, 5); + break; + case operand_type::identity: + break; + default: + // throw error + break; + } + } +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7e163c1349e..b5e54fcc1b8 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #include // for shared_ptr #include #include +#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/macros.h" @@ -26,14 +27,7 @@ namespace operators { namespace math { namespace jitkernel { -// TODO(TJ): move these to some proper place -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 -#define XMM_FLOAT_BLOCK 4 -#define YMM_FLOAT_BLOCK 8 -#define ZMM_FLOAT_BLOCK 16 - +// TODO(TJ): remove me typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; class Kernel { @@ -124,10 +118,13 @@ class LSTMKernel : public Kernel { const T *wp_data = nullptr, T *checked = nullptr) const = 0; - // compute c1 and h1 without c0 or h0 virtual void ComputeC1H1(T *gates, T *ct, T *ht, /* below only used in peephole*/ const T *wp_data = nullptr) const = 0; + + // void (*ComputeCtHt)(lstm_t *); + // // compute c1 and h1 without c0 or h0 + // void (*ComputeC1H1)(lstm_t *); }; template diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h new file mode 100644 index 00000000000..fcb6a7c0971 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_impl.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 + +typedef struct { + void* gates; // gates: W_ch, W_ih, W_fh, W_oh + const void* ct_1; + void* ct; + void* ht; + /* below only used in peephole*/ + const void* wp_data{nullptr}; + void* checked{nullptr}; +} lstm_t; + +typedef struct { + int d; + std::string act_gate, act_cand, act_cell; +} lstm_attr_t; + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle -- GitLab