diff --git a/paddle/fluid/operators/jitkernels/CMakeLists.txt b/paddle/fluid/operators/jitkernels/CMakeLists.txt index 6392d82e16d3e8eb231543678f2fb3579f7636ec..e82e6c3026f936bb96178710751002d297bf3b41 100644 --- a/paddle/fluid/operators/jitkernels/CMakeLists.txt +++ b/paddle/fluid/operators/jitkernels/CMakeLists.txt @@ -7,7 +7,7 @@ set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) -cc_library(jit_kernel_base SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS}) +cc_library(jit_kernel_base SRCS kernels.cc jitcode_base.cc DEPS ${JIT_KERNEL_DEPS}) add_subdirectory(refer) add_subdirectory(more) diff --git a/paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt b/paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt index 1a5e457309e9131dd4fe45e6e8ba2cb94241a39a..c678ea33b8e0043dade1c83523032e8e6e9e59d0 100644 --- a/paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt +++ b/paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt @@ -1,3 +1,5 @@ -cc_library(jit_kernel_jitcode SRCS jitcode.cc DEPS jit_kernel_base xbyak) +file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") + +cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE) diff --git a/paddle/fluid/operators/jitkernels/jitcode/blas.cc b/paddle/fluid/operators/jitkernels/jitcode/blas.cc new file mode 100644 index 0000000000000000000000000000000000000000..2691bee0fdf1669535540dfc5f217c55fa4aca60 --- /dev/null +++ b/paddle/fluid/operators/jitkernels/jitcode/blas.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ +#include "paddle/fluid/operators/jitkernels/jitcode/blas.h" +#include "paddle/fluid/operators/jitkernels/registry.h" + +namespace paddle { +namespace operators { +namespace jitkernels { +namespace jitcode { + +void VXXJitCode::genCode() { + // do not need push stack, and do not need save avx512reg if do not use avx512 + int offset = 0; + if (with_relu_) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } + if (scalar_index_ == 1) { + vbroadcastss(ymm_src1, ptr[param1]); + } else if (scalar_index_ == 2) { + vbroadcastss(ymm_src2, ptr[param2]); + } + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + if (scalar_index_ != 1) { + vmovups(ymm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(ymm_src2, ptr[param2 + offset]); + } + if (type_ == operand_type::mul) { + vmulps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::add) { + vaddps(ymm_dst, ymm_src1, ymm_src2); + } + if (with_relu_) { + vmaxps(ymm_dst, ymm_zero, ymm_dst); + } + vmovups(ptr[param3 + offset], ymm_dst); + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + int rest = num_ % YMM_FLOAT_BLOCK; + while (rest > 0) { + int block = XMM_FLOAT_BLOCK; + if (rest >= 4) { + block = 4; + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } + } else if (rest >= 2) { + block = 2; + if (scalar_index_ != 1) { + vmovq(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + } + } else { + block = 1; + if (scalar_index_ != 1) { + vmovss(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovss(xmm_src2, ptr[param2 + offset]); + } + } + switch (type_) { + case operand_type::mul: + vmulps(xmm_dst, xmm_src1, xmm_src2); + break; + case operand_type::add: + vaddps(xmm_dst, xmm_src1, xmm_src2); + break; + default: + break; + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } + if (rest >= 4) { + vmovups(ptr[param3 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param3 + offset], xmm_dst); + } else { + vmovss(ptr[param3 + offset], xmm_dst); + } + offset += sizeof(float) * block; + rest -= block; + } + ret(); +} + +} // namespace jitcode + +template <> +std::unique_ptr CreateJitCode(int attr) { + if (UseJitCode(attr)) { + return make_unique( + attr, CodeSize(attr)); + } + return nullptr; +} + +} // namespace jitkernels +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/jitcode/blas.h b/paddle/fluid/operators/jitkernels/jitcode/blas.h new file mode 100644 index 0000000000000000000000000000000000000000..a1aca97723e75eed3b3fefe9bf8471d4326bc812 --- /dev/null +++ b/paddle/fluid/operators/jitkernels/jitcode/blas.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h" + +namespace paddle { +namespace operators { +namespace jitkernels { +namespace jitcode { + +// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) +class VXXJitCode : public JitCode { + public: + const char* name() const override { + std::string base = "VXXJitCode"; + if (scalar_index_ == 1) { + base += "_Scalar"; + } else { + base += "_Vec"; + } + if (type_ == operand_type::mul) { + base += "_Mul"; + } else if (type_ == operand_type::add) { + base += "_Add"; + } + if (scalar_index_ == 2) { + base += "_Scalar"; + } else { + base += "_Vec"; + } + base += (with_relu_ ? "_Relu" : ""); + return base.c_str(); + } + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + scalar_index_(scalar_index), + with_relu_(with_relu) {} + // static bool init(int d, int scalar_index = 0); + void genCode() override; + + private: + int num_; + operand_type type_; + int scalar_index_; + bool with_relu_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + + xmm_t xmm_src1 = xmm_t(0); + xmm_t xmm_src2 = xmm_t(1); + xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_zero = xmm_t(3); + + ymm_t ymm_src1 = ymm_t(0); + ymm_t ymm_src2 = ymm_t(1); + ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); +}; + +class VMulJitCode : public VXXJitCode { + public: + explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr) + : VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {} +}; + +} // namespace jitcode +} // namespace jitkernels +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jitkernels/jitcode/jitcode.h b/paddle/fluid/operators/jitkernels/jitcode/jitcode.h index 7e0b6442edd239b9d3165f9ade4ee728c378a49c..a3582e5284c84ab97cdde3caa1bdc07b0ddc4ac9 100644 --- a/paddle/fluid/operators/jitkernels/jitcode/jitcode.h +++ b/paddle/fluid/operators/jitkernels/jitcode/jitcode.h @@ -16,7 +16,7 @@ #include #include "paddle/fluid/operators/jitkernels/jitcode_base.h" -#include "paddle/fluid/operators/jitkernels/kernels.h" +#include "paddle/fluid/platform/cpu_info.h" #define XBYAK_USE_MMAP_ALLOCATOR #include "xbyak/xbyak.h" @@ -30,23 +30,102 @@ namespace jitcode { // Application Binary Interface constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), - abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX); + abi_param4(Xbyak::Operand::RCX); -template -class VMulJitCode : public JitBase, public Xbyak::CodeGenerator { +constexpr Xbyak::Operand::Code g_abi_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15}; + +constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]); + +using reg64_t = const Xbyak::Reg64; +using reg32_t = const Xbyak::Reg32; +using xmm_t = const Xbyak::Xmm; +using ymm_t = const Xbyak::Ymm; +using zmm_t = const Xbyak::Zmm; +using Label = Xbyak::Label; + +typedef enum { + mul = 0, + add, + sub, + relu, + exp, + sigmoid, + tanh, + identity +} operand_type; + +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +#define DECLARE_JIT_CODE(codename) \ + const char* name() const override { return #codename; } + +class JitCode : public JitBase, public Xbyak::CodeGenerator { public: - VMulJitCode(Attr attr, size_t code_size, void* code_ptr = nullptr) + explicit JitCode(size_t code_size, void* code_ptr = nullptr) : Xbyak::CodeGenerator(code_size, code_ptr) { this->genCode(); } - virtual const char* name() const = 0; - virtual void genCode() = 0; - + size_t getSize() const override { return CodeGenerator::getSize(); } const unsigned char* getCodeInternal() override { const Xbyak::uint8* code = CodeGenerator::getCode(); return code; } + + virtual const char* name() const = 0; + virtual void genCode() = 0; + + protected: + Xbyak::Reg64 param1{abi_param1}; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + virtual void preCode() { + for (int i = 0; i < num_g_abi_regs; ++i) { + push(Xbyak::Reg64(g_abi_regs[i])); + } + if (platform::jit::MayIUse(platform::jit::avx512f)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + virtual void postCode() { + for (int i = 0; i < num_g_abi_regs; ++i) { + pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i])); + } + ret(); + } + void L(const char* label) { Xbyak::CodeGenerator::L(label); } + void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + // Enhanced vector extension + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt, + bool bcast = false) { + int scale = 0; + // Learn from https://github.com/intel/mkl-dnn + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + auto re = Xbyak::RegExp() + base + offt; + if (scale) { + re = re + reg_EVEX_max_8b_offt * scale; + } + if (bcast) { + return zword_b[re]; + } else { + return zword[re]; + } + } }; } // namespace jitcode diff --git a/paddle/fluid/operators/jitkernels/jitcode_base.cc b/paddle/fluid/operators/jitkernels/jitcode_base.cc index 417c4d4b9e25edb8c38eaedbf7531e8455f5dbf8..1da2af51f410b3d296324bac6bd9b00f9b31bbbc 100644 --- a/paddle/fluid/operators/jitkernels/jitcode_base.cc +++ b/paddle/fluid/operators/jitkernels/jitcode_base.cc @@ -13,6 +13,9 @@ * limitations under the License. */ #include "paddle/fluid/operators/jitkernels/jitcode_base.h" +#include +#include +#include DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); @@ -29,7 +32,7 @@ void JitBase::dumpCode(const unsigned char* code) const { counter++; std::ofstream fout(filename.str(), std::ios::out); if (fout.is_open()) { - fout.write(reinterpret_cast(code), getSize()); + fout.write(reinterpret_cast(code), this->getSize()); fout.close(); } } diff --git a/paddle/fluid/operators/jitkernels/jitcode_base.h b/paddle/fluid/operators/jitkernels/jitcode_base.h index a164746561ee97beffce7b2d18014d6dc4952316..ffec62163a70ebb9a1e43d5ab736f4f509719fc9 100644 --- a/paddle/fluid/operators/jitkernels/jitcode_base.h +++ b/paddle/fluid/operators/jitkernels/jitcode_base.h @@ -28,7 +28,7 @@ namespace jitkernels { // TODO(TJ): make these functions as virtual of a class // Every JitCode should estimate the code size itself -template +template size_t CodeSize(Attr attr) { return 4096; } @@ -43,13 +43,11 @@ bool UseJitCode(Attr attr) { template size_t GetKey(Attr attr); -class JitBase { +class JitBase : public Kernel { public: - JitBase() = default; - virtual ~JitBase() = default; virtual const char* name() const = 0; virtual const unsigned char* getCodeInternal() = 0; - + virtual size_t getSize() const = 0; template const FUNC getCode() { const unsigned char* code = this->getCodeInternal(); @@ -58,14 +56,17 @@ class JitBase { } return reinterpret_cast(code); } - DISABLE_COPY_AND_ASSIGN(JitBase); protected: - void dumpCode(const unsigned char* code); + void dumpCode(const unsigned char* code) const; }; -template -std::shared_ptr CreateJitCode(Attr attr); +template +std::unique_ptr CreateJitCode(Attr attr); //{ +// if (UseJitCode) { +// return make_unique(attr, CodeSize()); +// } +// } } // namespace jitkernels } // namespace operators diff --git a/paddle/fluid/operators/jitkernels/kernels.h b/paddle/fluid/operators/jitkernels/kernels.h index 866f72cce04962cef36b318886019eada9de6053..f398093dfe2b1bf66b5e1d547ba824de96ef5ea5 100644 --- a/paddle/fluid/operators/jitkernels/kernels.h +++ b/paddle/fluid/operators/jitkernels/kernels.h @@ -31,6 +31,9 @@ namespace jitkernels { template class JitCodePool { + typedef std::unique_ptr JitBasePtr; + typedef std::unordered_map JitBaseMap; + public: JitCodePool() = default; static JitCodePool& Instance() { @@ -38,29 +41,26 @@ class JitCodePool { return g_jit_codes; } - std::shared_ptr Get(size_t key) const { - if (codes_.find(key) == codes_.end()) { - return nullptr; - } - return codes_.at(key); - } + const JitBaseMap& AllKernels() { return codes_; } + + bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } - void Insert(size_t key, const std::shared_ptr& value) { - codes_.insert({key, value}); + void Insert(size_t key, JitBasePtr value) { + codes_.emplace(key, std::move(value)); } private: - std::unordered_map> codes_; + JitBaseMap codes_; DISABLE_COPY_AND_ASSIGN(JitCodePool); }; // TODO(TJ): std::tuple -template -struct KernelAttr { - typedef T data_type; - typedef Func return_type; - typedef Attr attr_type; -}; +// template +// struct KernelAttr { +// typedef T data_type; +// typedef Func return_type; +// typedef Attr attr_type; +// }; typedef std::unique_ptr KernelPtr; typedef std::unordered_map, KernelKey::Hash> @@ -123,20 +123,21 @@ inline Func GetRefer() { // TODO(TJ): make tuple? named KernelAttr template -Func Get(Attr attr) { - // size_t key = GetKey(attr); - // auto jitcode = JitCodePool().Instance().Get(key); - // if (jitcode) { - // return jitcode->template getCode(); - // } - - if (std::is_same::value && - std::is_same::value) { // TODO(TJ): float move to create - // auto p = CreateJitCode(attr); - // if (p) { - // JitCodePool().Instance().Insert(key, p); - // return p->template getCode(); - // } +const Func Get(Attr attr) { + size_t key = GetKey(attr); + auto& codes = JitCodePool().Instance(); + if (codes.Has(key)) { + return codes.AllKernels().at(key)->template getCode(); + } + + if (std::is_same::value) { // TODO(TJ): float + // move to create + auto p = CreateJitCode(attr); + if (p) { + auto f = p->template getCode(); + codes.Insert(key, std::move(p)); + return f; + } } // pool: (KernelKey(type, place), vector) diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 663a9fbf4e14fefbfa773eaecacc94bd1be60777..fd31ef77b46d5b5b641983a0421da31914c87c18 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -39,7 +39,7 @@ size_t CUDAPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CUDAPinnedMaxChunkSize(); -namespace jit { // remove this namespace +namespace jit { typedef enum { isa_any, sse42,