提交 3713d08d 编写于 作者: T tensor-tang

enable jitcode gru

上级 7c1f3ad6
...@@ -22,3 +22,6 @@ USE_JITKERNEL_GEN(vsigmoid) ...@@ -22,3 +22,6 @@ USE_JITKERNEL_GEN(vsigmoid)
USE_JITKERNEL_GEN(vtanh) USE_JITKERNEL_GEN(vtanh)
USE_JITKERNEL_GEN(lstmctht) USE_JITKERNEL_GEN(lstmctht)
USE_JITKERNEL_GEN(lstmc1h1) USE_JITKERNEL_GEN(lstmc1h1)
USE_JITKERNEL_GEN(gruh1)
USE_JITKERNEL_GEN(gruhtpart1)
USE_JITKERNEL_GEN(gruhtpart2)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void GRUJitCode::genCode() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);
#undef DECLARE_GRU_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(gruh1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart2, gen::GRUHtPart2Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActJitCode {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) {
return operand_type::sigmoid;
} else if (type == KernelType::vrelu) {
return operand_type::relu;
} else if (type == KernelType::vtanh) {
return operand_type::tanh;
} else if (type == KernelType::videntity) {
return operand_type::identity;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
this->genCode();
}
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
void genCode() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE(GRUH1, 0);
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
#undef DECLARE_GRU_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册