diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 81a6314bd25644b427374b691004e4ea30964dac..8ad9587b5ef171d51fb06db58eff5aec6044a96b 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -22,3 +22,6 @@ USE_JITKERNEL_GEN(vsigmoid) USE_JITKERNEL_GEN(vtanh) USE_JITKERNEL_GEN(lstmctht) USE_JITKERNEL_GEN(lstmc1h1) +USE_JITKERNEL_GEN(gruh1) +USE_JITKERNEL_GEN(gruhtpart1) +USE_JITKERNEL_GEN(gruhtpart2) diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec89880a0c3b3f65146cdaeae569d9539b25399a --- /dev/null +++ b/paddle/fluid/operators/jit/gen/gru.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/gru.h" +#include // offsetof +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void GRUJitCode::genCode() { + reg64_t reg_ptr_gates = rax; + reg64_t reg_ptr_ht_1 = r9; + reg64_t reg_ptr_ht = r10; + mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); + mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); + ymm_t ymm_one = ymm_t(0); + + if (id_ == 2) { + reg64_t reg_ptr_tmp = r11; + mov(reg_ptr_tmp, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); + } + int offset = 0; + int d = num_ * sizeof(float); + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + ymm_t ymm_u = ymm_t(1); + ymm_t ymm_r = ymm_t(2); + ymm_t ymm_s = ymm_t(3); + ymm_t ymm_ht_1 = ymm_t(4); + // W: {W_update, W_reset; W_state} + if (id_ == 0 || id_ == 2) { + vmovups(ymm_u, ptr[reg_ptr_gates + offset]); + vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); + } + if (id_ == 1) { + vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); + } + if (id_ == 1 || id_ == 2) { + vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); + } + + if (id_ == 0) { + // ht = act_gate(u) * act_cand(s) + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_s); + } else if (id_ == 1) { + // ht = act_gate(r) * ht_1 + act(ymm_r, ymm_r, act_gate_); + vmulps(ymm_r, ymm_r, ymm_ht_1); + vmovups(ptr[reg_ptr_ht + offset], ymm_r); + } else if (id_ == 2) { + // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 + ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vsubps(ymm_u, ymm_one_inner, ymm_u); + vmulps(ymm_u, ymm_ht_1, ymm_u); + vaddps(ymm_u, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_u); + } + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + ret(); +} + +#define DECLARE_GRU_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool UseMe(const gru_attr_t& attr) const override { \ + return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ + } \ + size_t CodeSize(const gru_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const gru_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_GRU_CREATOR(GRUH1); +DECLARE_GRU_CREATOR(GRUHtPart1); +DECLARE_GRU_CREATOR(GRUHtPart2); + +#undef DECLARE_GRU_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(gruh1, gen::GRUH1Creator); +REGISTER_JITKERNEL_GEN(gruhtpart1, gen::GRUHtPart1Creator); +REGISTER_JITKERNEL_GEN(gruhtpart2, gen::GRUHtPart2Creator); diff --git a/paddle/fluid/operators/jit/gen/gru.h b/paddle/fluid/operators/jit/gen/gru.h new file mode 100644 index 0000000000000000000000000000000000000000..bab1c6a4eee5900994a19c73c76f457bcf5ba7c9 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/gru.h @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class GRUJitCode : public VActJitCode { + public: + explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size, + void* code_ptr = nullptr) + : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, + code_ptr), + id_(id) { + auto typeExchange = [](KernelType type) -> gen::operand_type { + if (type == KernelType::vsigmoid) { + return operand_type::sigmoid; + } else if (type == KernelType::vrelu) { + return operand_type::relu; + } else if (type == KernelType::vtanh) { + return operand_type::tanh; + } else if (type == KernelType::videntity) { + return operand_type::identity; + } else { + LOG(FATAL) << "Do not support this jit::KernelType: " << type; + } + return operand_type::identity; + }; + num_ = attr.d; + act_gate_ = typeExchange(attr.act_gate); + act_cand_ = typeExchange(attr.act_cand); + + this->genCode(); + } + + const char* name() const override { + std::string base = "GRUJitCode"; + if (id_ == 0) { + base += "_H1"; + } else if (id_ == 1) { + base += "_HtPart1"; + } else if (id_ == 2) { + base += "_HtPart2"; + } + auto AddTypeStr = [&](operand_type type) { + switch (type) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + }; + AddTypeStr(act_gate_); + AddTypeStr(act_cand_); + return base.c_str(); + } + void genCode() override; + + protected: + int id_; + int num_; + operand_type act_gate_; + operand_type act_cand_; + reg64_t param1{abi_param1}; +}; + +#define DECLARE_GRU_JITCODE(name, id) \ + class name##JitCode : public GRUJitCode { \ + public: \ + explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \ + void* code_ptr = nullptr) \ + : GRUJitCode(id, attr, code_size, code_ptr) {} \ + }; + +DECLARE_GRU_JITCODE(GRUH1, 0); +DECLARE_GRU_JITCODE(GRUHtPart1, 1); +DECLARE_GRU_JITCODE(GRUHtPart2, 2); + +#undef DECLARE_GRU_JITCODE + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle