From 7c1f3ad6eb18d40310e8933a937ef83b3342a532 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 14 Dec 2018 14:13:44 +0000 Subject: [PATCH] enable jitcode lstm --- paddle/fluid/operators/jit/README.md | 2 +- paddle/fluid/operators/jit/gen/CMakeLists.txt | 2 + paddle/fluid/operators/jit/gen/jitcode.h | 2 +- paddle/fluid/operators/jit/gen/lstm.cc | 142 ++++++++++++++++++ paddle/fluid/operators/jit/gen/lstm.h | 119 +++++++++++++++ paddle/fluid/operators/jit/test.cc | 10 +- 6 files changed, 268 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/jit/gen/lstm.cc create mode 100644 paddle/fluid/operators/jit/gen/lstm.h diff --git a/paddle/fluid/operators/jit/README.md b/paddle/fluid/operators/jit/README.md index 28d21f40af3..ce31f18b63c 100644 --- a/paddle/fluid/operators/jit/README.md +++ b/paddle/fluid/operators/jit/README.md @@ -46,7 +46,7 @@ PaddlePaddle/Paddle/paddle/fluid/ - 在`KernelType` 中添加 `your_key` . - 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt`中`USE_JITKERNEL_REFER(your_key)`. - (optional) 实现更多的算法在`more`目录下,可以依赖mkl,openblas,或者mkldnn等第三方库。 -- (optional) 实现基于Xbyak的生成code,在`gen`目下。 +- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在KernelType上。 - 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。 - 添加unit test,需要测试float和double - 添加benchmark确保get得到的速度是最快。 diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 2be750a4d86..81a6314bd25 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -20,3 +20,5 @@ USE_JITKERNEL_GEN(videntity) USE_JITKERNEL_GEN(vexp) USE_JITKERNEL_GEN(vsigmoid) USE_JITKERNEL_GEN(vtanh) +USE_JITKERNEL_GEN(lstmctht) +USE_JITKERNEL_GEN(lstmc1h1) diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 64126e3f61a..898d7df3451 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -62,7 +62,7 @@ typedef enum { class JitCode : public GenBase, public Xbyak::CodeGenerator { public: explicit JitCode(size_t code_size, void* code_ptr = nullptr) - : Xbyak::CodeGenerator(code_size, code_ptr) {} + : Xbyak::CodeGenerator((code_size < 4096 ? 4096 : code_size), code_ptr) {} virtual const char* name() const = 0; virtual void genCode() = 0; diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc new file mode 100644 index 00000000000..7e5a7773f83 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/lstm.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/lstm.h" +#include // offsetof +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void LSTMJitCode::genCode() { + if (use_peephole_) { + preCode(); + } + reg64_t reg_ptr_gates = rax; + reg64_t reg_ptr_ct_1 = r9; + reg64_t reg_ptr_ct = r10; + reg64_t reg_ptr_ht = r11; + reg64_t reg_ptr_wp = r12; + mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); + mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); + mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); + if (use_peephole_) { + mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]); + } + + int offset = 0; + int d = num_ * sizeof(float); + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + /* gates: W_ch, W_ih, W_fh, W_oh */ + ymm_t ymm_c = ymm_t(0); + ymm_t ymm_i = ymm_t(1); + ymm_t ymm_f = ymm_t(2); + ymm_t ymm_o = ymm_t(3); + ymm_t ymm_ct_1 = ymm_t(4); + ymm_t ymm_wp0 = ymm_t(5); + ymm_t ymm_wp1 = ymm_t(6); + ymm_t ymm_wp2 = ymm_t(7); + vmovups(ymm_c, ptr[reg_ptr_gates + offset]); + vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); + vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); + vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); + if (!compute_c1h1_) { + vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); + } + if (use_peephole_) { + vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); + vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); + vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); + } + /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ + // act_cand(c) + act(ymm_c, ymm_c, act_cand_); + // act_gate(i) or act_gate(ct_1 * wp0 + i) + if (!compute_c1h1_ && use_peephole_) { + vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); + vaddps(ymm_i, ymm_i, ymm_wp0); + } + act(ymm_i, ymm_i, act_gate_); + vmulps(ymm_c, ymm_c, ymm_i); + if (!compute_c1h1_) { + // act_gate(f) or act_gate(ct_1 * wp1 + f) + if (use_peephole_) { + vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); + vaddps(ymm_f, ymm_f, ymm_wp1); + } + act(ymm_f, ymm_f, act_gate_); + // ct + vmulps(ymm_f, ymm_f, ymm_ct_1); + vaddps(ymm_f, ymm_f, ymm_c); + } + /* H_t = act_cell(C_t) * act_gate(o) */ + // act_cell(C_t) + ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; + ymm_t ymm_tmp = ymm_i; + act(ymm_tmp, ymm_ct, act_cell_); + // act_gate(o) or act_gate(ct * wp2 + o) + if (use_peephole_) { + vmulps(ymm_wp2, ymm_ct, ymm_wp2); + vaddps(ymm_o, ymm_o, ymm_wp2); + } + act(ymm_o, ymm_o, act_gate_); + // ht + vmulps(ymm_o, ymm_o, ymm_tmp); + // save ct and ht + vmovups(ptr[reg_ptr_ct + offset], ymm_ct); + vmovups(ptr[reg_ptr_ht + offset], ymm_o); + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + + if (use_peephole_) { + postCode(); + } else { + ret(); + } +} + +#define DECLARE_LSTM_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool UseMe(const lstm_attr_t& attr) const override { \ + return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ + } \ + size_t CodeSize(const lstm_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const lstm_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_LSTM_CREATOR(LSTMCtHt); +DECLARE_LSTM_CREATOR(LSTMC1H1); + +#undef DECLARE_LSTM_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(lstmctht, gen::LSTMCtHtCreator); +REGISTER_JITKERNEL_GEN(lstmc1h1, gen::LSTMC1H1Creator); diff --git a/paddle/fluid/operators/jit/gen/lstm.h b/paddle/fluid/operators/jit/gen/lstm.h new file mode 100644 index 00000000000..cb8705c6d95 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/lstm.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class LSTMJitCode : public VActJitCode { + public: + explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, + size_t code_size, void* code_ptr = nullptr) + : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, + code_ptr), + compute_c1h1_(compute_c1h1) { + auto typeExchange = [](KernelType type) -> gen::operand_type { + if (type == KernelType::vsigmoid) { + return operand_type::sigmoid; + } else if (type == KernelType::vrelu) { + return operand_type::relu; + } else if (type == KernelType::vtanh) { + return operand_type::tanh; + } else if (type == KernelType::videntity) { + return operand_type::identity; + } else { + LOG(FATAL) << "Do not support this jit::KernelType: " << type; + } + return operand_type::identity; + }; + num_ = attr.d; + use_peephole_ = attr.use_peephole; + act_gate_ = typeExchange(attr.act_gate); + act_cand_ = typeExchange(attr.act_cand); + act_cell_ = typeExchange(attr.act_cell); + + this->genCode(); + } + + const char* name() const override { + std::string base = "LSTMJitCode"; + if (use_peephole_) { + base += "_Peephole"; + } + if (compute_c1h1_) { + base += "_C1H1"; + } + auto AddTypeStr = [&](operand_type type) { + switch (type) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + }; + AddTypeStr(act_gate_); + AddTypeStr(act_cand_); + AddTypeStr(act_cell_); + return base.c_str(); + } + void genCode() override; + + protected: + int num_; + bool compute_c1h1_; + bool use_peephole_; + operand_type act_gate_; + operand_type act_cand_; + operand_type act_cell_; + reg64_t param1{abi_param1}; +}; + +#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \ + class name##JitCode : public LSTMJitCode { \ + public: \ + explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \ + void* code_ptr = nullptr) \ + : LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \ + }; + +DECLARE_LSTM_JITCODE(LSTMCtHt, false); +DECLARE_LSTM_JITCODE(LSTMC1H1, true); + +#undef DECLARE_LSTM_JITCODE + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index e211276d189..36f8eb6e7b6 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -236,7 +236,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { } } // test result from Get function - VLOG(10) << "Test Get function "; + // VLOG(10) << "Test Get function "; auto tgt = jit::Get(attr); test(tgt, args...); } @@ -338,9 +338,6 @@ void TestLSTMKernel() { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { for (auto& act_cell : all_acts) { - std::string info = act_gate + act_cand + act_cell + - (use_peephole ? "peephole_" : "") + "size_" + - std::to_string(d); const jit::lstm_attr_t attr( d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), jit::to_kerneltype(act_cell), use_peephole); @@ -370,7 +367,7 @@ void TestLSTMKernel() { step.checked = checked_data; } ref(&step, &attr); - + VLOG(10) << attr; TestAllImpls, PlaceType, std::vector, std::vector, std::vector, std::vector, std::vector>(attr, xsrc, wp, ct_1, ct_ref, ht_ref, @@ -390,7 +387,6 @@ void TestGRUKernel() { for (int d : TestSizes()) { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { - std::string info = act_gate + act_cand + "size_" + std::to_string(d); const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand)); auto ref = jit::GetRefer>(); @@ -409,7 +405,7 @@ void TestGRUKernel() { step.ht_1 = ht_1_data; step.ht = ht_ref_data; ref(&step, &attr); - + VLOG(10) << attr; TestAllImpls, PlaceType, std::vector, std::vector, std::vector>(attr, xsrc, ht_1, ht_ref, attr); -- GitLab