提交 2139b9f6 编写于 作者: T tensor-tang

add jit gencode

上级 cb27a921
......@@ -76,6 +76,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas)
SRCS jit_kernel.cc jit_gen.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas gflags)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
void JitCode::preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::jit::MayIUse(platform::jit::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
void JitCode::postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void JitCode::dumpCode(const Xbyak::uint8 *code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char *>(code), getSize());
fout.close();
}
}
}
Xbyak::Address JitCode::EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast) {
int scale = 0;
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
} // namespace gen
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX);
class JitCode : public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size = 256 * 1024, void *code_ptr = nullptr)
: Xbyak::CodeGenerator(code_size, code_ptr) {}
virtual ~JitCode() {}
virtual const char *name() const = 0;
virtual void generate() = 0;
template <typename FUNC>
const FUNC getCode() {
this->generate();
const Xbyak::uint8 *code = CodeGenerator::getCode();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
return reinterpret_cast<const FUNC>(code);
}
DISABLE_COPY_AND_ASSIGN(JitCode);
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
void preCode();
void postCode();
void dumpCode(const Xbyak::uint8 *code) const;
void L(const char *label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -40,6 +40,7 @@ class Kernel {
Kernel() = default;
virtual ~Kernel() = default;
int num_{0};
// TODO(TJ): below two should be reomved.
int end_{0};
int rest_{0};
DISABLE_COPY_AND_ASSIGN(Kernel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册