diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 301d038546b08e778b779edb6770e7f5642a5a18..4a14eb941cd98e333a3e85aff064e6099b3be396 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(metrics) add_subdirectory(optimizers) add_subdirectory(reduce_ops) add_subdirectory(sequence_ops) +add_subdirectory(jit) if(WITH_DISTRIBUTE) add_subdirectory(distributed) @@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h index e9d2e84a434d7084c526a6e75363a65577197262..72774a878d98b431da05cf870139752421b2df8d 100644 --- a/paddle/fluid/operators/crf_decoding_op.h +++ b/paddle/fluid/operators/crf_decoding_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/jit_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel { Tensor track; int* track_value = track.mutable_data(emission_dims, platform::CPUPlace()); - const auto& ker = math::jitkernel::KernelPool::Instance() - .template Get>( - static_cast(tag_num)); - ker->Compute(static_cast(seq_len), x, w, alpha_value, track_value); + auto ker = jit::Get, + platform::CPUPlace>(tag_num); + ker(static_cast(seq_len), x, w, alpha_value, track_value, tag_num); T max_score = -std::numeric_limits::max(); int max_i = 0; for (size_t i = 0; i < tag_num; ++i) { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index c600d1e3d76f7a989dd61e72caf4967aa5923c6f..ec85fb80f4852cc6de1e8aeda86f0e98c9e1470a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/operators/math/jit_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "xbyak/xbyak.h" #include "xbyak/xbyak_util.h" @@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; - const auto& multiply = - math::jitkernel::KernelPool::Instance() - .template Get>(n); - + auto multiply = jit::Get, + platform::CPUPlace>(0); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { @@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto ptr_z = z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); + multiply(ptr_x, ptr_y, ptr_z, h, w); } } } diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 4ce67e16dd0c4b15db26bc6556ab4715436c091b..66acba49e5ac25c5097042225ccfe30b258040fa 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_gru_op.h" #include // for memcpy #include +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel { const int total_T = x_dims[0]; \ const int D3 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* bias = ctx.Input("Bias"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const math::jitkernel::gru_attr_t attr( \ - D, ctx.Attr("gate_activation"), \ - ctx.Attr("activation")); \ - math::jitkernel::gru_t one_step; \ - const auto& ker = \ - math::jitkernel::KernelPool::Instance() \ - .template Get, \ - const math::jitkernel::gru_attr_t&>(attr); \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - auto place = ctx.GetPlace(); \ +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const jit::gru_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("activation"))); \ + jit::gru_t one_step; \ + auto ComputeH1 = \ + jit::Get, platform::CPUPlace>(attr); \ + auto ComputeHtPart1 = \ + jit::Get, platform::CPUPlace>(attr); \ + auto ComputeHtPart2 = \ + jit::Get, platform::CPUPlace>(attr); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { @@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel { } else { one_step.gates = xx_data; one_step.ht = hidden_out_data; - ker->ComputeH1(&one_step, &attr); + ComputeH1(&one_step, &attr); prev_hidden_data = hidden_out_data; tstart = 1; move_step(); @@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = xx_data; one_step.ht_1 = prev_hidden_data; one_step.ht = hidden_out_data; - ker->ComputeHtPart1(&one_step, &attr); + ComputeHtPart1(&one_step, &attr); // gemm rt * Ws blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), hidden_out_data, D, wh_state_data, D, static_cast(1), xx_data + D2, D3); - ker->ComputeHtPart2(&one_step, &attr); + ComputeHtPart2(&one_step, &attr); // save prev prev_hidden_data = hidden_out_data; move_step(); @@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel { for (int i = 0; i < max_bs; ++i) { one_step.gates = cur_in_data; one_step.ht = cur_out_data; - ker->ComputeH1(&one_step, &attr); + ComputeH1(&one_step, &attr); // add offset cur_in_data += D3; cur_out_data += D; @@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = cur_batched_data; one_step.ht_1 = cur_prev_hidden_data; one_step.ht = cur_out_data; - ker->ComputeHtPart1(&one_step, &attr); + ComputeHtPart1(&one_step, &attr); cur_batched_data += D3; cur_prev_hidden_data += D; @@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel { one_step.gates = cur_batched_data; one_step.ht_1 = cur_prev_hidden_data; one_step.ht = cur_out_data; - ker->ComputeHtPart2(&one_step, &attr); + ComputeHtPart2(&one_step, &attr); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index c4e752e3f0ce7e6d5e1f692fcb9a0290369b4243..b11b7c11bfe0ae4c79d5bb39844bce618649c44d 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D = wh_dims[0]; \ const int D4 = wh_dims[1] -#define INIT_OTHER_DEFINES \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wp_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } \ - const math::jitkernel::lstm_attr_t attr( \ - D, ctx.Attr("gate_activation"), \ - ctx.Attr("candidate_activation"), \ - ctx.Attr("cell_activation"), use_peepholes); \ - math::jitkernel::lstm_t one_step; \ - one_step.wp = wp_data; \ - one_step.checked = checked_cell_data; \ - const auto& ker = \ - math::jitkernel::KernelPool::Instance() \ - .template Get, \ - const math::jitkernel::lstm_attr_t&>(attr) +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const jit::lstm_attr_t attr( \ + D, jit::to_kerneltype(ctx.Attr("gate_activation")), \ + jit::to_kerneltype(ctx.Attr("candidate_activation")), \ + jit::to_kerneltype(ctx.Attr("cell_activation")), \ + use_peepholes); \ + jit::lstm_t one_step; \ + one_step.wp = wp_data; \ + one_step.checked = checked_cell_data; \ + auto ComputeC1H1 = \ + jit::Get, platform::CPUPlace>(attr); \ + auto ComputeCtHt = \ + jit::Get, platform::CPUPlace>(attr) // Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ @@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.gates = xx_data; one_step.ct = c_out_data; one_step.ht = h_out_data; - ker->ComputeC1H1(&one_step, &attr); + ComputeC1H1(&one_step, &attr); tstart = 1; // move one step prev_h_data = h_out_data; @@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.ct_1 = prev_c_data; one_step.ct = c_out_data; one_step.ht = h_out_data; - ker->ComputeCtHt(&one_step, &attr); + ComputeCtHt(&one_step, &attr); // move one step prev_h_data = h_out_data; prev_c_data = c_out_data; @@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.gates = cur_in_data; one_step.ct = cur_c_out_data; one_step.ht = cur_h_out_data; - ker->ComputeC1H1(&one_step, &attr); + ComputeC1H1(&one_step, &attr); cur_in_data += D4; cur_c_out_data += D; @@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel { one_step.ct_1 = cur_prev_c_data; one_step.ct = cur_c_out_data; one_step.ht = cur_h_out_data; - ker->ComputeCtHt(&one_step, &attr); + ComputeCtHt(&one_step, &attr); // move one batch cur_in_data += D4; diff --git a/paddle/fluid/operators/jit/CMakeLists.txt b/paddle/fluid/operators/jit/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ced29741253e72a17413de51fb2c24a7fb1257d3 --- /dev/null +++ b/paddle/fluid/operators/jit/CMakeLists.txt @@ -0,0 +1,25 @@ + +set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) +file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n") +file(APPEND ${jit_file} "\#pragma once\n") +file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") +file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") + +set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) + +file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") +list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc) +cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) + +# refer must go first +add_subdirectory(refer) +add_subdirectory(more) +if(WITH_XBYAK) + add_subdirectory(gen) +endif() + +cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) +cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper) +if(NOT WIN32) + cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper) +endif() diff --git a/paddle/fluid/operators/jit/README.md b/paddle/fluid/operators/jit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..89180b5900d760ce1da5bf0de879301e052db63a --- /dev/null +++ b/paddle/fluid/operators/jit/README.md @@ -0,0 +1,66 @@ +# JIT Kernel + +结合函数模板和JIT生成需要的kernel函数。 +这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。 +这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。 +目前仅支持CPU上的高性能计算。 + +## 目录结构 + +```txt +PaddlePaddle/Paddle/paddle/fluid/ +├── ... +├── operator/ +│ ├── .../ +└── jit/ + ├── ... + ├── gen/ + │ └── ... + |── more/ + │ ├── ... + │ ├── mkl/ + │ │ └── ... + │ ├── mkldnn/ + │ │ └── ... + │ ├── mix/ + │ │ └── ... + │ ├── intrinsic/ + │ │ └── ... + │ └── openblas/ + │ └── ... + └── refer/ + └── ... +``` + +基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。 +- gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。 +- refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。 +- more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。 + +## 动态获取 + +提供一个`jit::Get`方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。 + +## 测试 + +- 逻辑测试 + 所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型 +- 性能测试 + 所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。 + +# 如何添加新的算子 + +- 在`KernelType` 中添加 `your_key` . +- 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel. +- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。 +- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。 +- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。 +- 在`test.cc`中添加unit test,至少需要测试`float`和`double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。 +- 在`benchmark.cc`中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保`jit::Get`得到的实现一直是速度最快的。 + +# 优点 +- 统一的Get方法,接口简单。 +- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。 +- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。 +- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。 +- 可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。 diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..437005825db7e0718b52ac830dd56ac87069ed39 --- /dev/null +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -0,0 +1,231 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/platform/device_tracer.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/port.h" + +DEFINE_int32(burning, 10, "Burning times."); +DEFINE_int32(repeat, 3000, "Repeat times."); +DEFINE_int32(max_size, 1000, "The Max size would be tested."); + +template +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f), unsigned int seed = 100) { + std::mt19937 rng(seed); + std::uniform_real_distribution uniform_dist(0, 1); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +std::vector TestSizes() { + std::vector s; + for (int i = 1; i <= FLAGS_max_size; ++i) { + s.push_back(i); + } + return s; +} + +template +struct BenchFunc { + // return this function avg time + double operator()(const typename KernelTuples::func_type tgt, Args... args) { + for (int i = 0; i < FLAGS_burning; ++i) { + tgt(args...); + } + auto start = paddle::platform::PosixInNsec() / 1e-3; + for (int i = 0; i < FLAGS_repeat; ++i) { + tgt(args...); + } + auto end = paddle::platform::PosixInNsec() / 1e-3; + return static_cast(end - start) / FLAGS_repeat; + } +}; + +namespace jit = paddle::operators::jit; + +template +void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { + BenchFunc benchmark; + std::vector> infos; + // test refer + auto refer = jit::GetRefer(); + if (!refer) { + LOG(FATAL) << "Refer can not be empty!"; + } + infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); + + // test jitcode + auto jitcode = jit::GetJitCode(attr); + if (jitcode) { + infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); + } + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + infos.push_back( + std::make_pair(i->ImplType(), benchmark(more, args...))); + } + } + } + // Test result from Get function + auto tgt = jit::Get(attr); + if (!tgt) { + LOG(FATAL) << "Target can not be empty!"; + } + infos.push_back(std::make_pair("Target", benchmark(tgt, args...))); + + // print + std::ostringstream loginfos; + loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; + for (auto pair : infos) { + loginfos << pair.first << " takes " << pair.second << " us; "; + } + LOG(INFO) << loginfos.str(); +} + +template +void BenchXYZNKernel() { + for (int d : TestSizes()) { + std::vector x(d), y(d), z(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + BenchAllImpls, PlaceType>(d, x.data(), y.data(), + z.data(), d); + } +} + +template +void BenchAXYNKernel() { + for (int d : TestSizes()) { + const T a = static_cast(3); + std::vector x(d), y(d); + RandomVec(d, x.data()); + BenchAllImpls, PlaceType>(d, &a, x.data(), y.data(), + d); + } +} + +template +void BenchXYNKernel() { + for (int d : TestSizes()) { + std::vector x(d), y(d); + RandomVec(d, x.data()); + BenchAllImpls, PlaceType>(d, x.data(), y.data(), d); + } +} + +template +void BenchLSTMKernel() { + for (bool use_peephole : {true, false}) { + for (int d : TestSizes()) { + const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, + use_peephole); + std::vector x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); + RandomVec(4 * d, x.data(), -2.f, 2.f); + RandomVec(3 * d, wp.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + T* x_data = x.data(); + T* checked_data = checked.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + BenchAllImpls, PlaceType>(attr, &step, &attr); + } + } +} + +template +void BenchGRUKernel() { + for (int d : TestSizes()) { + const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); + std::vector x(3 * d), ht_1(d), ht(d); + RandomVec(3 * d, x.data(), -2.f, 2.f); + RandomVec(d, ht_1.data(), -2.f, 2.f); + const T* ht_1_data = ht_1.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + BenchAllImpls, PlaceType>(attr, &step, &attr); + } +} + +// Benchmark all jit kernels including jitcode, mkl and refer. +// To use this tool, run command: ./benchmark [options...] +// Options: +// --burning: the burning time before count +// --repeat: the repeat times +// --max_size: the max size would be tested +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat + << " times."; + using T = float; + using PlaceType = paddle::platform::CPUPlace; + // xyzn + BenchXYZNKernel(); + BenchXYZNKernel(); + BenchXYZNKernel(); + BenchXYZNKernel(); + + // axyn + BenchAXYNKernel(); + BenchAXYNKernel(); + + // xyn + BenchXYNKernel(); + BenchXYNKernel(); + BenchXYNKernel(); + BenchXYNKernel(); + BenchXYNKernel(); + + // lstm and peephole + BenchLSTMKernel(); + BenchLSTMKernel(); + + // gru functions + BenchGRUKernel(); + BenchGRUKernel(); + BenchGRUKernel(); +} diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a540108302f77e1ca3bfe1db0013d76a22d5eb4 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -0,0 +1,28 @@ + +file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") + +cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak) +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE) + +function(USE_JITKERNEL_GEN TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n") +endfunction() + +# use gen jitcode kernel by name +USE_JITKERNEL_GEN(kVMul) +USE_JITKERNEL_GEN(kVAdd) +#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me +USE_JITKERNEL_GEN(kVAddRelu) +USE_JITKERNEL_GEN(kVScal) +USE_JITKERNEL_GEN(kVAddBias) +USE_JITKERNEL_GEN(kVRelu) +USE_JITKERNEL_GEN(kVIdentity) +USE_JITKERNEL_GEN(kVExp) +USE_JITKERNEL_GEN(kVSigmoid) +USE_JITKERNEL_GEN(kVTanh) +USE_JITKERNEL_GEN(kLSTMCtHt) +USE_JITKERNEL_GEN(kLSTMC1H1) +USE_JITKERNEL_GEN(kGRUH1) +USE_JITKERNEL_GEN(kGRUHtPart1) +USE_JITKERNEL_GEN(kGRUHtPart2) +USE_JITKERNEL_GEN(kNCHW16CMulNC) diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ea076f217dc7c8a755055d3f48c22b7a3627012 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/act.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = { + REPEAT_8TIMES(1.f), + REPEAT_8TIMES(2.f), + REPEAT_8TIMES(0.5f), + REPEAT_8TIMES(EXP_HIG), + REPEAT_8TIMES(EXP_LOW), + REPEAT_8TIMES(CEPHES_LOG2EF), + REPEAT_8TIMES(CEPHES_EXP_C1), + REPEAT_8TIMES(CEPHES_EXP_C2), + REPEAT_8TIMES(CEPHES_EXP_P0), + REPEAT_8TIMES(CEPHES_EXP_P1), + REPEAT_8TIMES(CEPHES_EXP_P2), + REPEAT_8TIMES(CEPHES_EXP_P3), + REPEAT_8TIMES(CEPHES_EXP_P4), + REPEAT_8TIMES(CEPHES_EXP_P5), + REPEAT_8TIMES(EXP_MAX_INPUT), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; + +const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)}; +int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0}; + +void VActJitCode::genCode() { + int offset = 0; + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + vmovups(ymm_src, ptr[param1 + offset]); + act(ymm_dst, ymm_src, type_); + vmovups(ptr[param2 + offset], ymm_dst); + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + int rest = num_ % YMM_FLOAT_BLOCK; + while (rest > 0) { + int block = XMM_FLOAT_BLOCK; + if (rest >= 4) { + block = 4; + vmovups(xmm_src, ptr[param1 + offset]); + } else if (rest >= 2) { + block = 2; + vmovq(xmm_src, ptr[param1 + offset]); + } else { + block = 1; + vmovss(xmm_src, ptr[param1 + offset]); + } + act(xmm_dst, xmm_src, type_); + if (rest >= 4) { + vmovups(ptr[param2 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param2 + offset], xmm_dst); + } else { + vmovss(ptr[param2 + offset], xmm_dst); + } + offset += sizeof(float) * block; + rest -= block; + } + ret(); +} + +#define DECLARE_ACT_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + bool UseMe(const int& attr) const override { \ + return platform::MayIUse(platform::avx); \ + } \ + size_t CodeSize(const int& d) const override; \ + std::unique_ptr CreateJitCode(const int& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_ACT_CREATOR(VRelu); +DECLARE_ACT_CREATOR(VIdentity); +DECLARE_ACT_CREATOR(VExp); +DECLARE_ACT_CREATOR(VSigmoid); +DECLARE_ACT_CREATOR(VTanh); + +// TODO(TJ): tuning use me +size_t VReluCreator::CodeSize(const int& d) const { + return 96 /* init size */ + + (d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ * + 8 /* average bytes for each instruction */; +} + +size_t VIdentityCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; +} + +size_t VExpCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8; +} + +size_t VSigmoidCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8; +} + +size_t VTanhCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8; +} + +#undef DECLARE_ACT_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); +REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); +REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); +REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); +REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator); diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/jit/gen/act.h similarity index 52% rename from paddle/fluid/operators/math/jit_code.h rename to paddle/fluid/operators/jit/gen/act.h index 6d22bf675724166d0701e9a51d0d23ae00ef1048..81503c42ab5cd46961378847584f68f2cbed0ed5 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/jit/gen/act.h @@ -1,48 +1,28 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/math/jit_gen.h" -#include "paddle/fluid/operators/math/jit_kernel_impl.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" namespace paddle { namespace operators { -namespace math { -namespace jitkernel { +namespace jit { namespace gen { -using reg64_t = const Xbyak::Reg64; -using reg32_t = const Xbyak::Reg32; -using xmm_t = const Xbyak::Xmm; -using ymm_t = const Xbyak::Ymm; -using zmm_t = const Xbyak::Zmm; -using Label = Xbyak::Label; - -typedef enum { - mul = 0, - add, - sub, - relu, - exp, - sigmoid, - tanh, - identity -} operand_type; - extern const float exp_float_consts[]; extern const int exp_int_0x7f[]; extern int g_tmp_mem[]; @@ -79,94 +59,15 @@ extern int g_tmp_mem[]; #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) -// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) -class VXXJitCode : public JitCode { - public: - const char* name() const override { - std::string base = "VXXJitCode"; - if (scalar_index_ == 1) { - base += "_Scalar"; - } else { - base += "_Vec"; - } - if (type_ == operand_type::mul) { - base += "_Mul"; - } else if (type_ == operand_type::add) { - base += "_Add"; - } - if (scalar_index_ == 2) { - base += "_Scalar"; - } else { - base += "_Vec"; - } - base += (with_relu_ ? "_Relu" : ""); - return base.c_str(); - } - explicit VXXJitCode(int d, operand_type type, int scalar_index, - bool with_relu, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), - num_(d), - type_(type), - scalar_index_(scalar_index), - with_relu_(with_relu) {} - static bool init(int d, int scalar_index = 0); - void generate() override; - - private: - int num_; - operand_type type_; - int scalar_index_; - bool with_relu_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - - xmm_t xmm_src1 = xmm_t(0); - xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); - xmm_t xmm_zero = xmm_t(3); - - ymm_t ymm_src1 = ymm_t(0); - ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); - ymm_t ymm_zero = ymm_t(3); -}; - -class VActJitCode : public JitCode { +class VActFunc : public JitCode { public: - const char* name() const override { - std::string base = "VActJitCode"; - switch (type_) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - return base.c_str(); - } - - explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d), type_(type) {} - static bool init(int d, operand_type type); - void generate() override; + explicit VActFunc(size_t code_size, void* code_ptr) + : JitCode(code_size, code_ptr) {} + virtual const char* name() const = 0; + virtual void genCode() = 0; protected: - // compute relu with ymm, xmm + // compute RELU with ymm, xmm template void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT JMM zero = JMM(zero_idx); @@ -174,7 +75,7 @@ class VActJitCode : public JitCode { vmaxps(dst, src, zero); } - // compute exp with ymm, xmm + // compute EXP with ymm, xmm template void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { @@ -258,7 +159,7 @@ class VActJitCode : public JitCode { pop(reg_ptr_global); } - // compute sigmoid with ymm, xmm + // compute SIGMOID with ymm, xmm template void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, @@ -283,7 +184,7 @@ class VActJitCode : public JitCode { pop(reg_ptr_global); } - // compute tanh with ymm, xmm + // compute TANH with ymm, xmm template void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, @@ -310,223 +211,109 @@ class VActJitCode : public JitCode { pop(reg_ptr_global); } + // compute IDENTITY with ymm, xmm + template + void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT + JMM zero = JMM(zero_idx); + vxorps(zero, zero, zero); + vaddps(dst, src, zero); + // TODO(TJ): use below + // dst.setIdx(src.getIdx()); + } + template void act(JMM& dst, JMM& src, operand_type type) { // NOLINT // use 11~15 switch (type) { - case operand_type::relu: + case operand_type::RELU: relu_jmm(dst, src, 15); break; - case operand_type::exp: + case operand_type::EXP: exp_jmm(dst, src, 11, 12, 13, 14, 15); break; - case operand_type::sigmoid: + case operand_type::SIGMOID: sigmoid_jmm(dst, src, 11, 12, 13, 14, 15); break; - case operand_type::tanh: + case operand_type::TANH: tanh_jmm(dst, src, 11, 12, 13, 14, 15); break; - case operand_type::identity: + case operand_type::IDENTITY: + identity_jmm(dst, src, 15); break; default: - // throw error + LOG(FATAL) << "Do not support this operand type: " << type; break; } } - - protected: - int num_; - operand_type type_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - - xmm_t xmm_src = xmm_t(0); - ymm_t ymm_src = ymm_t(0); - - xmm_t xmm_dst = xmm_t(1); - ymm_t ymm_dst = ymm_t(1); }; -class LSTMJitCode : public VActJitCode { +class VActJitCode : public VActFunc { public: - const char* name() const override { - std::string base = "LSTMJitCode"; - if (use_peephole_) { - base += "_Peephole"; - } - if (compute_c1h1_) { - base += "_C1H1"; + explicit VActJitCode(int d, operand_type type, size_t code_size, + void* code_ptr = nullptr) + : VActFunc(code_size, code_ptr), num_(d), type_(type) { + if (!(type_ == operand_type::RELU || type_ == operand_type::EXP || + type_ == operand_type::SIGMOID || type_ == operand_type::TANH || + type_ == operand_type::IDENTITY)) { + LOG(FATAL) << "Do not support this operand type: " << type_; } - auto AddTypeStr = [&](operand_type type) { - switch (type) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - }; - AddTypeStr(act_gate_); - AddTypeStr(act_cand_); - AddTypeStr(act_cell_); - return base.c_str(); - } - - explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - compute_c1h1_(compute_c1h1) { - auto typeExchange = [](const std::string& type) -> gen::operand_type { - if (type == "sigmoid") { - return operand_type::sigmoid; - } else if (type == "relu") { - return operand_type::relu; - } else if (type == "tanh") { - return operand_type::tanh; - } else if (type == "identity" || type == "") { - return operand_type::identity; - } // else throw error - return operand_type::identity; - }; - num_ = attr.d; - use_peephole_ = attr.use_peephole; - act_gate_ = typeExchange(attr.act_gate); - act_cand_ = typeExchange(attr.act_cand); - act_cell_ = typeExchange(attr.act_cell); + this->genCode(); } - static bool init(int d); - void generate() override; - - protected: - int num_; - bool compute_c1h1_; - bool use_peephole_; - operand_type act_gate_; - operand_type act_cand_; - operand_type act_cell_; - reg64_t param1{abi_param1}; -}; -class GRUJitCode : public VActJitCode { - public: const char* name() const override { - std::string base = "GRUJitCode"; - if (id_ == 0) { - base += "_H1"; - } else if (id_ == 1) { - base += "_HtPart1"; - } else if (id_ == 2) { - base += "_HtPart2"; + std::string base = "VActJitCode"; + switch (type_) { + case operand_type::RELU: + base += "_Relu"; + break; + case operand_type::EXP: + base += "_Exp"; + break; + case operand_type::SIGMOID: + base += "_Sigmoid"; + break; + case operand_type::TANH: + base += "_Tanh"; + break; + case operand_type::IDENTITY: + base += "_Identity"; + break; + default: + break; } - auto AddTypeStr = [&](operand_type type) { - switch (type) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - }; - AddTypeStr(act_gate_); - AddTypeStr(act_cand_); return base.c_str(); } - - explicit GRUJitCode(int id, const gru_attr_t& attr, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - id_(id) { - auto typeExchange = [](const std::string& type) -> gen::operand_type { - if (type == "sigmoid") { - return operand_type::sigmoid; - } else if (type == "relu") { - return operand_type::relu; - } else if (type == "tanh") { - return operand_type::tanh; - } else if (type == "identity" || type == "") { - return operand_type::identity; - } // else throw error - return operand_type::identity; - }; - num_ = attr.d; - act_gate_ = typeExchange(attr.act_gate); - act_cand_ = typeExchange(attr.act_cand); - } - static bool init(int d); - void generate() override; + void genCode() override; protected: - int id_; int num_; - operand_type act_gate_; - operand_type act_cand_; + operand_type type_; reg64_t param1{abi_param1}; -}; + reg64_t param2{abi_param2}; -#ifdef PADDLE_WITH_MKLDNN -struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { - explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) - : Xbyak::CodeGenerator(code_size) { - // RDI is ptr x_input - // RSI is ptr y_input - // RDX is ptr output - // RCX is height - // r8 is width + xmm_t xmm_src = xmm_t(0); + ymm_t ymm_src = ymm_t(0); - push(rbx); + xmm_t xmm_dst = xmm_t(1); + ymm_t ymm_dst = ymm_t(1); +}; - xor_(rax, rax); - xor_(r10, r10); - vmovups(zmm3, ptr[rsi]); +#define DECLARE_ACT_JITCODE(name, op_type) \ + class name##JitCode : public VActJitCode { \ + public: \ + explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \ + : VActJitCode(d, op_type, code_size, code_ptr) {} \ + }; - L("h_loop"); - xor_(rbx, rbx); - L("w_loop"); - vmovups(zmm2, ptr[rdi + rax]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx + rax], zmm1); - add(rax, 64); - inc(rbx); - cmp(r8, rbx); - jnz("w_loop"); - inc(r10); - cmp(r10, rcx); - jnz("h_loop"); +DECLARE_ACT_JITCODE(VRelu, operand_type::RELU); +DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY); +DECLARE_ACT_JITCODE(VExp, operand_type::EXP); +DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID); +DECLARE_ACT_JITCODE(VTanh, operand_type::TANH); - pop(rbx); - ret(); - } -}; -#endif +#undef DECLARE_ACT_JITCODE } // namespace gen -} // namespace jitkernel -} // namespace math +} // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1198773088faa594bac0714dd8449b240b3ce4d --- /dev/null +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/blas.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void VXXJitCode::genCode() { + // do not need push stack, and do not need save avx512reg if do not use avx512 + int offset = 0; + if (with_relu_) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } + if (scalar_index_ == 1) { + vbroadcastss(ymm_src1, ptr[param1]); + } else if (scalar_index_ == 2) { + vbroadcastss(ymm_src2, ptr[param2]); + } + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + if (scalar_index_ != 1) { + vmovups(ymm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(ymm_src2, ptr[param2 + offset]); + } + if (type_ == operand_type::MUL) { + vmulps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::ADD) { + vaddps(ymm_dst, ymm_src1, ymm_src2); + } + if (with_relu_) { + vmaxps(ymm_dst, ymm_zero, ymm_dst); + } + vmovups(ptr[param3 + offset], ymm_dst); + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + int rest = num_ % YMM_FLOAT_BLOCK; + while (rest > 0) { + int block = XMM_FLOAT_BLOCK; + if (rest >= 4) { + block = 4; + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } + } else if (rest >= 2) { + block = 2; + if (scalar_index_ != 1) { + vmovq(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + } + } else { + block = 1; + if (scalar_index_ != 1) { + vmovss(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovss(xmm_src2, ptr[param2 + offset]); + } + } + switch (type_) { + case operand_type::MUL: + vmulps(xmm_dst, xmm_src1, xmm_src2); + break; + case operand_type::ADD: + vaddps(xmm_dst, xmm_src1, xmm_src2); + break; + default: + break; + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } + if (rest >= 4) { + vmovups(ptr[param3 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param3 + offset], xmm_dst); + } else { + vmovss(ptr[param3 + offset], xmm_dst); + } + offset += sizeof(float) * block; + rest -= block; + } + ret(); +} + +void NCHW16CMulNCJitCode::genCode() { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); +} + +class NCHW16CMulNCCreator : public JitCodeCreator { + public: + bool UseMe(const int& attr) const override { + return platform::MayIUse(platform::avx512f); + } + size_t CodeSize(const int& d) const override { return 256 * 1024; } + std::unique_ptr CreateJitCode(const int& attr) const override { + return make_unique(attr, CodeSize(attr)); + } +}; + +#define DECLARE_BLAS_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + bool UseMe(const int& attr) const override { \ + return platform::MayIUse(platform::avx); \ + } \ + size_t CodeSize(const int& d) const override { \ + return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ + } \ + std::unique_ptr CreateJitCode(const int& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_BLAS_CREATOR(VMul); +DECLARE_BLAS_CREATOR(VAdd); +DECLARE_BLAS_CREATOR(VSub); +DECLARE_BLAS_CREATOR(VAddRelu); +DECLARE_BLAS_CREATOR(VScal); +DECLARE_BLAS_CREATOR(VAddBias); + +#undef DECLARE_BLAS_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); +REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); +// TODO(TJ): enable sub +// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); +REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); +REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); +REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); +REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h new file mode 100644 index 0000000000000000000000000000000000000000..c46ec15fb788c0c7a90cfc8732aad375a9e226a1 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) +class VXXJitCode : public JitCode { + public: + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + scalar_index_(scalar_index), + with_relu_(with_relu) { + if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) { + LOG(FATAL) << "Do not support this operand type: " << type_; + } + this->genCode(); + } + + virtual const char* name() const { + std::string base = "VXXJitCode"; + if (scalar_index_ == 1) { + base += "_Scalar"; + } else { + base += "_Vec"; + } + if (type_ == operand_type::MUL) { + base += "_Mul"; + } else if (type_ == operand_type::ADD) { + base += "_Add"; + } + if (scalar_index_ == 2) { + base += "_Scalar"; + } else { + base += "_Vec"; + } + base += (with_relu_ ? "_Relu" : ""); + return base.c_str(); + } + void genCode() override; + + private: + int num_; + operand_type type_; + int scalar_index_; + bool with_relu_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + + xmm_t xmm_src1 = xmm_t(0); + xmm_t xmm_src2 = xmm_t(1); + xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_zero = xmm_t(3); + + ymm_t ymm_src1 = ymm_t(0); + ymm_t ymm_src2 = ymm_t(1); + ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); +}; + +#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \ + class name##JitCode : public VXXJitCode { \ + public: \ + explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \ + : VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \ + } \ + }; + +DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false); +DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false); +DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false); +DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true); +DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false); +DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false); + +#undef DECLARE_BLAS_JITCODE + +// nChw16c = nChw16c .* NC +class NCHW16CMulNCJitCode : public JitCode { + public: + DECLARE_JIT_CODE(NCHW16CMulNCJitCode); + explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr) { + this->genCode(); + } + void genCode() override; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc new file mode 100644 index 0000000000000000000000000000000000000000..13f7a14111a80632a06c7fc632da47c0802828f7 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/gru.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/gru.h" +#include // offsetof +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void GRUJitCode::genCode() { + reg64_t reg_ptr_gates = rax; + reg64_t reg_ptr_ht_1 = r9; + reg64_t reg_ptr_ht = r10; + mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); + mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); + ymm_t ymm_one = ymm_t(0); + + if (id_ == 2) { + reg64_t reg_ptr_tmp = r11; + mov(reg_ptr_tmp, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); + } + int offset = 0; + int d = num_ * sizeof(float); + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + ymm_t ymm_u = ymm_t(1); + ymm_t ymm_r = ymm_t(2); + ymm_t ymm_s = ymm_t(3); + ymm_t ymm_ht_1 = ymm_t(4); + // W: {W_update, W_reset; W_state} + if (id_ == 0 || id_ == 2) { + vmovups(ymm_u, ptr[reg_ptr_gates + offset]); + vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); + } + if (id_ == 1) { + vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); + } + if (id_ == 1 || id_ == 2) { + vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); + } + + if (id_ == 0) { + // ht = act_gate(u) * act_cand(s) + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_s); + } else if (id_ == 1) { + // ht = act_gate(r) * ht_1 + act(ymm_r, ymm_r, act_gate_); + vmulps(ymm_r, ymm_r, ymm_ht_1); + vmovups(ptr[reg_ptr_ht + offset], ymm_r); + } else if (id_ == 2) { + // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 + ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); + act(ymm_u, ymm_u, act_gate_); + act(ymm_s, ymm_s, act_cand_); + vmulps(ymm_s, ymm_s, ymm_u); + vsubps(ymm_u, ymm_one_inner, ymm_u); + vmulps(ymm_u, ymm_ht_1, ymm_u); + vaddps(ymm_u, ymm_s, ymm_u); + vmovups(ptr[reg_ptr_ht + offset], ymm_u); + } + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + ret(); +} + +#define DECLARE_GRU_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool UseMe(const gru_attr_t& attr) const override { \ + return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ + } \ + size_t CodeSize(const gru_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const gru_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_GRU_CREATOR(GRUH1); +DECLARE_GRU_CREATOR(GRUHtPart1); +DECLARE_GRU_CREATOR(GRUHtPart2); + +#undef DECLARE_GRU_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator); +REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator); +REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator); diff --git a/paddle/fluid/operators/jit/gen/gru.h b/paddle/fluid/operators/jit/gen/gru.h new file mode 100644 index 0000000000000000000000000000000000000000..a4d7222a3459d175fc5eaf5cdf0e7a1a610f8b0c --- /dev/null +++ b/paddle/fluid/operators/jit/gen/gru.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class GRUJitCode : public VActFunc { + public: + explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size, + void* code_ptr = nullptr) + : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) { + auto typeExchange = [](KernelType type) -> gen::operand_type { + if (type == KernelType::kVSigmoid) { + return operand_type::SIGMOID; + } else if (type == KernelType::kVRelu) { + return operand_type::RELU; + } else if (type == KernelType::kVTanh) { + return operand_type::TANH; + } else if (type == KernelType::kVIdentity) { + return operand_type::IDENTITY; + } else { + LOG(FATAL) << "Do not support this jit::KernelType: " << type; + } + return operand_type::IDENTITY; + }; + act_gate_ = typeExchange(attr.act_gate); + act_cand_ = typeExchange(attr.act_cand); + + this->genCode(); + } + + const char* name() const override { + std::string base = "GRUJitCode"; + if (id_ == 0) { + base += "_H1"; + } else if (id_ == 1) { + base += "_HtPart1"; + } else if (id_ == 2) { + base += "_HtPart2"; + } + auto AddTypeStr = [&](operand_type type) { + switch (type) { + case operand_type::RELU: + base += "_Relu"; + break; + case operand_type::EXP: + base += "_Exp"; + break; + case operand_type::SIGMOID: + base += "_Sigmoid"; + break; + case operand_type::TANH: + base += "_Tanh"; + break; + case operand_type::IDENTITY: + base += "_Identity"; + break; + default: + break; + } + }; + AddTypeStr(act_gate_); + AddTypeStr(act_cand_); + return base.c_str(); + } + void genCode() override; + + protected: + int id_; + int num_; + operand_type act_gate_; + operand_type act_cand_; + reg64_t param1{abi_param1}; +}; + +#define DECLARE_GRU_JITCODE(name, id) \ + class name##JitCode : public GRUJitCode { \ + public: \ + explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \ + void* code_ptr = nullptr) \ + : GRUJitCode(id, attr, code_size, code_ptr) {} \ + }; + +DECLARE_GRU_JITCODE(GRUH1, 0); +DECLARE_GRU_JITCODE(GRUHtPart1, 1); +DECLARE_GRU_JITCODE(GRUHtPart2, 2); + +#undef DECLARE_GRU_JITCODE + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h new file mode 100644 index 0000000000000000000000000000000000000000..5b7234c1cb5d15d290685a3dceb3b757be1ef0c6 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/gen_base.h" +#include "paddle/fluid/platform/cpu_info.h" + +#define XBYAK_USE_MMAP_ALLOCATOR +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +// Application Binary Interface +constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), + abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), + abi_param4(Xbyak::Operand::RCX); + +constexpr Xbyak::Operand::Code g_abi_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15}; + +constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]); + +using reg64_t = const Xbyak::Reg64; +using reg32_t = const Xbyak::Reg32; +using xmm_t = const Xbyak::Xmm; +using ymm_t = const Xbyak::Ymm; +using zmm_t = const Xbyak::Zmm; +using Label = Xbyak::Label; + +typedef enum { + MUL = 0, + ADD, + SUB, + RELU, + EXP, + SIGMOID, + TANH, + IDENTITY +} operand_type; + +#define DECLARE_JIT_CODE(codename) \ + const char* name() const override { return #codename; } + +class JitCode : public GenBase, public Xbyak::CodeGenerator { + public: + explicit JitCode(size_t code_size, void* code_ptr = nullptr) + : Xbyak::CodeGenerator( + (code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size), + code_ptr) {} + + virtual const char* name() const = 0; + virtual void genCode() = 0; + + size_t getSize() const override { return CodeGenerator::getSize(); } + const unsigned char* getCodeInternal() override { + const Xbyak::uint8* code = CodeGenerator::getCode(); + return code; + } + + protected: + Xbyak::Reg64 param1{abi_param1}; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + virtual void preCode() { + for (int i = 0; i < num_g_abi_regs; ++i) { + push(Xbyak::Reg64(g_abi_regs[i])); + } + if (platform::MayIUse(platform::avx512f)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + virtual void postCode() { + for (int i = 0; i < num_g_abi_regs; ++i) { + pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i])); + } + ret(); + } + void L(const char* label) { Xbyak::CodeGenerator::L(label); } + void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + // Enhanced vector extension + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt, + bool bcast = false) { + int scale = 0; + // Learn from https://github.com/intel/mkl-dnn + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + auto re = Xbyak::RegExp() + base + offt; + if (scale) { + re = re + reg_EVEX_max_8b_offt * scale; + } + if (bcast) { + return zword_b[re]; + } else { + return zword[re]; + } + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..08bafb5a81882072129a4bfa86d5aff2d33a79a1 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/lstm.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/lstm.h" +#include // offsetof +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void LSTMJitCode::genCode() { + if (use_peephole_) { + preCode(); + } + reg64_t reg_ptr_gates = rax; + reg64_t reg_ptr_ct_1 = r9; + reg64_t reg_ptr_ct = r10; + reg64_t reg_ptr_ht = r11; + reg64_t reg_ptr_wp = r12; + mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); + mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); + mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); + mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); + if (use_peephole_) { + mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]); + } + + int offset = 0; + int d = num_ * sizeof(float); + for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { + /* gates: W_ch, W_ih, W_fh, W_oh */ + ymm_t ymm_c = ymm_t(0); + ymm_t ymm_i = ymm_t(1); + ymm_t ymm_f = ymm_t(2); + ymm_t ymm_o = ymm_t(3); + ymm_t ymm_ct_1 = ymm_t(4); + ymm_t ymm_wp0 = ymm_t(5); + ymm_t ymm_wp1 = ymm_t(6); + ymm_t ymm_wp2 = ymm_t(7); + vmovups(ymm_c, ptr[reg_ptr_gates + offset]); + vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); + vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); + vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); + if (!compute_c1h1_) { + vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); + } + if (use_peephole_) { + vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); + vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); + vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); + } + /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ + // act_cand(c) + act(ymm_c, ymm_c, act_cand_); + // act_gate(i) or act_gate(ct_1 * wp0 + i) + if (!compute_c1h1_ && use_peephole_) { + vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); + vaddps(ymm_i, ymm_i, ymm_wp0); + } + act(ymm_i, ymm_i, act_gate_); + vmulps(ymm_c, ymm_c, ymm_i); + if (!compute_c1h1_) { + // act_gate(f) or act_gate(ct_1 * wp1 + f) + if (use_peephole_) { + vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); + vaddps(ymm_f, ymm_f, ymm_wp1); + } + act(ymm_f, ymm_f, act_gate_); + // ct + vmulps(ymm_f, ymm_f, ymm_ct_1); + vaddps(ymm_f, ymm_f, ymm_c); + } + /* H_t = act_cell(C_t) * act_gate(o) */ + // act_cell(C_t) + ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; + ymm_t ymm_tmp = ymm_i; + act(ymm_tmp, ymm_ct, act_cell_); + // act_gate(o) or act_gate(ct * wp2 + o) + if (use_peephole_) { + vmulps(ymm_wp2, ymm_ct, ymm_wp2); + vaddps(ymm_o, ymm_o, ymm_wp2); + } + act(ymm_o, ymm_o, act_gate_); + // ht + vmulps(ymm_o, ymm_o, ymm_tmp); + // save ct and ht + vmovups(ptr[reg_ptr_ct + offset], ymm_ct); + vmovups(ptr[reg_ptr_ht + offset], ymm_o); + offset += sizeof(float) * YMM_FLOAT_BLOCK; + } + + if (use_peephole_) { + postCode(); + } else { + ret(); + } +} + +#define DECLARE_LSTM_CREATOR(name) \ + class name##Creator : public JitCodeCreator { \ + public: \ + /* TODO(TJ): enable more */ \ + bool UseMe(const lstm_attr_t& attr) const override { \ + return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ + } \ + size_t CodeSize(const lstm_attr_t& attr) const override { \ + return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \ + } \ + std::unique_ptr CreateJitCode( \ + const lstm_attr_t& attr) const override { \ + return make_unique(attr, CodeSize(attr)); \ + } \ + } + +DECLARE_LSTM_CREATOR(LSTMCtHt); +DECLARE_LSTM_CREATOR(LSTMC1H1); + +#undef DECLARE_LSTM_CREATOR + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator); +REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator); diff --git a/paddle/fluid/operators/jit/gen/lstm.h b/paddle/fluid/operators/jit/gen/lstm.h new file mode 100644 index 0000000000000000000000000000000000000000..d4753bca23de91c74415d41c372cde1610712ef7 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/lstm.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/act.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class LSTMJitCode : public VActFunc { + public: + explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, + size_t code_size, void* code_ptr = nullptr) + : VActFunc(code_size, code_ptr), + num_(attr.d), + compute_c1h1_(compute_c1h1), + use_peephole_(attr.use_peephole) { + auto typeExchange = [](KernelType type) -> gen::operand_type { + if (type == KernelType::kVSigmoid) { + return operand_type::SIGMOID; + } else if (type == KernelType::kVRelu) { + return operand_type::RELU; + } else if (type == KernelType::kVTanh) { + return operand_type::TANH; + } else if (type == KernelType::kVIdentity) { + return operand_type::IDENTITY; + } else { + LOG(FATAL) << "Do not support this jit::KernelType: " << type; + } + return operand_type::IDENTITY; + }; + act_gate_ = typeExchange(attr.act_gate); + act_cand_ = typeExchange(attr.act_cand); + act_cell_ = typeExchange(attr.act_cell); + + this->genCode(); + } + + const char* name() const override { + std::string base = "LSTMJitCode"; + if (use_peephole_) { + base += "_Peephole"; + } + if (compute_c1h1_) { + base += "_C1H1"; + } + auto AddTypeStr = [&](operand_type type) { + switch (type) { + case operand_type::RELU: + base += "_Relu"; + break; + case operand_type::EXP: + base += "_Exp"; + break; + case operand_type::SIGMOID: + base += "_Sigmoid"; + break; + case operand_type::TANH: + base += "_Tanh"; + break; + case operand_type::IDENTITY: + base += "_Identity"; + break; + default: + break; + } + }; + AddTypeStr(act_gate_); + AddTypeStr(act_cand_); + AddTypeStr(act_cell_); + return base.c_str(); + } + void genCode() override; + + protected: + int num_; + bool compute_c1h1_; + bool use_peephole_; + operand_type act_gate_; + operand_type act_cand_; + operand_type act_cell_; + reg64_t param1{abi_param1}; +}; + +#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \ + class name##JitCode : public LSTMJitCode { \ + public: \ + explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \ + void* code_ptr = nullptr) \ + : LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \ + }; + +DECLARE_LSTM_JITCODE(LSTMCtHt, false); +DECLARE_LSTM_JITCODE(LSTMC1H1, true); + +#undef DECLARE_LSTM_JITCODE + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..310da0c76f1ab251d788e54f2305f375f3fb4838 --- /dev/null +++ b/paddle/fluid/operators/jit/gen_base.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen_base.h" +#include +#include +#include + +DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); + +namespace paddle { +namespace operators { +namespace jit { + +// refer do not need useme, it would be the last one. +void GenBase::dumpCode(const unsigned char* code) const { + if (code) { + static int counter = 0; + std::ostringstream filename; + filename << "paddle_jitcode_" << name() << "." << counter << ".bin"; + counter++; + std::ofstream fout(filename.str(), std::ios::out); + if (fout.is_open()) { + fout.write(reinterpret_cast(code), this->getSize()); + fout.close(); + } + } +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h new file mode 100644 index 0000000000000000000000000000000000000000..48855abd267687b0f3c092279c1f29cc9fb1da40 --- /dev/null +++ b/paddle/fluid/operators/jit/gen_base.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include // for unique_ptr +#include "paddle/fluid/operators/jit/kernel_base.h" + +DECLARE_bool(dump_jitcode); + +namespace paddle { +namespace operators { +namespace jit { + +class GenBase : public Kernel { + public: + virtual ~GenBase() = default; + virtual const char* name() const = 0; + virtual size_t getSize() const = 0; + virtual const unsigned char* getCodeInternal() = 0; + template + Func getCode() { + const unsigned char* code = this->getCodeInternal(); + if (FLAGS_dump_jitcode) { + this->dumpCode(code); + } + return reinterpret_cast(const_cast(code)); + } + + protected: + void dumpCode(const unsigned char* code) const; +}; + +// Creator is used to creat the jitcode and save in pool. +// Every JitCode should have one creator. +class GenCreator { + public: + virtual ~GenCreator() = default; +}; + +template +class JitCodeCreator : public GenCreator { + public: + virtual ~JitCodeCreator() = default; + + // condition when this jit code can be used. + virtual bool UseMe(const Attr& attr) const = 0; + + // estimate this code size + virtual size_t CodeSize(const Attr& attr) const = 0; + + // create this code + virtual std::unique_ptr CreateJitCode(const Attr& attr) const = 0; +}; + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..d00584baa081c21762774aef4cbbc714d49cd012 --- /dev/null +++ b/paddle/fluid/operators/jit/helper.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/helper.h" +#include // tolower +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace jit { + +#define ONE_CASE(key) \ + case key: \ + return #key + +const char* to_string(KernelType kt) { + switch (kt) { + ONE_CASE(kVMul); + ONE_CASE(kVAdd); + ONE_CASE(kVAddRelu); + ONE_CASE(kVSub); + ONE_CASE(kVScal); + ONE_CASE(kVAddBias); + ONE_CASE(kVRelu); + ONE_CASE(kVIdentity); + ONE_CASE(kVExp); + ONE_CASE(kVSigmoid); + ONE_CASE(kVTanh); + ONE_CASE(kLSTMCtHt); + ONE_CASE(kLSTMC1H1); + ONE_CASE(kGRUH1); + ONE_CASE(kGRUHtPart1); + ONE_CASE(kGRUHtPart2); + ONE_CASE(kCRFDecoding); + ONE_CASE(kLayerNorm); + ONE_CASE(kNCHW16CMulNC); + default: + PADDLE_THROW("Not support type: %d, or forget to add it.", kt); + return "NOT JITKernel"; + } + return nullptr; +} +#undef ONE_CASE + +KernelType to_kerneltype(const std::string& act) { + std::string lower = act; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + if (lower == "relu" || lower == "vrelu") { + return kVRelu; + } else if (lower == "identity" || lower == "videntity" || lower == "") { + return kVIdentity; + } else if (lower == "exp" || lower == "vexp") { + return kVExp; + } else if (lower == "sigmoid" || lower == "vsigmoid") { + return kVSigmoid; + } else if (lower == "tanh" || lower == "vtanh") { + return kVTanh; + } + PADDLE_THROW("Not support type: %s, or forget to add this case", act); + return kNone; +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..412df86aa1cd94871989aef25adef803f673812b --- /dev/null +++ b/paddle/fluid/operators/jit/helper.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/operators/jit/gen_base.h" +#include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/operators/jit/kernel_key.h" +#include "paddle/fluid/operators/jit/kernel_pool.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace jit { + +template +inline typename std::enable_if< + std::is_same::value && + std::is_same::value, + typename KernelTuples::func_type>::type +GetJitCode(const typename KernelTuples::attr_type& attr) { + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; + size_t key = JitCodeKey(attr); + auto& codes = JitCodePool().Instance(); + if (codes.Has(key)) { + return codes.AllKernels().at(key)->template getCode(); + } + + // creator is not related with attr, so can use KernelKey as key + KernelKey kkey(KT, PlaceType()); + // pool: (KernelKey(type, place), vector) + auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); + auto iter = creator_map.find(kkey); + if (iter != creator_map.end()) { + auto& creators = iter->second; + for (auto& cur : creators) { + auto i = dynamic_cast*>(cur.get()); + if (i && i->UseMe(attr)) { + auto p = i->CreateJitCode(attr); + if (p) { + auto f = p->template getCode(); + codes.Insert(key, std::move(p)); + return f; + } + } + } + } + return nullptr; +} + +template +inline typename std::enable_if< + !std::is_same::value || + !std::is_same::value, + typename KernelTuples::func_type>::type +GetJitCode(const typename KernelTuples::attr_type& attr) { + return nullptr; +} + +// Refer code do not related with attr, which is just for cast +// Refer is always on CPUPlace +template +inline typename KernelTuples::func_type GetRefer() { + auto& ref_pool = ReferKernelPool().Instance().AllKernels(); + KernelKey kkey(KT, platform::CPUPlace()); + auto ref_iter = ref_pool.find(kkey); + PADDLE_ENFORCE(ref_iter != ref_pool.end(), + "Every Kernel should have reference function."); + auto& ref_impls = ref_iter->second; + for (auto& impl : ref_impls) { + auto i = dynamic_cast*>(impl.get()); + if (i) { + return i->GetFunc(); + } + } + return nullptr; +} + +template +typename KernelTuples::func_type Get( + const typename KernelTuples::attr_type& attr) { + auto jitfunc = GetJitCode(attr); + if (jitfunc) { + return jitfunc; + } + + // pool: (KernelKey(type, place), vector) + KernelKey kkey(KT, PlaceType()); + auto& pool = KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); + if (i && i->UseMe(attr)) { + return i->GetFunc(); + } + } + } + + // The last implementation should be reference function on CPUPlace. + return GetRefer(); +} + +const char* to_string(KernelType kt); + +KernelType to_kerneltype(const std::string& act); + +inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) { + os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) + << "],act_cand[" << to_string(attr.act_cand) << "],act_cell[" + << to_string(attr.act_cell) << "],use_peephole[" + << (attr.use_peephole ? "True" : "False") << "]"; + return os; +} +inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { + os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) + << "],act_cand[" << to_string(attr.act_cand) << "]"; + return os; +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b4a2d5d47301a2fd82bf27ddfaaa31ef23e431c2 --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -0,0 +1,172 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/jit/macro.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { +namespace jit { + +typedef enum { + kNone = 0, + kVMul = 1, + kVAdd = 2, + kVAddRelu, + kVSub, + kVScal, + kVAddBias, + kVRelu, + kVIdentity, + kVExp, + kVSigmoid, + kVTanh, + kLSTMCtHt, + kLSTMC1H1, + kGRUH1, + kGRUHtPart1, + kGRUHtPart2, + kCRFDecoding, + kLayerNorm, + kNCHW16CMulNC, +} KernelType; + +template +struct XYZNTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int); +}; + +template +struct AXYNTuples : public XYZNTuples {}; + +template +struct XYNTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, T*, int); +}; + +typedef struct { + void* gates; // gates: x_ch, x_ih, x_fh, x_oh + const void* ct_1; + void* ct; + void* ht; + /* weight_peephole and checked data are only used in peephole*/ + const void* wp{nullptr}; // W_ic, W_fc, W_oc + void* checked{nullptr}; // size: 2 * d +} lstm_t; + +typedef struct { + void* gates; // gates: {x_update, x_reset; x_state} + const void* ht_1; + void* ht; +} gru_t; + +struct rnn_attr_s { + int d; + KernelType act_gate, act_cand; + rnn_attr_s() = default; + explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) + : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} +}; + +struct lstm_attr_s : public rnn_attr_s { + bool use_peephole; + KernelType act_cell; + lstm_attr_s() = default; + explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, + KernelType _act_cell, bool _use_peephole = false) + : rnn_attr_s(_d, _act_gate, _act_cand), + use_peephole(_use_peephole), + act_cell(_act_cell) {} +}; + +typedef struct rnn_attr_s gru_attr_t; +typedef struct lstm_attr_s lstm_attr_t; + +template +struct LSTMTuples { + typedef T data_type; + typedef lstm_attr_t attr_type; + typedef void (*func_type)(lstm_t*, const lstm_attr_t*); +}; + +template +struct GRUTuples { + typedef T data_type; + typedef gru_attr_t attr_type; + typedef void (*func_type)(gru_t*, const gru_attr_t*); +}; + +template +struct CRFDecodingTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); +}; + +template +struct LayerNormTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, + const float, int); +}; + +// nChw16c = nChw16c .* NC +template +struct NCHW16CMulNCTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int, int); +}; + +// Just for adding to kernel pool without template +class Kernel { + public: + Kernel() = default; + virtual ~Kernel() = default; + DISABLE_COPY_AND_ASSIGN(Kernel); +}; + +template +class KernelMore : public Kernel { + public: + using T = typename KernelTuples::data_type; + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; + virtual Func GetFunc() const { return func; } + virtual bool UseMe(const Attr& attr) const = 0; + virtual const char* ImplType() const = 0; + + protected: + Func func{nullptr}; +}; + +template +class ReferKernel : public KernelMore { + public: + // Refer code can always be used + bool UseMe(const typename KernelTuples::attr_type& attr) const override { + return true; + } + const char* ImplType() const override { return "Refer"; } +}; + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e6a19f04fd425b920aeea49b63001941d800a73 --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/kernel_key.h" + +namespace paddle { +namespace operators { +namespace jit { + +template <> +size_t JitCodeKey(const int& d) { + return d; +} + +constexpr int act_type_shift = 3; // suppot 2^3 act types + +template <> +size_t JitCodeKey(const lstm_attr_t& attr) { + size_t key = attr.d; + int gate_key = static_cast(attr.act_gate) << 1; + int cand_key = static_cast(attr.act_cand) << (1 + act_type_shift); + int cell_key = static_cast(attr.act_cell) << (1 + act_type_shift * 2); + return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + + attr.use_peephole; +} + +template <> +size_t JitCodeKey(const gru_attr_t& attr) { + size_t key = attr.d; + return (key << (act_type_shift * 2)) + static_cast(attr.act_gate) + + (static_cast(attr.act_cand) << act_type_shift); +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_key.h b/paddle/fluid/operators/jit/kernel_key.h new file mode 100644 index 0000000000000000000000000000000000000000..611a0210d614196ad0b05d583303688c1d964e04 --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_key.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace jit { + +struct KernelKey { + struct Hash { + size_t operator()(const KernelKey& key) const { + int place = key.place_.which(); // less than 2^8 + int type = static_cast(key.type_) << 8; // less than 2^(32-8) + std::hash hasher; + return hasher(place + type); + } + }; + + KernelType type_; + platform::Place place_; + + KernelKey(KernelType type, platform::Place place) + : type_(type), place_(place) {} + size_t hash_key() const { return Hash()(*this); } + + bool operator==(const KernelKey& o) const { + return platform::places_are_same_class(place_, o.place_) && + type_ == o.type_; + } + bool operator!=(const KernelKey& o) const { return !(*this == o); } +}; + +// Every JitCode should have a method to get the key from attribution +template +size_t JitCodeKey(const Attr& attr); + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_pool.cc b/paddle/fluid/operators/jit/kernel_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc98c644fbee2cd54faf4dc9fe151b8be131bd7b --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_pool.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/kernel_pool.h" +#include // for shared_ptr +#include +#include + +namespace paddle { +namespace operators { +namespace jit { + +JitCodeCreatorPool& JitCodeCreatorPool::Instance() { + static JitCodeCreatorPool g_creator_pool; + return g_creator_pool; +} + +KernelPool& KernelPool::Instance() { + static KernelPool g_kernel_pool; + return g_kernel_pool; +} + +ReferKernelPool& ReferKernelPool::Instance() { + static ReferKernelPool g_refer_kernel_pool; + return g_refer_kernel_pool; +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..3e15242af28839ee0759e1a5b3930d6d6bfaa0ff --- /dev/null +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include // for unique_ptr +#include +#include +#include +#include "paddle/fluid/operators/jit/gen_base.h" +#include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/operators/jit/kernel_key.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace jit { + +template +class JitCodePool { + typedef std::unique_ptr GenBasePtr; + typedef std::unordered_map JitCodeMap; + + public: + JitCodePool() = default; + static JitCodePool& Instance() { + static thread_local JitCodePool g_jit_codes; + return g_jit_codes; + } + + const JitCodeMap& AllKernels() { return codes_; } + + bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } + + void Insert(size_t key, GenBasePtr value) { + codes_.emplace(key, std::move(value)); + } + + private: + JitCodeMap codes_; + DISABLE_COPY_AND_ASSIGN(JitCodePool); +}; + +class JitCodeCreatorPool { + typedef std::unique_ptr GenCreatorPtr; + typedef std::unordered_map, + KernelKey::Hash> + GenCreatorPtrMap; + + public: + JitCodeCreatorPool() = default; + static JitCodeCreatorPool& Instance(); + GenCreatorPtrMap& AllCreators() { return creators_; } + void Insert(const KernelKey& key, GenCreatorPtr value) { + if (creators_.find(key) == creators_.end()) { + creators_.emplace(key, std::vector()); + } + creators_.at(key).emplace_back(std::move(value)); + } + + private: + GenCreatorPtrMap creators_; + DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool); +}; + +typedef std::unique_ptr KernelPtr; +typedef std::unordered_map, KernelKey::Hash> + KernelMap; + +class KernelPool { + public: + static KernelPool& Instance(); + KernelPool() = default; + KernelMap& AllKernels() { return pool_; } + void Insert(const KernelKey& key, KernelPtr value) { + if (pool_.find(key) == pool_.end()) { + pool_.emplace(key, std::vector()); + } + pool_.at(key).emplace_back(std::move(value)); + } + + private: + KernelMap pool_; + DISABLE_COPY_AND_ASSIGN(KernelPool); +}; + +// Every kernel should have refer code and it should be used in unit tests, +// so refer kernels should have it's independent kernel pool +class ReferKernelPool { + public: + static ReferKernelPool& Instance(); + ReferKernelPool() = default; + KernelMap& AllKernels() { return pool_; } + void Insert(const KernelKey& key, KernelPtr value) { + if (pool_.find(key) == pool_.end()) { + pool_.emplace(key, std::vector()); + } + pool_.at(key).emplace_back(std::move(value)); + } + + private: + KernelMap pool_; + DISABLE_COPY_AND_ASSIGN(ReferKernelPool); +}; + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/macro.h b/paddle/fluid/operators/jit/macro.h new file mode 100644 index 0000000000000000000000000000000000000000..b2622eba8b70cc553a2da44638d577c9d7751b25 --- /dev/null +++ b/paddle/fluid/operators/jit/macro.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace operators { +namespace jit { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/more/CMakeLists.txt b/paddle/fluid/operators/jit/more/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa503356baa73cb76e50ff19901a56d0c987ad99 --- /dev/null +++ b/paddle/fluid/operators/jit/more/CMakeLists.txt @@ -0,0 +1,17 @@ + +function(USE_JITKERNEL_MORE TARGET TYPE) + file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n") +endfunction() + +if(WITH_MKLML) + add_subdirectory(mkl) +endif() + +if(WITH_AVX) + add_subdirectory(intrinsic) +endif() + +# mix should be last +add_subdirectory(mix) + +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt b/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..468937a4f6b27ae525bfd0d8e99cc891eedbc353 --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt @@ -0,0 +1,9 @@ + +file(GLOB jit_kernel_cc_intrinsic RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") +cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_base) + +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE) + +# use mkl kernels by name and type +USE_JITKERNEL_MORE(kCRFDecoding, intrinsic) +USE_JITKERNEL_MORE(kLayerNorm, intrinsic) diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc new file mode 100644 index 0000000000000000000000000000000000000000..16c91f8246dda34b1436fd4edd507e9ff603de6b --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc @@ -0,0 +1,181 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h" +#include +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { +// Note: intrinsic code is not runtime build. +// For example, if you build code on AVX, and run on AVX512 it can only use AVX + +void CRFDecoding(const int seq_len, const float* x, const float* w, + float* alpha, int* track, int tag_num) { +#ifdef __AVX512F__ + const int step_size = ZMM_FLOAT_BLOCK; +#else + const int step_size = YMM_FLOAT_BLOCK; +#endif + const int end = tag_num / step_size; + const int rest = tag_num % step_size; + /* Setup the alpha initial value.*/ + int i_offset = 0; + int last_offset = rest - step_size; + for (int i = 0; i <= end; ++i) { +#ifdef __AVX512F__ + // Declare the variable for the content of weights, input and alpha values. + __m512 w_content, x_content, alpha_content; + // Load the relevant data into the variables from un-aligned address. + w_content = _mm512_loadu_ps(w + i_offset); + x_content = _mm512_loadu_ps(x + i_offset); + alpha_content = _mm512_add_ps(w_content, x_content); + // Save the alpha value. + _mm512_storeu_ps(alpha_value + i_offset, alpha_content); +#else + // AVX or AVX2 + // weights, input and alpha values. + __m256 w_content, x_content, alpha_content; + // Load the relevant data into the variables from un-aligned address. + w_content = _mm256_loadu_ps(w + i_offset); + x_content = _mm256_loadu_ps(x + i_offset); + alpha_content = _mm256_add_ps(w_content, x_content); + _mm256_storeu_ps(alpha + i_offset, alpha_content); +#endif + i_offset += step_size; + if (i == end - 1) { + if (rest > 0) { + i_offset += last_offset; + } else { + break; + } + } + } + // Use the column-major strategy to get the location of maximum score. + int seq_offset = 0; + constexpr int state_trans_base_idx = 2; + for (int k = 1; k < seq_len; ++k) { + int j_offset = 0; + for (int j = 0; j <= end; ++j) { +/* Initialize the variables of maximum score and location.*/ +#ifdef __AVX512F__ + __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); + __m512i max_j = _mm512_setzero_si512(); +#else + __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); + __m256i max_j = _mm256_set1_epi32(0); +#endif + /* Calculate the offset of transition_weights.*/ + int trans_offset = state_trans_base_idx * tag_num + j_offset; + for (int i = 0; i < tag_num; ++i) { +/* Initalize the content of alpha variable with related offset.*/ +#ifdef __AVX512F__ + __m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); + /* Obtain the content of weights from un-aligned address.*/ + __m512 w_content = _mm512_loadu_ps(w + trans_offset); + __m512 score_v = _mm512_add_ps(alpha_content, w_content); + __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); + /* AVX512 instructions.*/ + max_j = _mm512_mask_set1_epi32(max_j, mask, i); + /* Update the max_score value.*/ + max_score = _mm512_max_ps(max_score, score_v); + +#else + __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); + /* Obtain the content of weights from un-aligned address.*/ + __m256 w_content = _mm256_loadu_ps(w + trans_offset); + __m256 score_v = _mm256_add_ps(alpha_content, w_content); + __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); +/* According to the mask value, update the index of the max_score.*/ +#ifdef __AVX2__ + max_j = _mm256_or_si256( + _mm256_andnot_si256((__m256i)mask, max_j), + _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); +#else + __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); + __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); + __m128i lo_mask = + _mm256_extractf128_si256(*(__m256i*)&mask, 0); // NOLINT + __m128i hi_mask = + _mm256_extractf128_si256(*(__m256i*)&mask, 1); // NOLINT + lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); + hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); + lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); + hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); + lo_max_j = _mm_or_si128(lo_mask, lo_max_j); + hi_max_j = _mm_or_si128(hi_mask, hi_max_j); + max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); + max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); +#endif + /* Update the max_score value.*/ + max_score = _mm256_max_ps(max_score, score_v); + +#endif + + trans_offset += tag_num; + } +/* Update the alpha and track values. */ +#ifdef __AVX512F__ + __m512 x_content = + _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); + max_score = _mm512_add_ps(max_score, x_content); + _mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + + this->num_ + j_offset), + max_j); +#else + __m256 x_content = _mm256_loadu_ps(x + seq_offset + tag_num + j_offset); + max_score = _mm256_add_ps(max_score, x_content); + _mm256_storeu_ps(alpha + seq_offset + tag_num + j_offset, max_score); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(track + seq_offset + tag_num + j_offset), + max_j); +#endif + + /* Calculate the offset of next step*/ + j_offset += step_size; + if (j == end - 1) { + if (rest > 0) { + j_offset += last_offset; + } else { + break; + } + } + } + seq_offset += tag_num; + } +} + +bool CRFDecodingKernel::UseMe(const int& d) const { +#ifdef __AVX512F__ + constexpr int block = ZMM_FLOAT_BLOCK; +#else + constexpr int block = YMM_FLOAT_BLOCK; +#endif + return platform::MayIUse(platform::avx) && d >= block; +} + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace intrinsic = paddle::operators::jit::more::intrinsic; + +REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel); diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h new file mode 100644 index 0000000000000000000000000000000000000000..24179d90ddcc6e7f44ffa4b2ca0886fbca5c81bf --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/kernel_base.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { + +void CRFDecoding(const int seq_len, const float* x, const float* w, + float* alpha, int* track, int tag_num); + +class CRFDecodingKernel : public KernelMore> { + public: + CRFDecodingKernel() { this->func = CRFDecoding; } + bool UseMe( + const typename CRFDecodingTuples::attr_type&) const override; + const char* ImplType() const override { return "Intrinsic"; } +}; + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9b6e401c6825b21191881d4e57fe09b48d2f4ee --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h" +#include +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { + +void LayerNorm(float* x, float* out, float* mean, float* var, + const float* scale, const float* bias, int height, + const float epsilon, int right) { + __m256 sum; + __m256 mean_vec, var_vec; + __m128 hi, lo; + __m256 tmp; + size_t offset; + size_t j; + int block = YMM_FLOAT_BLOCK; + const int rest = right % block; + const int end = right - rest; + + __m256 reverse_num_vec = + _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + int rest_mask = + ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; + __m256i mask_vec = _mm256_set_epi32( + rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, + rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, + rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, + rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); + + for (int i = 0; i < height; ++i) { + offset = i * right; + + /* get mean */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)x + j); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); + } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + mean_vec = _mm256_mul_ps(sum, reverse_num_vec); + mean[i] = *reinterpret_cast(&mean_vec); + + /* get variance */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + sum = _mm256_add_ps(sum, tmp); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); + } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + var_vec = _mm256_mul_ps(sum, reverse_num_vec); + var[i] = *reinterpret_cast(&var_vec); + + /* get x_norm and calculate output*/ + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps(tmp, + _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps(tmp, + _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + + if (scale) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)scale + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(tmp, + _mm256_loadu_ps((const float*)scale + j - offset))); + } + } + + if (bias) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)bias + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps(reinterpret_cast(out) + j, + _mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias + + j - offset))); + } + } + } +} + +bool LayerNormKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; +} + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace intrinsic = paddle::operators::jit::more::intrinsic; + +REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel); diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..89da2940f4420c418f9bd5260c4b74606cc9168f --- /dev/null +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/kernel_base.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace intrinsic { + +void LayerNorm(float* x, float* out, float* mean, float* var, + const float* scale, const float* bias, int height, + const float epsilon, int right); + +class LayerNormKernel : public KernelMore> { + public: + LayerNormKernel() { this->func = LayerNorm; } + bool UseMe(const typename LayerNormTuples::attr_type&) const override; + const char* ImplType() const override { return "Intrinsic"; } +}; + +} // namespace intrinsic +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/more/mix/CMakeLists.txt b/paddle/fluid/operators/jit/more/mix/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e05f204b1eebd03c7a00157d96d0482f4a44a7fb --- /dev/null +++ b/paddle/fluid/operators/jit/more/mix/CMakeLists.txt @@ -0,0 +1,14 @@ + + +file(GLOB jit_kernel_mix_cc RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") +cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base) + +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE) + +USE_JITKERNEL_MORE(kVSigmoid, mix) +USE_JITKERNEL_MORE(kVTanh, mix) +USE_JITKERNEL_MORE(kLSTMCtHt, mix) +USE_JITKERNEL_MORE(kLSTMC1H1, mix) +USE_JITKERNEL_MORE(kGRUH1, mix) +USE_JITKERNEL_MORE(kGRUHtPart1, mix) +USE_JITKERNEL_MORE(kGRUHtPart2, mix) diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc new file mode 100644 index 0000000000000000000000000000000000000000..df0a85256b1f546d5f64be73925cf58b87a25bd7 --- /dev/null +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/more/mix/mix.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace mix { + +void VSigmoid(const T* x, T* y, int n) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + auto compute = Get, platform::CPUPlace>(n); + compute(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} + +void VTanh(const T* x, T* y, int n) { + const T a = 2, b = -1; + auto compute_scal = Get, platform::CPUPlace>(n); + auto compute_addbias = Get, platform::CPUPlace>(n); + auto compute_sigmoid = Get, platform::CPUPlace>(n); + compute_scal(&a, x, y, n); + compute_sigmoid(y, y, n); + compute_scal(&a, y, y, n); + compute_addbias(&b, y, y, n); +} + +void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT + if (type == kVSigmoid) { + return Get, platform::CPUPlace>(d); + } else if (type == kVRelu) { + return Get, platform::CPUPlace>(d); + } else if (type == kVTanh) { + return Get, platform::CPUPlace>(d); + } else if (type == kVIdentity) { + return Get, platform::CPUPlace>(d); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} + +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + const T* ct_1 = reinterpret_cast(step->ct_1); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + const T* wp = reinterpret_cast(step->wp); + T* checked = reinterpret_cast(step->checked); + const int d = attr->d; + const int d2 = d * 2; + const int d3 = d * 3; + auto vmul_d = Get, platform::CPUPlace>(d); + auto vadd_d = Get, platform::CPUPlace>(d); + auto vadd_d2 = Get, platform::CPUPlace>(d2); + auto act_gate_d = getActFunc(attr->act_gate, d); + auto act_gate_d2 = getActFunc(attr->act_gate, d2); + auto act_gate_d3 = getActFunc(attr->act_gate, d3); + auto act_cand_d = getActFunc(attr->act_cand, d); + auto act_cell_d = getActFunc(attr->act_cell, d); + + if (attr->use_peephole) { + vmul_d(wp, ct_1, checked, d); + vmul_d(wp + d, ct_1, checked + d, d); + vadd_d2(checked, gates + d, gates + d, d2); + act_gate_d2(gates + d, gates + d, d2); + } else { + act_gate_d3(gates + d, gates + d, d3); + } + + // C_t = C_t-1 * fgated + cand_gated * igated + act_cand_d(gates, gates, d); + vmul_d(gates, gates + d, gates + d, d); + vmul_d(ct_1, gates + d2, gates + d2, d); + vadd_d(gates + d, gates + d2, ct, d); + + if (attr->use_peephole) { + // get ogated + vmul_d(wp + d2, ct, gates + d, d); + vadd_d(gates + d, gates + d3, gates + d3, d); + act_gate_d(gates + d3, gates + d3, d); + } + // H_t = act_cell(C_t) * ogated + act_cell_d(ct, gates + d2, d); + vmul_d(gates + d2, gates + d3, ht, d); +} + +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ct = reinterpret_cast(step->ct); + T* ht = reinterpret_cast(step->ht); + int d = attr->d; + int d2 = d * 2; + int d3 = d * 3; + auto vmul_d = Get, platform::CPUPlace>(d); + auto vadd_d = Get, platform::CPUPlace>(d); + auto act_gate_d = getActFunc(attr->act_gate, d); + auto act_cand_d = getActFunc(attr->act_cand, d); + auto act_cell_d = getActFunc(attr->act_cell, d); + /* C_t = igated * cgated*/ + act_gate_d(gates + d, gates + d, d); + act_cand_d(gates, gates, d); + vmul_d(gates, gates + d, ct, d); + if (attr->use_peephole) { + // get outgated, put W_oc * C_t on igated + const T* wp = reinterpret_cast(step->wp); + vmul_d(wp + d2, ct, gates + d, d); + vadd_d(gates + d, gates + d3, gates + d3, d); + } + /* H_t = act_cell(C_t) * ogated */ + act_gate_d(gates + d3, gates + d3, d); + act_cell_d(ct, gates + d2, d); + vmul_d(gates + d2, gates + d3, ht, d); +} + +// compute h1 without h0 +void GRUH1(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + int d = attr->d; + int d2 = d * 2; + auto act_gate = getActFunc(attr->act_gate, d); + auto act_cand = getActFunc(attr->act_cand, d); + auto vmul_d = Get, platform::CPUPlace>(d); + act_gate(gates, gates, d); + act_cand(gates + d2, gates + d2, d); + vmul_d(gates, gates + d2, ht, d); +} + +// compute the first part of GRU: ht = act_gate(r) * ht_1 +void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { + // W: {W_update, W_reset; W_state} + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + auto act_gate = getActFunc(attr->act_gate, attr->d); + auto vmul_d = Get, platform::CPUPlace>(attr->d); + act_gate(gates + attr->d, gates + attr->d, attr->d); + vmul_d(ht_1, gates + attr->d, ht, attr->d); +} + +// compute the second part of GRU: +// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 +void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { + T* gates = reinterpret_cast(step->gates); + T* ht = reinterpret_cast(step->ht); + const T* ht_1 = reinterpret_cast(step->ht_1); + int d = attr->d; + auto act_gate = getActFunc(attr->act_gate, d); + auto act_cand = getActFunc(attr->act_cand, d); + T* y = gates + d * 2; + act_gate(gates, gates, d); + act_cand(y, y, d); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } +} + +// TODO(TJ): tuning me +bool VSigmoidKernel::UseMe(const int& d) const { return true; } + +bool VTanhKernel::UseMe(const int& d) const { return true; } + +bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } + +bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } + +bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } + +bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } + +bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } + +} // namespace mix +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace mix = paddle::operators::jit::more::mix; + +#define REGISTER_MORE_KERNEL(key, func) \ + REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) + +REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); +REGISTER_MORE_KERNEL(kVTanh, VTanh); +REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); +REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); +REGISTER_MORE_KERNEL(kGRUH1, GRUH1); +REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); +REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); + +#undef REGISTER_MORE_KERNEL diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h new file mode 100644 index 0000000000000000000000000000000000000000..a70ecdf9348f511311307b4c27bb4506222a7439 --- /dev/null +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/kernel_base.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace mix { +using T = float; + +void VSigmoid(const T* x, T* y, int n); +void VTanh(const T* x, T* y, int n); + +void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr); +void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr); +void GRUH1(gru_t* step, const gru_attr_t* attr); +void GRUHtPart1(gru_t* step, const gru_attr_t* attr); +void GRUHtPart2(gru_t* step, const gru_attr_t* attr); + +#define DECLARE_MORE_KERNEL(name, tuples) \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename tuples::attr_type&) const override; \ + const char* ImplType() const override { return "Mixed"; } \ + } + +// XYN +DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); +DECLARE_MORE_KERNEL(VTanh, XYNTuples); + +DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); +DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); + +DECLARE_MORE_KERNEL(GRUH1, GRUTuples); +DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); +DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); + +#undef DECLARE_MORE_KERNEL + +} // namespace mix +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..863cc720d68ce3dcfe045aa11c559a06a50909f3 --- /dev/null +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -0,0 +1,11 @@ + +cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) + +# use mkl kernels by name and type +USE_JITKERNEL_MORE(kVMul, mkl) +USE_JITKERNEL_MORE(kVAdd, mkl) +USE_JITKERNEL_MORE(kVScal, mkl) +USE_JITKERNEL_MORE(kVExp, mkl) +USE_JITKERNEL_MORE(kVSigmoid, mkl) +USE_JITKERNEL_MORE(kVTanh, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5b088d4812b8a54e3b4fb1cb83d9e8bc7501994 --- /dev/null +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/more/mkl/mkl.h" +#include "paddle/fluid/operators/jit/refer/refer.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/dynload/mklml.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace mkl { + +template <> +void VMul(const float* x, const float* y, float* z, int n) { + platform::dynload::vsMul(n, x, y, z); +} + +template <> +void VMul(const double* x, const double* y, double* z, int n) { + platform::dynload::vdMul(n, x, y, z); +} + +template <> +void VAdd(const float* x, const float* y, float* z, int n) { + platform::dynload::vsAdd(n, x, y, z); +} + +template <> +void VAdd(const double* x, const double* y, double* z, int n) { + platform::dynload::vdAdd(n, x, y, z); +} + +template <> +void VScal(const float* a, const float* x, float* y, int n) { + if (x == y) { + platform::dynload::cblas_sscal(n, *a, y, 1); + } else { + refer::VScal(a, x, y, n); + } +} + +template <> +void VScal(const double* a, const double* x, double* y, int n) { + if (x == y) { + platform::dynload::cblas_dscal(n, *a, y, 1); + } else { + refer::VScal(a, x, y, n); + } +} + +template <> +void VExp(const float* x, float* y, int n) { + platform::dynload::vsExp(n, x, y); +} + +template <> +void VExp(const double* x, double* y, int n) { + platform::dynload::vdExp(n, x, y); +} + +// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 +template <> +bool VMulKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +template <> +bool VAddKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +template <> +bool VScalKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +template <> +bool VExpKernel::UseMe(const int& d) const { + return d > 7; +} + +template <> +bool VSigmoidKernel::UseMe(const int& d) const { + return d > 7; +} + +template <> +bool VTanhKernel::UseMe(const int& d) const { + return d > 7; +} + +#define AWALYS_USE_ME_WITH_DOUBLE(func) \ + template <> \ + bool func##Kernel::UseMe(const int& d) const { \ + return true; \ + } + +AWALYS_USE_ME_WITH_DOUBLE(VMul); +AWALYS_USE_ME_WITH_DOUBLE(VAdd); +AWALYS_USE_ME_WITH_DOUBLE(VScal); +AWALYS_USE_ME_WITH_DOUBLE(VExp); +AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); +AWALYS_USE_ME_WITH_DOUBLE(VTanh); + +#undef AWALYS_USE_ME_WITH_DOUBLE +} // namespace mkl +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace mkl = paddle::operators::jit::more::mkl; + +#define REGISTER_MKL_KERNEL(key, func) \ + REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel, \ + mkl::func##Kernel) + +REGISTER_MKL_KERNEL(kVMul, VMul); +REGISTER_MKL_KERNEL(kVAdd, VAdd); +REGISTER_MKL_KERNEL(kVScal, VScal); +REGISTER_MKL_KERNEL(kVExp, VExp); +REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); +REGISTER_MKL_KERNEL(kVTanh, VTanh); + +#undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h new file mode 100644 index 0000000000000000000000000000000000000000..ee1031c028ff72181f504004b7cbeb9f7ee578f1 --- /dev/null +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/operators/jit/kernel_base.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace more { +namespace mkl { + +template +void VMul(const T* x, const T* y, T* z, int n); + +template +void VAdd(const T* x, const T* y, T* z, int n); + +template +void VScal(const T* a, const T* x, T* y, int n); + +template +void VExp(const T* x, T* y, int n); + +template +void VSigmoid(const T* x, T* y, int n) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + VExp(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} + +template +void VTanh(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * x[i]; + } + VSigmoid(y, y, n); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(2) * y[i] - static_cast(1); + } +} + +#define DECLARE_MKL_KERNEL(name, tuples) \ + template \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename tuples::attr_type&) const override; \ + const char* ImplType() const override { return "MKL"; } \ + } + +// XYZN +DECLARE_MKL_KERNEL(VMul, XYZNTuples); +DECLARE_MKL_KERNEL(VAdd, XYZNTuples); + +// AXYN +DECLARE_MKL_KERNEL(VScal, AXYNTuples); + +// XYN +DECLARE_MKL_KERNEL(VExp, XYNTuples); +DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); +DECLARE_MKL_KERNEL(VTanh, XYNTuples); + +#undef DECLARE_MKL_KERNEL + +} // namespace mkl +} // namespace more +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..07497b732050a7299e224531db37eb56e60ef605 --- /dev/null +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -0,0 +1,28 @@ + +cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base) +set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE) + +function(USE_JITKERNEL_REFER TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n") +endfunction() + +# use refer kernel by name +USE_JITKERNEL_REFER(kVMul) +USE_JITKERNEL_REFER(kVAdd) +USE_JITKERNEL_REFER(kVAddRelu) +USE_JITKERNEL_REFER(kVSub) +USE_JITKERNEL_REFER(kVScal) +USE_JITKERNEL_REFER(kVAddBias) +USE_JITKERNEL_REFER(kVRelu) +USE_JITKERNEL_REFER(kVIdentity) +USE_JITKERNEL_REFER(kVExp) +USE_JITKERNEL_REFER(kVSigmoid) +USE_JITKERNEL_REFER(kVTanh) +USE_JITKERNEL_REFER(kLSTMCtHt) +USE_JITKERNEL_REFER(kLSTMC1H1) +USE_JITKERNEL_REFER(kGRUH1) +USE_JITKERNEL_REFER(kGRUHtPart1) +USE_JITKERNEL_REFER(kGRUHtPart2) +USE_JITKERNEL_REFER(kCRFDecoding) +USE_JITKERNEL_REFER(kLayerNorm) +USE_JITKERNEL_REFER(kNCHW16CMulNC) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc new file mode 100644 index 0000000000000000000000000000000000000000..d196266326b4ee668f647fa51032f6344d26e5c6 --- /dev/null +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/refer/refer.h" +#include "paddle/fluid/operators/jit/registry.h" + +namespace refer = paddle::operators::jit::refer; + +#define REGISTER_REFER_KERNEL(key, func) \ + REGISTER_JITKERNEL_REFER(key, refer::func##Kernel, \ + refer::func##Kernel) + +REGISTER_REFER_KERNEL(kVMul, VMul); +REGISTER_REFER_KERNEL(kVAdd, VAdd); +REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); +REGISTER_REFER_KERNEL(kVSub, VSub); + +REGISTER_REFER_KERNEL(kVScal, VScal); +REGISTER_REFER_KERNEL(kVAddBias, VAddBias); + +REGISTER_REFER_KERNEL(kVRelu, VRelu); +REGISTER_REFER_KERNEL(kVIdentity, VIdentity); +REGISTER_REFER_KERNEL(kVExp, VExp); +REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); +REGISTER_REFER_KERNEL(kVTanh, VTanh); + +REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); +REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); + +REGISTER_REFER_KERNEL(kGRUH1, GRUH1); +REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); +REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); + +REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); +REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); + +REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); + +#undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/jit/refer/refer.h similarity index 54% rename from paddle/fluid/operators/math/jit_kernel_refer.h rename to paddle/fluid/operators/jit/refer/refer.h index e0b2e3c7fada6b422318c68a42fd6d103c99af5a..0fd1b89dfdba9f4655f649fa6d32604188c78da3 100644 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -1,30 +1,31 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once + #include -#include -#include "paddle/fluid/operators/math/jit_kernel_impl.h" +#include +#include "paddle/fluid/operators/jit/helper.h" +#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { -namespace math { -namespace jitkernel { +namespace jit { namespace refer { -/* Refer code only focus on correctness */ +// Refer code only focus on correctness template void VMul(const T* x, const T* y, T* z, int n) { for (int i = 0; i < n; ++i) { @@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) { } } +template +void VSub(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] - y[i]; + } +} + template void VScal(const T* a, const T* x, T* y, int n) { for (int i = 0; i < n; ++i) { @@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) { } template -inline void VIdentity(const T* x, T* y, int n) {} +inline void VIdentity(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = x[i]; + } +} template void VExp(const T* x, T* y, int n) { @@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) { } template -void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT - if (type == "sigmoid") { +void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT + if (type == kVSigmoid) { return VSigmoid; - } else if (type == "relu") { + } else if (type == kVRelu) { return VRelu; - } else if (type == "tanh") { + } else if (type == kVTanh) { return VTanh; - } else if (type == "identity" || type == "") { + } else if (type == kVIdentity) { return VIdentity; } PADDLE_THROW("Not support type: %s", type); return nullptr; } +// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels + // compute ct and ht template void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { @@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { } } +template +void CRFDecoding(const int seq_len, const T* x, const T* w, T* alpha, + int* track, int right) { + constexpr int state_trans_base_idx = 2; + for (int i = 0; i < right; ++i) { + alpha[i] = w[i] + x[i]; + } + for (int k = 1; k < seq_len; ++k) { + for (int i = 0; i < right; ++i) { + T max_score = -std::numeric_limits::max(); + int max_j = 0; + for (int j = 0; j < right; ++j) { + T score = alpha[(k - 1) * right + j] + + w[(j + state_trans_base_idx) * right + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + alpha[k * right + i] = max_score + x[k * right + i]; + track[k * right + i] = max_j; + } + } +} + +template +void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, + int height, const float epsilon, int right) { + // get mean + for (int i = 0; i < height; i++) { + T sum = 0.0; + int offset = i * right; + for (int j = 0; j < right; j++) { + sum += x[offset + j]; + } + mean[i] = sum / right; + } + + // get variance + for (int i = 0; i < height; i++) { + T sum = 0.0; + int offset = i * right; + for (int j = 0; j < right; j++) { + sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]); + } + var[i] = sum / right; + } + + for (int i = 0; i < height; i++) { + int offset = i * right; + T sqrt_var = std::sqrt(var[i] + (T)epsilon); + for (int j = 0; j < right; j++) { + out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var; + } + } + if (scale) { + for (int i = 0; i < height; i++) { + int offset = i * right; + for (int j = 0; j < right; j++) { + out[offset + j] *= scale[j]; + } + } + } + + if (bias) { + for (int i = 0; i < height; i++) { + int offset = i * right; + for (int j = 0; j < right; j++) { + out[offset + j] += bias[j]; + } + } + } +} + +template +void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { + int offset = 0; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + for (int i = 0; i < 16; ++i) { + z[i + offset] = y[i] * x[i + offset]; + } + offset += ZMM_FLOAT_BLOCK; + } + } +} + +#define DECLARE_REFER_KERNEL(name, tuples) \ + template \ + class name##Kernel : public ReferKernel> { \ + public: \ + name##Kernel() { this->func = name; } \ + } + +// const T* x, const T* y, T* z, int n +DECLARE_REFER_KERNEL(VMul, XYZNTuples); +DECLARE_REFER_KERNEL(VAdd, XYZNTuples); +DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); +DECLARE_REFER_KERNEL(VSub, XYZNTuples); + +// const T* a, const T* x, T* y, int n +DECLARE_REFER_KERNEL(VScal, AXYNTuples); +DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); + +// const T* x, T* y, int n +DECLARE_REFER_KERNEL(VRelu, XYNTuples); +DECLARE_REFER_KERNEL(VIdentity, XYNTuples); +DECLARE_REFER_KERNEL(VExp, XYNTuples); +DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); +DECLARE_REFER_KERNEL(VTanh, XYNTuples); + +// lstm_t*, const lstm_attr_t* +DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); +DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); + +// gru_t*, const gru_attr_t* +DECLARE_REFER_KERNEL(GRUH1, GRUTuples); +DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); +DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); + +DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); +DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); + +DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); + +#undef DECLARE_REFER_KERNEL + } // namespace refer -} // namespace jitkernel -} // namespace math +} // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/registry.h b/paddle/fluid/operators/jit/registry.h new file mode 100644 index 0000000000000000000000000000000000000000..cb32c487208fe8fe9e72c069db8833c736316aec --- /dev/null +++ b/paddle/fluid/operators/jit/registry.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/operators/jit/kernel_pool.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/variant.h" // for UNUSED + +namespace paddle { +namespace operators { +namespace jit { + +// make_unique is supported since c++14 +template +inline std::unique_ptr make_unique(Args&&... args) { + static_assert(!std::is_array::value, "T must not be array"); + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +struct JitKernelRegistrarFunctor; + +template +struct JitKernelRegistrarFunctor { + void operator()(KernelType kt) const {} +}; + +template +struct JitKernelRegistrarFunctor { + using KERNEL_IMPL_TYPE = + typename std::tuple_element>::type; + + void operator()(KernelType kt) const { + KernelKey kkey(kt, PlaceType()); + Pool().Instance().Insert(kkey, + std::move(make_unique())); + constexpr auto size = std::tuple_size>::value; + JitKernelRegistrarFunctor + func; + func(kt); + } +}; + +template +class JitKernelRegistrar { + public: + explicit JitKernelRegistrar(KernelType kt) { + JitKernelRegistrarFunctor func; + func(kt); + } + void Touch() {} +}; + +#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +// Refer always on CPUPlace +#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_refer_CPUPlace, \ + "REGISTER_KERNEL_REFER must be called in global namespace"); \ + static ::paddle::operators::jit::JitKernelRegistrar< \ + ::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \ + __VA_ARGS__> \ + __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \ + ::paddle::operators::jit::KernelType::kernel_type); \ + int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ + __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \ + return 0; \ + } + +// kernel_type: should be in paddle::operators::jit::KernelType +// place_type: should be one of CPUPlace and GPUPlace in paddle::platform +#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ + "REGISTER_KERNEL_MORE must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \ + UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static ::paddle::operators::jit::JitKernelRegistrar< \ + ::paddle::operators::jit::KernelPool, ::paddle::platform::place_type, \ + __VA_ARGS__> \ + __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \ + ::paddle::operators::jit::KernelType::kernel_type); \ + int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ + __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \ + .Touch(); \ + return 0; \ + } + +#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \ + REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) + +#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ + REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) + +#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ + "REGISTER_JITKERNEL_GEN must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \ + TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static ::paddle::operators::jit::JitKernelRegistrar< \ + ::paddle::operators::jit::JitCodeCreatorPool, \ + ::paddle::platform::CPUPlace, __VA_ARGS__> \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \ + ::paddle::operators::jit::KernelType::kernel_type); \ + int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \ + return 0; \ + } + +#define USE_JITKERNEL_GEN(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ + "USE_JITKERNEL_GEN must be called in global namespace"); \ + extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \ + static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \ + TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() + +#define USE_JITKERNEL_REFER(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_refer_CPUPlace_, \ + "USE_JITKERNEL_REFER must be called in global namespace"); \ + extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \ + TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() + +#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ + __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \ + "USE_JITKERNEL_MORE must be called in global namespace"); \ + extern int \ + TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \ + static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \ + UNUSED = \ + TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() + +#define USE_JITKERNEL_MORE(kernel_type, impl_type) \ + USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace) + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a73e2a60aeb0c1594b5072b2bffbd11cccfcdc7d --- /dev/null +++ b/paddle/fluid/operators/jit/test.cc @@ -0,0 +1,584 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/place.h" + +template +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f)) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +template +void ExpectEQ(const T* target, const T* refer, int n) { + if (std::is_floating_point::value) { + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(target[i], refer[i], 1e-5); + } + } else { + for (int i = 0; i < n; ++i) { + EXPECT_EQ(target[i], refer[i]); + } + } +} + +std::vector TestSizes() { + std::vector s; + for (int i = 1; i < 32; ++i) { + s.push_back(i); + } + // test some large size + s.push_back(100); + s.push_back(1000); + s.push_back(2000); + return s; +} + +namespace jit = paddle::operators::jit; + +template +struct TestFuncWithRefer { + void operator()(const typename KernelTuples::func_type tgt, Args... args) {} +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector> { + void operator()(const typename jit::XYZNTuples::func_type tgt, + const std::vector& x, const std::vector& y, + const std::vector& zref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(zref.size(), x.size()); + EXPECT_EQ(zref.size(), y.size()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* zref_data = zref.data(); + const int d = zref.size(); + + std::vector ztgt(d); + T* ztgt_data = ztgt.data(); + // test normal + tgt(x_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ztgt.begin()); + tgt(ztgt_data, y_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + // test inplace y + std::copy(y.begin(), y.end(), ztgt.begin()); + tgt(x_data, ztgt_data, ztgt_data, d); + ExpectEQ(ztgt_data, zref_data, d); + } +}; + +template +struct TestFuncWithRefer, T, std::vector, + std::vector> { + void operator()(const typename jit::AXYNTuples::func_type tgt, const T a, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(&a, x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(&a, ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector> { + void operator()(const typename jit::XYNTuples::func_type tgt, + const std::vector& x, const std::vector& yref) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(yref.size(), x.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + const int d = yref.size(); + std::vector ytgt(d); + T* ytgt_data = ytgt.data(); + // test normal + tgt(x_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + // test inplace x + std::copy(x.begin(), x.end(), ytgt.begin()); + tgt(ytgt_data, ytgt_data, d); + ExpectEQ(ytgt_data, yref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector, std::vector, std::vector> { + void operator()(const typename jit::LSTMTuples::func_type tgt, + const std::vector& xsrc, const std::vector& wp, + const std::vector& ct_1, const std::vector& ct_ref, + const std::vector& ht_ref, + const typename jit::LSTMTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ct_ref.size(), ht_ref.size()); + EXPECT_EQ(ct_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 4 * ht_ref.size()); + EXPECT_EQ(wp.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size()); + std::vector checked(2 * d); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + const T* ct_ref_data = ct_ref.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ct_data = ct.data(); + T* ht_data = ht.data(); + T* checked_data = checked.data(); + + paddle::operators::jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_data; + step.ht = ht_data; + if (attr.use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + + tgt(&step, &attr); + ExpectEQ(ct_data, ct_ref_data, d); + ExpectEQ(ht_data, ht_ref_data, d); + } +}; + +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector> { + void operator()(const typename jit::GRUTuples::func_type tgt, + const std::vector& xsrc, const std::vector& ht_1, + const std::vector& ht_ref, + const typename jit::GRUTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(ht_1.size(), ht_ref.size()); + EXPECT_EQ(xsrc.size(), 3 * ht_ref.size()); + + // x could be changed after compute, so copy to save src + int d = ht_ref.size(); + std::vector x(xsrc.size()), ht(ht_ref.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + const T* ht_ref_data = ht_ref.data(); + T* x_data = x.data(); + T* ht_data = ht.data(); + paddle::operators::jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_data; + tgt(&step, &attr); + ExpectEQ(ht_data, ht_ref_data, d); + } +}; + +template +void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { + TestFuncWithRefer test; + // test jitcode + auto jitcode = jit::GetJitCode(attr); + if (jitcode) { + VLOG(10) << "Test Jitcode Kernel "; + test(jitcode, args...); + } + // test all impls in more + jit::KernelKey kkey(KT, PlaceType()); + auto& pool = jit::KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); + if (i && i->UseMe(attr)) { + auto more = i->GetFunc(); + VLOG(10) << "Test More Kernel : " << i->ImplType(); + test(more, args...); + } + } + } + // test result from Get function + // VLOG(10) << "Test Get function "; + auto tgt = jit::Get(attr); + test(tgt, args...); +} + +template +void TestXYZNKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int d : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + + std::vector x(d), y(d), zref(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + + std::vector xinp(d), yinp(d); // inplace test + std::copy(x.begin(), x.end(), xinp.begin()); + std::copy(y.begin(), y.end(), yinp.begin()); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* zref_data = zref.data(); + T* xinp_data = xinp.data(); + T* yinp_data = yinp.data(); + + // test refer code inplace + ref(x_data, y_data, zref_data, d); + ref(x_data, yinp_data, yinp_data, d); + ref(xinp_data, y_data, xinp_data, d); + ExpectEQ(xinp_data, zref_data, d); + ExpectEQ(yinp_data, zref_data, d); + + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(d, x, y, zref); + } +} + +template +void TestAXYNKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int d : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + + const T a = static_cast(3); + std::vector x(d), yref(d); + std::vector xinp(d); // inplace test + RandomVec(d, x.data()); + std::copy(x.begin(), x.end(), xinp.begin()); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* xinp_data = xinp.data(); + // test refer code inplace + ref(&a, x_data, yref_data, d); + ref(&a, xinp_data, xinp_data, d); + ExpectEQ(xinp_data, yref_data, d); + + TestAllImpls, PlaceType, T, std::vector, + std::vector>(d, a, x, yref); + } +} + +template +void TestXYNKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int d : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + + std::vector x(d), yref(d); + std::vector xinp(d); // inplace test + RandomVec(d, x.data(), -2.f, 2.f); + std::copy(x.begin(), x.end(), xinp.begin()); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* xinp_data = xinp.data(); + // test refer code inplace + ref(x_data, yref_data, d); + ref(xinp_data, xinp_data, d); + ExpectEQ(xinp_data, yref_data, d); + + TestAllImpls, PlaceType, std::vector, + std::vector>(d, x, yref); + } +} + +template +void TestLSTMKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; + for (int d : TestSizes()) { + for (bool use_peephole : {true, false}) { + for (auto& act_gate : all_acts) { + for (auto& act_cand : all_acts) { + for (auto& act_cell : all_acts) { + const jit::lstm_attr_t attr( + d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand), + jit::to_kerneltype(act_cell), use_peephole); + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector xsrc(4 * d), wp(3 * d), ct_1(d); + std::vector ct_ref(d), ht_ref(d), checked(2 * d); + RandomVec(4 * d, xsrc.data(), -2.f, 2.f); + RandomVec(3 * d, wp.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + // x could be changed after compute, so copy to save src + std::vector x(xsrc.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ct_1_data = ct_1.data(); + const T* wp_data = wp.data(); + T* x_data = x.data(); + T* checked_data = checked.data(); + T* ct_ref_data = ct_ref.data(); + T* ht_ref_data = ht_ref.data(); + jit::lstm_t step; + step.gates = x_data; + step.ct_1 = ct_1_data; + step.ct = ct_ref_data; + step.ht = ht_ref_data; + if (use_peephole) { + step.wp = wp_data; + step.checked = checked_data; + } + ref(&step, &attr); + VLOG(10) << attr; + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector, std::vector, + std::vector>(attr, xsrc, wp, ct_1, ct_ref, ht_ref, + attr); + } + } + } + } + } +} + +template +void TestGRUKernel() { + namespace jit = paddle::operators::jit; + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; + for (int d : TestSizes()) { + for (auto& act_gate : all_acts) { + for (auto& act_cand : all_acts) { + const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), + jit::to_kerneltype(act_cand)); + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector xsrc(3 * d), ht_1(d), ht_ref(d); + RandomVec(3 * d, xsrc.data(), -2.f, 2.f); + RandomVec(d, ht_1.data(), -2.f, 2.f); + // x could be changed after compute, so copy to save src + std::vector x(xsrc.size()); + std::copy(xsrc.begin(), xsrc.end(), x.begin()); + const T* ht_1_data = ht_1.data(); + T* x_data = x.data(); + T* ht_ref_data = ht_ref.data(); + jit::gru_t step; + step.gates = x_data; + step.ht_1 = ht_1_data; + step.ht = ht_ref_data; + ref(&step, &attr); + VLOG(10) << attr; + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(attr, xsrc, ht_1, ht_ref, + attr); + } + } + } +} + +template +void TestNCHW16CMulNCKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + const int n = 3, c = 16 * 4, h = 10, w = 10; + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + int sz = n * c * h * w; + std::vector x(sz), y(n * c), zref(sz); + std::vector ztgt(sz), zjit(sz); + RandomVec(sz, x.data(), -2.f, 2.f); + RandomVec(n * c, y.data(), -2.f, 2.f); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* zref_data = zref.data(); + T* ztgt_data = ztgt.data(); + T* zjit_data = zjit.data(); + constexpr int simd_width = ZMM_FLOAT_BLOCK; + int C = c / simd_width; + auto tgt = jit::Get, PlaceType>(0); + auto jitcode = jit::GetJitCode, PlaceType>(0); + EXPECT_TRUE(tgt != nullptr); + + if (std::is_same::value && + paddle::platform::MayIUse(paddle::platform::avx512f)) { + EXPECT_TRUE(jitcode != nullptr); + } + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_zref = + zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_ztgt = + ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + ref(ptr_x, ptr_y, ptr_zref, h, w); + tgt(ptr_x, ptr_y, ptr_ztgt, h, w); + + if (jitcode) { + auto ptr_zjit = + zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + jitcode(ptr_x, ptr_y, ptr_zjit, h, w); + } + } + } + ExpectEQ(ztgt_data, zref_data, sz); + if (jitcode) { + ExpectEQ(zjit_data, zref_data, sz); + } +} + +// XYZNTuple +TEST(JITKernel, kVMul) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, kVAdd) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, kVAddRelu) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +TEST(JITKernel, kVSub) { + namespace jit = paddle::operators::jit; + TestXYZNKernel(); + TestXYZNKernel(); +} + +// AXYNTuples +TEST(JITKernel, kVScal) { + namespace jit = paddle::operators::jit; + TestAXYNKernel(); + TestAXYNKernel(); +} + +TEST(JITKernel, kVAddBias) { + namespace jit = paddle::operators::jit; + TestAXYNKernel(); + TestAXYNKernel(); +} + +// XYNTuples +TEST(JITKernel, kVRelu) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, kVIdentity) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, kVExp) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, kVSigmoid) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +TEST(JITKernel, kVTanh) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + +// LSTM +TEST(JITKernel, kLSTMCtHt) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +TEST(JITKernel, kLSTMC1H1) { + namespace jit = paddle::operators::jit; + TestLSTMKernel(); + TestLSTMKernel(); +} + +// GRU +TEST(JITKernel, kGRUH1) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + +TEST(JITKernel, kGRUHtPart1) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + +TEST(JITKernel, kGRUHtPart2) { + namespace jit = paddle::operators::jit; + TestGRUKernel(); + TestGRUKernel(); +} + +TEST(JITKernel, kNCHW16CMulNC) { + namespace jit = paddle::operators::jit; + TestNCHW16CMulNCKernel(); + TestNCHW16CMulNCKernel(); +} + +// TODO(yihua/TJ): add crf decoding and layer norm unit tests + +TEST(JITKernel, pool) { + // TODO(TJ): add some test +} diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 78d20ddf5fd63b81fd5e7fba656d825897a67a11..f564a103963bd93732165596712230b0f37f7f26 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) -#include "paddle/fluid/operators/math/jit_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" #endif #include "paddle/fluid/operators/math/math_function.h" @@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right); - const auto& ker = math::jitkernel::KernelPool::Instance() - .template Get>( - static_cast(right)); - ker->Compute(x.data(), out.data(), mean->data(), var->data(), - scale->data(), bias->data(), static_cast(left), - static_cast(epsilon)); + auto ker = + jit::Get, platform::CPUPlace>( + right); + ker(x.data(), out.data(), mean->data(), var->data(), + scale->data(), bias->data(), static_cast(left), + static_cast(epsilon), right); #endif } }; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index b3d2ea38eb1bfffadc1f68c5a34bc4d557bdea3b..ea6aebd291eee580a307aa112117434fa942005e 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -73,12 +73,3 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) - -set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc) -set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) -if(WITH_XBYAK) - list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) - list(APPEND JIT_KERNEL_DEPS xbyak) -endif() -cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) -cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 5b9953a5aa9a29bd917d16a16c678fc32a32c18f..cddd0a18db53a7ddf9ca14d5f373180586ef6a31 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/jit_kernel.h" namespace paddle { namespace operators { @@ -30,22 +30,21 @@ inline void FCCompute(const BlasT& blas, const int M, return; } if (relu) { - const auto& vaddrelu = jitkernel::KernelPool::Instance() - .template Get>(N); + auto compute = + jit::Get, platform::CPUPlace>(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vaddrelu->Compute(B, dst, dst, N); + compute(B, dst, dst, N); } } else { - const auto& vadd = jitkernel::KernelPool::Instance() - .template Get>(N); - + auto compute = + jit::Get, platform::CPUPlace>(N); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vadd->Compute(B, dst, dst, N); + compute(B, dst, dst, N); } } } diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc deleted file mode 100644 index 2b08c1059713fb9acd0cfdcf39ac2ad283172724..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_code.cc +++ /dev/null @@ -1,334 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_code.h" -#include // offsetof -#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -using namespace platform; // NOLINT - -bool VXXJitCode::init(int d, int scalar_index) { - // It's not necessary to use avx512 since it would slow down the frequency - // and this kernel is not compute bound. - return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2; -} - -void VXXJitCode::generate() { - // do not need push stack, and do not need save avx512reg if do not use avx512 - int offset = 0; - if (with_relu_) { - vxorps(ymm_zero, ymm_zero, ymm_zero); - } - if (scalar_index_ == 1) { - vbroadcastss(ymm_src1, ptr[param1]); - } else if (scalar_index_ == 2) { - vbroadcastss(ymm_src2, ptr[param2]); - } - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - if (scalar_index_ != 1) { - vmovups(ymm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(ymm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulps(ymm_dst, ymm_src1, ymm_src2); - } else if (type_ == operand_type::add) { - vaddps(ymm_dst, ymm_src1, ymm_src2); - } - if (with_relu_) { - vmaxps(ymm_dst, ymm_zero, ymm_dst); - } - vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - int rest = num_ % YMM_FLOAT_BLOCK; - while (rest > 0) { - int block = XMM_FLOAT_BLOCK; - if (rest >= 4) { - block = 4; - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - } else if (rest >= 2) { - block = 2; - if (scalar_index_ != 1) { - vmovq(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovq(xmm_src2, ptr[param2 + offset]); - } - } else { - block = 1; - if (scalar_index_ != 1) { - vmovss(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovss(xmm_src2, ptr[param2 + offset]); - } - } - switch (type_) { - case operand_type::mul: - vmulps(xmm_dst, xmm_src1, xmm_src2); - break; - case operand_type::add: - vaddps(xmm_dst, xmm_src1, xmm_src2); - break; - default: - break; - } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - if (rest >= 4) { - vmovups(ptr[param3 + offset], xmm_dst); - } else if (rest >= 2) { - vmovq(ptr[param3 + offset], xmm_dst); - } else { - vmovss(ptr[param3 + offset], xmm_dst); - } - offset += sizeof(float) * block; - rest -= block; - } - ret(); -} - -const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = { - REPEAT_8TIMES(1.f), - REPEAT_8TIMES(2.f), - REPEAT_8TIMES(0.5f), - REPEAT_8TIMES(EXP_HIG), - REPEAT_8TIMES(EXP_LOW), - REPEAT_8TIMES(CEPHES_LOG2EF), - REPEAT_8TIMES(CEPHES_EXP_C1), - REPEAT_8TIMES(CEPHES_EXP_C2), - REPEAT_8TIMES(CEPHES_EXP_P0), - REPEAT_8TIMES(CEPHES_EXP_P1), - REPEAT_8TIMES(CEPHES_EXP_P2), - REPEAT_8TIMES(CEPHES_EXP_P3), - REPEAT_8TIMES(CEPHES_EXP_P4), - REPEAT_8TIMES(CEPHES_EXP_P5), - REPEAT_8TIMES(EXP_MAX_INPUT), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; - -const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)}; -int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0}; - -bool VActJitCode::init(int d, operand_type type) { - // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 - return MayIUse(avx); -} - -void VActJitCode::generate() { - int offset = 0; - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - vmovups(ymm_src, ptr[param1 + offset]); - act(ymm_dst, ymm_src, type_); - vmovups(ptr[param2 + offset], ymm_dst); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - int rest = num_ % YMM_FLOAT_BLOCK; - while (rest > 0) { - int block = XMM_FLOAT_BLOCK; - if (rest >= 4) { - block = 4; - vmovups(xmm_src, ptr[param1 + offset]); - } else if (rest >= 2) { - block = 2; - vmovq(xmm_src, ptr[param1 + offset]); - } else { - block = 1; - vmovss(xmm_src, ptr[param1 + offset]); - } - act(xmm_dst, xmm_src, type_); - if (rest >= 4) { - vmovups(ptr[param2 + offset], xmm_dst); - } else if (rest >= 2) { - vmovq(ptr[param2 + offset], xmm_dst); - } else { - vmovss(ptr[param2 + offset], xmm_dst); - } - offset += sizeof(float) * block; - rest -= block; - } - ret(); -} - -bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } - -void LSTMJitCode::generate() { - if (use_peephole_) { - preCode(); - } - reg64_t reg_ptr_gates = rax; - reg64_t reg_ptr_ct_1 = r9; - reg64_t reg_ptr_ct = r10; - reg64_t reg_ptr_ht = r11; - reg64_t reg_ptr_wp = r12; - mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); - mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); - mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); - mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); - if (use_peephole_) { - mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]); - } - - int offset = 0; - int d = num_ * sizeof(float); - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - /* gates: W_ch, W_ih, W_fh, W_oh */ - ymm_t ymm_c = ymm_t(0); - ymm_t ymm_i = ymm_t(1); - ymm_t ymm_f = ymm_t(2); - ymm_t ymm_o = ymm_t(3); - ymm_t ymm_ct_1 = ymm_t(4); - ymm_t ymm_wp0 = ymm_t(5); - ymm_t ymm_wp1 = ymm_t(6); - ymm_t ymm_wp2 = ymm_t(7); - vmovups(ymm_c, ptr[reg_ptr_gates + offset]); - vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); - vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); - vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); - if (!compute_c1h1_) { - vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); - } - if (use_peephole_) { - vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); - vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); - vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); - } - /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ - // act_cand(c) - act(ymm_c, ymm_c, act_cand_); - // act_gate(i) or act_gate(ct_1 * wp0 + i) - if (!compute_c1h1_ && use_peephole_) { - vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); - vaddps(ymm_i, ymm_i, ymm_wp0); - } - act(ymm_i, ymm_i, act_gate_); - vmulps(ymm_c, ymm_c, ymm_i); - if (!compute_c1h1_) { - // act_gate(f) or act_gate(ct_1 * wp1 + f) - if (use_peephole_) { - vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); - vaddps(ymm_f, ymm_f, ymm_wp1); - } - act(ymm_f, ymm_f, act_gate_); - // ct - vmulps(ymm_f, ymm_f, ymm_ct_1); - vaddps(ymm_f, ymm_f, ymm_c); - } - /* H_t = act_cell(C_t) * act_gate(o) */ - // act_cell(C_t) - ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; - ymm_t ymm_tmp = ymm_i; - act(ymm_tmp, ymm_ct, act_cell_); - // act_gate(o) or act_gate(ct * wp2 + o) - if (use_peephole_) { - vmulps(ymm_wp2, ymm_ct, ymm_wp2); - vaddps(ymm_o, ymm_o, ymm_wp2); - } - act(ymm_o, ymm_o, act_gate_); - // ht - vmulps(ymm_o, ymm_o, ymm_tmp); - // save ct and ht - vmovups(ptr[reg_ptr_ct + offset], ymm_ct); - vmovups(ptr[reg_ptr_ht + offset], ymm_o); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - - if (use_peephole_) { - postCode(); - } else { - ret(); - } -} - -bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } - -void GRUJitCode::generate() { - reg64_t reg_ptr_gates = rax; - reg64_t reg_ptr_ht_1 = r9; - reg64_t reg_ptr_ht = r10; - mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); - mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); - mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); - ymm_t ymm_one = ymm_t(0); - - if (id_ == 2) { - reg64_t reg_ptr_tmp = r11; - mov(reg_ptr_tmp, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); - } - int offset = 0; - int d = num_ * sizeof(float); - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - ymm_t ymm_u = ymm_t(1); - ymm_t ymm_r = ymm_t(2); - ymm_t ymm_s = ymm_t(3); - ymm_t ymm_ht_1 = ymm_t(4); - // W: {W_update, W_reset; W_state} - if (id_ == 0 || id_ == 2) { - vmovups(ymm_u, ptr[reg_ptr_gates + offset]); - vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); - } - if (id_ == 1) { - vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); - } - if (id_ == 1 || id_ == 2) { - vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); - } - - if (id_ == 0) { - // ht = act_gate(u) * act_cand(s) - act(ymm_u, ymm_u, act_gate_); - act(ymm_s, ymm_s, act_cand_); - vmulps(ymm_s, ymm_s, ymm_u); - vmovups(ptr[reg_ptr_ht + offset], ymm_s); - } else if (id_ == 1) { - // ht = act_gate(r) * ht_1 - act(ymm_r, ymm_r, act_gate_); - vmulps(ymm_r, ymm_r, ymm_ht_1); - vmovups(ptr[reg_ptr_ht + offset], ymm_r); - } else if (id_ == 2) { - // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 - ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); - act(ymm_u, ymm_u, act_gate_); - act(ymm_s, ymm_s, act_cand_); - vmulps(ymm_s, ymm_s, ymm_u); - vsubps(ymm_u, ymm_one_inner, ymm_u); - vmulps(ymm_u, ymm_ht_1, ymm_u); - vaddps(ymm_u, ymm_s, ymm_u); - vmovups(ptr[reg_ptr_ht + offset], ymm_u); - } - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - - ret(); -} -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_gen.cc b/paddle/fluid/operators/math/jit_gen.cc deleted file mode 100644 index 5c6672928e8c03ccb1920bd828f785084e422fc2..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_gen.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_gen.h" -#include -#include -#include -#include "paddle/fluid/platform/cpu_info.h" - -DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -constexpr Xbyak::Operand::Code g_abi_regs[] = { - Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, - Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15}; - -constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]); - -void JitCode::preCode() { - for (int i = 0; i < num_g_abi_regs; ++i) { - push(Xbyak::Reg64(g_abi_regs[i])); - } - if (platform::MayIUse(platform::avx512f)) { - mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); - } -} - -void JitCode::postCode() { - for (int i = 0; i < num_g_abi_regs; ++i) { - pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i])); - } - ret(); -} - -void JitCode::dumpCode(const Xbyak::uint8 *code) const { - if (code) { - static int counter = 0; - std::ostringstream filename; - filename << "paddle_jitcode_" << name() << "." << counter << ".bin"; - counter++; - std::ofstream fout(filename.str(), std::ios::out); - if (fout.is_open()) { - fout.write(reinterpret_cast(code), getSize()); - fout.close(); - } - } -} - -Xbyak::Address JitCode::EVEX_compress_addr(Xbyak::Reg64 base, int offt, - bool bcast) { - int scale = 0; - if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { - offt = offt - 2 * EVEX_max_8b_offt; - scale = 1; - } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { - offt = offt - 4 * EVEX_max_8b_offt; - scale = 2; - } - auto re = Xbyak::RegExp() + base + offt; - if (scale) { - re = re + reg_EVEX_max_8b_offt * scale; - } - if (bcast) { - return zword_b[re]; - } else { - return zword[re]; - } -} - -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_gen.h b/paddle/fluid/operators/math/jit_gen.h deleted file mode 100644 index 6abf3434cc8d8f6ab2838ef822a4f6b948331802..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_gen.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/platform/macros.h" - -#define XBYAK_USE_MMAP_ALLOCATOR -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -DECLARE_bool(dump_jitcode); - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -#define DECLARE_JIT_CODE(codename) \ - const char *name() const override { return #codename; } - -// Application Binary Interface -constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), - abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), - abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX); - -class JitCode : public Xbyak::CodeGenerator { - public: - explicit JitCode(size_t code_size = 256 * 1024, void *code_ptr = nullptr) - : Xbyak::CodeGenerator(code_size, code_ptr) {} - - virtual ~JitCode() {} - virtual const char *name() const = 0; - virtual void generate() = 0; - - template - const FUNC getCode() { - this->generate(); - const Xbyak::uint8 *code = CodeGenerator::getCode(); - if (FLAGS_dump_jitcode) { - this->dumpCode(code); - } - return reinterpret_cast(code); - } - DISABLE_COPY_AND_ASSIGN(JitCode); - - protected: - Xbyak::Reg64 param1{abi_param1}; - const int EVEX_max_8b_offt = 0x200; - const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; - - void preCode(); - void postCode(); - void dumpCode(const Xbyak::uint8 *code) const; - void L(const char *label) { Xbyak::CodeGenerator::L(label); } - void L(const Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); } - // Enhanced vector extension - Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt, - bool bcast = false); -}; - -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc deleted file mode 100644 index 118696ba47986e2dbf97535333c9817b7c264a54..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -KernelPool& KernelPool::Instance() { - static thread_local KernelPool g_jit_kernels; - return g_jit_kernels; -} - -std::shared_ptr KernelPool::Get(const std::string& key) const { - if (kers_.find(key) == kers_.end()) { - return nullptr; - } - return kers_.at(key); -} - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h deleted file mode 100644 index b78b92b4f97b761654a5b9b178f96c1dc99f7789..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel.h +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include // for shared_ptr -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_impl.h" -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/macros.h" - -// Note: Only support on CPU yet. -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -// TODO(TJ): remove me -typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; - -class Kernel { - public: - Kernel() = default; - virtual ~Kernel() = default; - // TODO(TJ): below members should be deprecated. - int num_{0}; - int end_{0}; - int rest_{0}; - DISABLE_COPY_AND_ASSIGN(Kernel); -}; - -class KernelPool { - public: - static KernelPool &Instance(); - - template - std::shared_ptr Get(ARGS... args); - - std::shared_ptr Get(const std::string &key) const; - - private: - KernelPool() = default; - std::unordered_map> kers_; - - DISABLE_COPY_AND_ASSIGN(KernelPool); -}; - -template -class VMulKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddReluKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VScalKernel : public Kernel { - public: - // y = a.*x - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddBiasKernel : public Kernel { - public: - // y = a.+x - void (*Compute)(const T *, const T *, T *, int); -}; - -#ifdef PADDLE_WITH_MKLDNN -template -class EltwiseMulnChw16cNCKernel : public Kernel { - public: - // nChw16c = nChw16c .* NC - void (*Compute)(const float *, const float *, float *, int, int); -}; -#endif - -template -class VActKernel : public Kernel { - public: - void (*Compute)(const T *, T *, int); -}; - -template -class VReluKernel : public VActKernel {}; - -template -class VIdentityKernel : public VActKernel {}; - -template -class VExpKernel : public VActKernel {}; - -template -class VSigmoidKernel : public VActKernel {}; - -template -class VTanhKernel : public VActKernel {}; - -template -class LSTMKernel : public Kernel { - public: - // compute c1 and h1 without c0 or h0 - void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *); - void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *); -}; - -template -class GRUKernel : public Kernel { - public: - // compute h1 without h0 - void (*ComputeH1)(gru_t *, const gru_attr_t *); - void (*ComputeHtPart1)(gru_t *, const gru_attr_t *); - void (*ComputeHtPart2)(gru_t *, const gru_attr_t *); -}; - -template -class CRFDecodeKernel : public Kernel { - public: - virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha, - int *track) const = 0; -}; - -template -class LayerNormKernel : public Kernel { - public: - virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale, - const T *bias, int height, - const float epsilon) const = 0; -}; - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc deleted file mode 100644 index 8cf588efba52314650bfd376b95b10e6d4336b2e..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/enforce.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#ifdef PADDLE_WITH_MKLML -template -void VMulMKL(const T* x, const T* y, T* z, int n); - -template <> -void VMulMKL(const float* x, const float* y, float* z, int n) { - platform::dynload::vsMul(n, x, y, z); -} - -template <> -void VMulMKL(const double* x, const double* y, double* z, int n) { - platform::dynload::vdMul(n, x, y, z); -} - -template -void VAddMKL(const T* x, const T* y, T* z, int n); - -template <> -void VAddMKL(const float* x, const float* y, float* z, int n) { - platform::dynload::vsAdd(n, x, y, z); -} - -template <> -void VAddMKL(const double* x, const double* y, double* z, int n) { - platform::dynload::vdAdd(n, x, y, z); -} - -template -void VScalMKL(const T* a, const T* x, T* y, int n); - -template <> -void VScalMKL(const float* a, const float* x, float* y, int n) { - if (x == y) { - platform::dynload::cblas_sscal(n, *a, y, 1); - } else { - refer::VScal(a, x, y, n); - } -} - -template <> -void VScalMKL(const double* a, const double* x, double* y, int n) { - if (x == y) { - platform::dynload::cblas_dscal(n, *a, y, 1); - } else { - refer::VScal(a, x, y, n); - } -} - -#endif - -/* VMUL JitKernel */ -template -class VMulKernelImpl : public VMulKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VMulKernelImpl(int d) : VMulKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - // roughly estimate the size of code - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VMulMKL; - return; - } -#endif - this->Compute = refer::VMul; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VMulKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VMulKernelImpl::useMKL(int d) { - return platform::MayIUse(platform::avx512f) && d > 512; -} - -template <> -bool VMulKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VAdd JitKernel */ -template -class VAddKernelImpl : public VAddKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddKernelImpl(int d) : VAddKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VAddMKL; - return; - } -#endif - this->Compute = refer::VAdd; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VAddKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VAddKernelImpl::useMKL(int d) { - return true; -} -#endif - -#ifdef PADDLE_WITH_MKLDNN -/* EltwiseMul for nChw16c & NC inputs JitKernel */ -template -class EltwiseMulnChw16cNCKernelImpl - : public math::jitkernel::EltwiseMulnChw16cNCKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit EltwiseMulnChw16cNCKernelImpl(int d) - : EltwiseMulnChw16cNCKernel() { - using mul_func_t = void (*)(const float*, const float*, float*, int, int); -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - // roughly estimate the size of code - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - sz = sz > 4096 ? sz : 4096; - jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); - this->Compute = (mul_func_t)jitcode_->getCode(); - return; - } -#endif - PADDLE_THROW( - "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " - "environemnt"); - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -}; - -template <> -bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { - return true; -} -#endif -#endif - -/* VAddRelu JitKernel */ -template -class VAddReluKernelImpl : public VAddReluKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddReluKernelImpl(int d) : VAddReluKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif - this->Compute = refer::VAddRelu; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddReluKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -/* VScal JitKernel */ -template -class VScalKernelImpl : public VScalKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VScalKernelImpl(int d) : VScalKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VScalMKL; - return; - } -#endif - this->Compute = refer::VScal; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VScalKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d, 1); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VScalKernelImpl::useMKL(int d) { - return d > 512; -} -template <> -bool VScalKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VAddBias JitKernel */ -template -class VAddBiasKernelImpl : public VAddBiasKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif - - this->Compute = refer::VAddBias; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddBiasKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d, 1); -} -#endif - -/* VRelu JitKernel */ -template -class VReluKernelImpl : public VReluKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VReluKernelImpl(int d) : VReluKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 /* init size */ + - d / YMM_FLOAT_BLOCK * 4 /* instructions */ * - 8 /* average bytes for each instruction */; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - - this->Compute = refer::VRelu; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VReluKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::relu); -} -#endif - -/* An empty JitKernel */ -template -class VIdentityKernelImpl : public VIdentityKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VIdentityKernelImpl(int d) : VIdentityKernel() { - this->Compute = refer::VIdentity; - } -}; - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); -REGISTER_JITKERNEL(vrelu, VReluKernel); -REGISTER_JITKERNEL(videntity, VIdentityKernel); -#ifdef PADDLE_WITH_MKLDNN -REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); -#endif - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc deleted file mode 100644 index ac2d29f1c18392ebf917cc097e63670e06b1eded..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* CRF Decode JitKernel */ -template -class CRFDecodeKernelImpl : public CRFDecodeKernel { - public: - explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel() { - this->num_ = tag_num; - } - void Compute(const int seq_len, const T* x, const T* w, T* alpha, - int* track) const override { - constexpr int state_trans_base_idx = 2; - for (int i = 0; i < this->num_; ++i) { - alpha[i] = w[i] + x[i]; - } - for (int k = 1; k < seq_len; ++k) { - for (int i = 0; i < this->num_; ++i) { - T max_score = -std::numeric_limits::max(); - int max_j = 0; - for (int j = 0; j < this->num_; ++j) { - T score = alpha[(k - 1) * this->num_ + j] + - w[(j + state_trans_base_idx) * this->num_ + i]; - if (score > max_score) { - max_score = score; - max_j = j; - } - } - alpha[k * this->num_ + i] = max_score + x[k * this->num_ + i]; - track[k * this->num_ + i] = max_j; - } - } - } -}; - -#define INIT_ALPHA(step_size) \ - /* Setup the alpha initial value.*/ \ - int i_offset = 0; \ - int last_offset = this->rest_ - step_size; \ - for (int i = 0; i <= this->end_; ++i) { \ - /* weights, input and alpha values. */ \ - __m256 w_content, x_content, alpha_content; \ - /* Load the relevant data into the variables from un-aligned address.*/ \ - w_content = _mm256_loadu_ps(w + i_offset); \ - x_content = _mm256_loadu_ps(x + i_offset); \ - alpha_content = _mm256_add_ps(w_content, x_content); \ - _mm256_storeu_ps(alpha + i_offset, alpha_content); \ - i_offset += step_size; \ - if (i == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - i_offset += last_offset; \ - } else { \ - break; \ - } \ - } \ - } - -#define UPDATE_ALPHA(step_size) \ - /* Update the alpha and track values. */ \ - __m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \ - max_score = _mm256_add_ps(max_score, x_content); \ - _mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \ - _mm256_storeu_si256( \ - reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \ - max_j); \ - /* Calculate the offset of next step*/ \ - j_offset += step_size; \ - if (j == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - j_offset += last_offset; \ - } else { \ - break; \ - } \ - } - -#define INTRIAVX_FLOAT(block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ - int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(YMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ - __m256i max_j = _mm256_set1_epi32(0); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ - __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ - __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ - /* According to the mask value, update the index of the max_score.*/ \ - /* AVX instructions.*/ \ - __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \ - __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \ - __m128i lo_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 0); \ - __m128i hi_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 1); \ - lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \ - hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \ - lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \ - hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \ - lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \ - hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \ - max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \ - max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \ - /* AVX done*/ \ - /* Update the max_score value.*/ \ - max_score = _mm256_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ - } \ - seq_offset += this->num_; \ - } \ - } - -#define INTRIAVX2_FLOAT(isa, block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl(int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(YMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ - __m256i max_j = _mm256_set1_epi32(0); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ - __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ - __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ - /* According to the mask value, update the index of the max_score.*/ \ - /* AVX2 instructions.*/ \ - max_j = _mm256_or_si256( \ - _mm256_andnot_si256((__m256i)mask, max_j), \ - _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \ - /* Update the max_score value.*/ \ - max_score = _mm256_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ - } \ - seq_offset += this->num_; \ - } \ - } - -#define INTRIAVX512_FLOAT(block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ - int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(ZMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); \ - __m512i max_j = _mm512_setzero_si512(); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m512 w_content = _mm512_loadu_ps(w + trans_offset); \ - __m512 score_v = _mm512_add_ps(alpha_content, w_content); \ - __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \ - /* AVX512 instructions.*/ \ - max_j = _mm512_mask_set1_epi32(max_j, mask, i); \ - /* Update the max_score value.*/ \ - max_score = _mm512_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - /* Update the alpha and track values.*/ \ - __m512 x_content = \ - _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \ - max_score = _mm512_add_ps(max_score, x_content); \ - _mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \ - max_score); \ - _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \ - this->num_ + j_offset), \ - max_j); \ - /* Calculate the offset of next step*/ \ - j_offset += ZMM_FLOAT_BLOCK; \ - if (j == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - j_offset += last_offset; \ - } else { \ - break; \ - } \ - } \ - } \ - seq_offset += this->num_; \ - } \ - } - -#ifdef __AVX__ -INTRIAVX_FLOAT(kEQ8); -INTRIAVX_FLOAT(kGT8LT16); -INTRIAVX_FLOAT(kEQ16); -INTRIAVX_FLOAT(kGT16); -#endif -#ifdef __AVX2__ -INTRIAVX2_FLOAT(platform::avx2, kEQ8); -INTRIAVX2_FLOAT(platform::avx2, kGT8LT16); -INTRIAVX2_FLOAT(platform::avx2, kEQ16); -INTRIAVX2_FLOAT(platform::avx2, kGT16); -#endif -#ifdef __AVX512F__ -INTRIAVX2_FLOAT(platform::avx512f, kEQ8); -INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16); -INTRIAVX512_FLOAT(kEQ16); -INTRIAVX512_FLOAT(kGT16); -#endif - -#undef INTRIAVX512_FLOAT -#undef INTRIAVX2_FLOAT -#undef INTRIAVX_FLOAT -#undef INIT_ALPHA -#undef UPDATE_ALPHA - -REGISTER_JITKERNEL_DEPRECATED(crf_decode, CRFDecodeKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc deleted file mode 100644 index 7945cfb253a61b7d1191c39537254126e2bb85dd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#ifdef PADDLE_WITH_MKLML -// try to use MKL to speedup -template -void VExpMKL(const T* x, T* y, int n); - -template <> -void VExpMKL(const float* x, float* y, int n) { - platform::dynload::vsExp(n, x, y); -} - -template <> -void VExpMKL(const double* x, double* y, int n) { - platform::dynload::vdExp(n, x, y); -} - -template -void VSigmoidMKL(const T* x, T* y, int n) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = static_cast(0) - y[i]; - } - VExpMKL(y, y, n); - for (int i = 0; i < n; ++i) { - y[i] = static_cast(1) / (static_cast(1) + y[i]); - } -} - -template -void VTanhMKL(const T* x, T* y, int n) { - for (int i = 0; i < n; ++i) { - y[i] = static_cast(2) * x[i]; - } - VSigmoidMKL(y, y, n); - for (int i = 0; i < n; ++i) { - y[i] = static_cast(2) * y[i] - static_cast(1); - } -} -#endif - -/* VExp JitKernel */ -template -class VExpKernelImpl : public VExpKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VExpKernelImpl(int d) : VExpKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VExpMKL; - return; - } -#endif - this->Compute = refer::VExp; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VExpKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::exp); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VExpKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VExpKernelImpl::useMKL(int d) { - return true; -} - -#endif - -/* VSigmoid JitKernel */ -template -class VSigmoidKernelImpl : public VSigmoidKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - -#ifdef PADDLE_WITH_MKLML - // strictly it's a better impl with MKL, then is refer - if (useMKL(d)) { - this->Compute = VSigmoidMKL; - return; - } -#endif - this->Compute = refer::VSigmoid; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VSigmoidKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::sigmoid); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VSigmoidKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VSigmoidKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VTanh JitKernel */ -template -class VTanhKernelImpl : public VTanhKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VTanhKernelImpl(int d) : VTanhKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - -#ifdef PADDLE_WITH_MKLML - // strictly it's a better impl with MKL, then is refer - if (useMKL(d)) { - this->Compute = VTanhMKL; - return; - } -#endif - this->Compute = refer::VTanh; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VTanhKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::tanh); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VTanhKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VTanhKernelImpl::useMKL(int d) { - return true; -} -#endif - -REGISTER_JITKERNEL(vexp, VExpKernel); -REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); -REGISTER_JITKERNEL(vtanh, VTanhKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h deleted file mode 100644 index ba5f20e53383d3cafab4239f1a2d911addf1ae23..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 -#define XMM_FLOAT_BLOCK 4 -#define YMM_FLOAT_BLOCK 8 -#define ZMM_FLOAT_BLOCK 16 - -typedef struct { - void* gates; // gates: W_ch, W_ih, W_fh, W_oh - const void* ct_1; - void* ct; - void* ht; - /* weight_peephole and checked data are only used in peephole*/ - const void* wp{nullptr}; - void* checked{nullptr}; -} lstm_t; - -typedef struct { - void* gates; // gates: {W_update, W_reset; W_state} - const void* ht_1; - void* ht; -} gru_t; - -struct rnn_attr_s { - int d; - std::string act_gate, act_cand; - rnn_attr_s() = default; - rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand) - : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} -}; - -struct lstm_attr_s : public rnn_attr_s { - bool use_peephole; - std::string act_cell; - lstm_attr_s() = default; - lstm_attr_s(int _d, const std::string& _act_gate, - const std::string& _act_cand, const std::string& _act_cell, - bool _use_peephole = false) - : rnn_attr_s(_d, _act_gate, _act_cand), - use_peephole(_use_peephole), - act_cell(_act_cell) {} -}; - -typedef struct rnn_attr_s gru_attr_t; -typedef struct lstm_attr_s lstm_attr_t; - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc deleted file mode 100644 index e21092037a27d26cd31205b1b5d8e2f0cb8380cd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* Layer Norm JitKernel */ -template -class LayerNormKernelImpl : public LayerNormKernel { - public: - explicit LayerNormKernelImpl(int right) : LayerNormKernel() { - this->num_ = right; - } - - void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, - int height, const float epsilon) const override { - // get mean - for (int i = 0; i < height; i++) { - T sum = 0.0; - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - sum += x[offset + j]; - } - mean[i] = sum / this->num_; - } - - // get variance - for (int i = 0; i < height; i++) { - T sum = 0.0; - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]); - } - var[i] = sum / this->num_; - } - - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - T sqrt_var = sqrt(var[i] + (T)epsilon); - for (int j = 0; j < this->num_; j++) { - out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var; - } - } - if (scale) { - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - out[offset + j] *= scale[j]; - } - } - } - - if (bias) { - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - out[offset + j] += bias[j]; - } - } - } - } -}; - -#define INTRIAVX_FLOAT(isa, jit_block) \ - template <> \ - LayerNormKernelImpl::LayerNormKernelImpl(int right) \ - : LayerNormKernel() { \ - this->num_ = right; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - this->end_ = this->num_ - this->rest_; \ - } \ - template <> \ - void LayerNormKernelImpl::Compute( \ - float* x, float* out, float* mean, float* var, const float* scale, \ - const float* bias, int height, const float epsilon) const { \ - __m256 sum; \ - __m256 mean_vec, var_vec; \ - __m128 hi, lo; \ - __m256 tmp; \ - size_t offset; \ - size_t j; \ - size_t block = YMM_FLOAT_BLOCK; \ - __m256 reverse_num_vec = \ - _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ - __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ - int rest_mask = \ - ((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \ - 0x0ff; \ - __m256i mask_vec = _mm256_set_epi32( \ - rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \ - rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \ - rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \ - rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \ - \ - for (int i = 0; i < height; ++i) { \ - offset = i * this->num_; \ - \ - /* get mean */ \ - sum = _mm256_setzero_ps(); \ - for (j = offset; j < end_ + offset; j += block) { \ - sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)x + j); \ - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - hi = _mm256_extractf128_ps(sum, 1); \ - lo = _mm256_extractf128_ps(sum, 0); \ - sum = _mm256_add_ps( \ - sum, _mm256_insertf128_ps( \ - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ - sum = _mm256_hadd_ps(sum, sum); \ - sum = _mm256_hadd_ps(sum, sum); \ - mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \ - mean[i] = *reinterpret_cast(&mean_vec); \ - \ - /* get variance */ \ - sum = _mm256_setzero_ps(); \ - for (j = offset; j < end_ + offset; j += block) { \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - hi = _mm256_extractf128_ps(sum, 1); \ - lo = _mm256_extractf128_ps(sum, 0); \ - sum = _mm256_add_ps( \ - sum, _mm256_insertf128_ps( \ - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ - sum = _mm256_hadd_ps(sum, sum); \ - sum = _mm256_hadd_ps(sum, sum); \ - var_vec = _mm256_mul_ps(sum, reverse_num_vec); \ - var[i] = *reinterpret_cast(&var_vec); \ - \ - /* get x_norm and calculate output*/ \ - for (j = offset; j < end_ + offset; j += block) { \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_div_ps( \ - tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ - } \ - if (rest_ != 0) { \ - j = offset + num_ - block; \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_div_ps( \ - tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ - } \ - \ - if (scale) { \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)out + j); \ - } \ - for (j = offset; j < end_ + offset; j += block) { \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_mul_ps( \ - _mm256_loadu_ps((const float*)out + j), \ - _mm256_loadu_ps((const float*)scale + j - offset))); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_mul_ps( \ - tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \ - } \ - } \ - \ - if (bias) { \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)out + j); \ - } \ - for (j = offset; j < end_ + offset; j += block) { \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_add_ps( \ - _mm256_loadu_ps((const float*)out + j), \ - _mm256_loadu_ps((const float*)bias + j - offset))); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_add_ps( \ - tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \ - } \ - } \ - } \ - } - -#ifdef __AVX__ -INTRIAVX_FLOAT(platform::avx, kEQ8); -INTRIAVX_FLOAT(platform::avx, kGT8LT16); -INTRIAVX_FLOAT(platform::avx, kEQ16); -INTRIAVX_FLOAT(platform::avx, kGT16); -INTRIAVX_FLOAT(platform::avx2, kEQ8); -INTRIAVX_FLOAT(platform::avx2, kGT8LT16); -INTRIAVX_FLOAT(platform::avx2, kEQ16); -INTRIAVX_FLOAT(platform::avx2, kGT16); -INTRIAVX_FLOAT(platform::avx512f, kEQ8); -INTRIAVX_FLOAT(platform::avx512f, kGT8LT16); -INTRIAVX_FLOAT(platform::avx512f, kEQ16); -INTRIAVX_FLOAT(platform::avx512f, kGT16); -#endif - -#undef INTRIAVX_FLOAT - -REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h deleted file mode 100644 index 4dba3b56810794cb4839d26386ae77a8f4507977..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#define JITKERNEL_DECLARE_STATIC_FUNC \ - static inline std::string name(int d) { \ - PADDLE_THROW("DType should be either float or double"); \ - } \ - static inline bool useJIT(int d) { return false; } \ - static inline bool useMKL(int d) { return false; } - -#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(int d) { \ - std::string key(#ker_key "f"); \ - if (useJIT(d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(d); \ - } else if (useMKL(d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(int d) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, int>(int d) - -#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(d) - -#define JITKERNEL_IMPL(ker_class, ker_dtype) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(d)) - -#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \ - macro_find_key, macro_impl) \ - marco_declare(ker_class, ker_dtype) { \ - macro_find_key(ker_class, ker_dtype); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - macro_impl(ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>( \ - kers_.at(key)); \ - } - -#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ - marco_declare, macro_find_key, macro_impl) \ - marco_define_name(ker_key, ker_class); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \ - macro_find_key, macro_impl); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \ - macro_find_key, macro_impl) - -#define REGISTER_JITKERNEL(ker_key, ker_class) \ - REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ - JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \ - JITKERNEL_IMPL) - -// TODO(TJ): below defines are deprecated, would be remove recently -#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ - if (d < YMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kLT8); \ - } else if (d == YMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ8); \ - } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kGT8LT16); \ - } else if (d == ZMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ16); \ - } else { \ - macro_(ker, dtype, isa, kGT16); \ - } - -#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ - if (platform::MayIUse(platform::avx512f)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \ - } else if (platform::MayIUse(platform::avx2)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \ - } else if (platform::MayIUse(platform::avx)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \ - } else { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \ - } - -#define JITKERNEL_KEY(ker_key, dtype_key) \ - #ker_key #dtype_key + std::to_string(d) - -#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(d)) - -#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \ - dtype_key, marco_declare, macro_key, \ - macro_impl) \ - marco_declare(ker_class, ker_dtype) { \ - std::string key = macro_key(ker_key, dtype_key); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>( \ - kers_.at(key)); \ - } - -#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \ - JITKERNEL_DECLARE, JITKERNEL_KEY, \ - JITKERNEL_NEW_IMPL_DEPRECATED); \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \ - JITKERNEL_DECLARE, JITKERNEL_KEY, \ - JITKERNEL_NEW_IMPL_DEPRECATED) - -#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \ - macro_key, macro_impl) \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \ - macro_key, macro_impl); \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \ - marco_declare, macro_key, macro_impl) - -#define FOR_EACH_ISA(macro_, block) \ - macro_(platform::avx512f, block); \ - macro_(platform::avx2, block); \ - macro_(platform::avx, block); \ - macro_(platform::isa_any, block) - -#define FOR_EACH_BLOCK(macro_, isa) \ - macro_(isa, kLT8); \ - macro_(isa, kEQ8); \ - macro_(isa, kGT8LT16); \ - macro_(isa, kEQ16); \ - macro_(isa, kGT16) - -#define FOR_EACH_ISA_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, platform::avx512f); \ - FOR_EACH_BLOCK(macro_, platform::avx2); \ - FOR_EACH_BLOCK(macro_, platform::avx); \ - FOR_EACH_BLOCK(macro_, platform::isa_any) - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc deleted file mode 100644 index 2db3274a45610aedea385baf650b8efb42ac39d0..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/macros.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* LSTM JitKernel */ -template -class LSTMKernelImpl : public LSTMKernel { - public: - static inline std::string name(const lstm_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; - jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); - this->ComputeCtHt = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); - this->ComputeC1H1 = - jitcode1_->getCode(); - return; - } -#endif - - this->ComputeCtHt = refer::LSTMCtHt; - this->ComputeC1H1 = refer::LSTMC1H1; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool LSTMKernelImpl::useJIT(int d) { - return gen::LSTMJitCode::init(d); -} -#endif - -/* Peephole JitKernel */ -template -class PeepholeKernelImpl : public LSTMKernel { - public: - static inline std::string name(const lstm_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8; - jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); - this->ComputeCtHt = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); - this->ComputeC1H1 = - jitcode1_->getCode(); - return; - } -#endif - - this->ComputeCtHt = refer::LSTMCtHt; - this->ComputeC1H1 = refer::LSTMC1H1; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool PeepholeKernelImpl::useJIT(int d) { - return gen::LSTMJitCode::init(d); -} -#endif - -#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ - std::string key(#ker_key "f"); \ - key += (attr.act_gate + attr.act_cand + attr.act_cell + \ - (attr.use_peephole ? "p" : "n")); \ - if (useJIT(attr.d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(attr.d); \ - } else if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, const lstm_attr_t&>( \ - const lstm_attr_t& attr) - -#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(attr) - -#define JITKERNEL_LSTM_IMPL(ker, dtype) \ - if (attr.use_peephole) { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); \ - } else { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); \ - } - -REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM, - JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM, - JITKERNEL_LSTM_IMPL); - -#undef JITKERNEL_LSTM_IMPL -#undef JITKERNEL_FIND_KEY_LSTM -#undef JITKERNEL_DECLARE_LSTM -#undef JITKERNEL_DEFINE_NAME_LSTM - -/* GRU JitKernel */ -template -class GRUKernelImpl : public GRUKernel { - public: - static inline std::string name(const gru_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; - jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096)); - this->ComputeH1 = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096)); - this->ComputeHtPart1 = - jitcode1_->getCode(); - - jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096)); - this->ComputeHtPart2 = - jitcode2_->getCode(); - return; - } -#endif - this->ComputeH1 = refer::GRUH1; - this->ComputeHtPart1 = refer::GRUHtPart1; - this->ComputeHtPart2 = refer::GRUHtPart2; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}, - jitcode2_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool GRUKernelImpl::useJIT(int d) { - return gen::GRUJitCode::init(d); -} -#endif - -#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(const gru_attr_t& attr) { \ - std::string key(#ker_key "f"); \ - key += (attr.act_gate + attr.act_cand); \ - if (useJIT(attr.d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(attr.d); \ - } else if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(const gru_attr_t& attr) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, const gru_attr_t&>( \ - const gru_attr_t& attr) - -#define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(attr) - -#define JITKERNEL_GRU_IMPL(ker, dtype) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); - -REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DEFINE_NAME_GRU, - JITKERNEL_DECLARE_GRU, JITKERNEL_FIND_KEY_GRU, - JITKERNEL_GRU_IMPL); - -#undef JITKERNEL_GRU_IMPL -#undef JITKERNEL_FIND_KEY_GRU -#undef JITKERNEL_DECLARE_GRU -#undef JITKERNEL_DEFINE_NAME_GRU -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc deleted file mode 100644 index 19f7bd8909499c12fd5bee4db0d0a71a632e7f19..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ /dev/null @@ -1,742 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include // for exp -#include // for memcpy -#include -#include -#include -#include "gflags/gflags.h" -#include "glog/logging.h" -#include "gtest/gtest.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/port.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -#ifdef __AVX__ -#include -#endif - -constexpr int repeat = 20000; - -// TODO(TJ): benchmark and test should be seperated, -// benchmark should verify more sizes - -inline double GetCurrentUS() { - struct timeval time; - gettimeofday(&time, NULL); - return 1e+6 * time.tv_sec + time.tv_usec; -} - -template -void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), - const T upper = static_cast(20.f)) { - static unsigned int seed = 100; - std::mt19937 rng(seed++); - std::uniform_real_distribution uniform_dist(0, 1); - for (int i = 0; i < n; ++i) { - a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); - } -} - -#if defined __AVX__ || defined __AVX2__ -void vrelu_intri8(const int n, const float* x, float* y) { - __m256 tmp = _mm256_loadu_ps(x); - tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); - _mm256_storeu_ps(y, tmp); -} -#endif - -TEST(JitKernel, vrelu) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -10.f, 1.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VRelu(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vrelu_intri8(d, x_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us"; - } -#endif - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -TEST(JitKernel, vaddbias) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float a = 2.f; - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAddBias(&a, x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#ifdef PADDLE_WITH_MKLML -void vexp_mkl(const int n, const float* x, float* y) { - paddle::platform::dynload::vsExp(n, x, y); -} -#endif - -TEST(JitKernel, vexp) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VExp(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vexp_mkl(d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - // ker->Compute(x_data, ztgt_data); - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vsigmoid_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VExpKernel>& vexp, - const int n, const float* x, float* y) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = 0.f - y[i]; - } - vexp->Compute(y, y, n); - for (int i = 0; i < n; ++i) { - y[i] = 1.f / (1.f + y[i]); - } -} - -TEST(JitKernel, vsigmoid) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vexp = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vsigmoid_better(vexp, d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VSigmoid(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vtanh_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VScalKernel>& vscal, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VSigmoidKernel>& - vsigmoid, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddBiasKernel>& - vaddbias, - const int n, const float* x, float* y) { - const float a = 2.f, b = -1.f; - vscal->Compute(&a, x, y, n); - vsigmoid->Compute(y, y, n); - vscal->Compute(&a, y, y, n); - vaddbias->Compute(&b, y, y, n); -} - -TEST(JitKernel, vtanh) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vscal = - jit::KernelPool::Instance().template Get>(d); - const auto& vsigmoid = - jit::KernelPool::Instance().template Get>(d); - const auto& vaddbias = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VTanh(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void lstm_ctht_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VSigmoidKernel>& - vsigmoid_3d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VMulKernel>& vmul_d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, - const int d, float* gates, const float* ct_1, float* ct, float* ht) { - int d2 = d * 2; - vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); - vtanh_d->Compute(gates, gates, d); - vmul_d->Compute(gates, gates + d, gates + d, d); - vmul_d->Compute(ct_1, gates + d2, gates + d2, d); - vadd_d->Compute(gates + d, gates + d2, ct, d); - /* H_t = act_cell(C_t) * ogated */ - vtanh_d->Compute(ct, gates + d2, d); - vmul_d->Compute(gates + d2, gates + d * 3, ht, d); -} - -TEST(JitKernel, lstm) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { - int d4 = d * 4; - int d3 = d * 3; - std::vector x(d4), xref(d4); - std::vector ct_1(d), ct_tgt(d), ht_tgt(d); - std::vector ct_ref(d), ht_ref(d); - RandomVec(d4, x.data(), -2.f, 2.f); - RandomVec(d, ct_1.data(), -2.f, 2.f); - memcpy(xref.data(), x.data(), sizeof(float) * d4); - std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false); - const auto& ker = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>( - attr); - // below kernels are used to compute refer - const auto& vsigmoid_3d = - jit::KernelPool::Instance().template Get>( - d3); - const auto& vtanh_d = - jit::KernelPool::Instance().template Get>(d); - const auto& vmul_d = - jit::KernelPool::Instance().template Get>(d); - const auto& vadd_d = - jit::KernelPool::Instance().template Get>(d); - - float* x_data = x.data(); - float* xref_data = xref.data(); - const float* ct_1_data = ct_1.data(); - float* ct_tgt_data = ct_tgt.data(); - float* ht_tgt_data = ht_tgt.data(); - float* ct_ref_data = ct_ref.data(); - float* ht_ref_data = ht_ref.data(); - // compute once to check correctness - jit::lstm_t step; - step.gates = xref_data; - step.ct_1 = ct_1_data; - step.ct = ct_ref_data; - step.ht = ht_ref_data; - refer::LSTMCtHt(&step, &attr); - - step.gates = x_data; - step.ct = ct_tgt_data; - step.ht = ht_tgt_data; - ker->ComputeCtHt(&step, &attr); - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); - EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); - } - - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data, - ct_1_data, ct_ref_data, ht_ref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::LSTMCtHt(&step, &attr); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->ComputeCtHt(&step, &attr); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - } -} - -#if defined __AVX__ || defined __AVX2__ -void vscal_intri8(const int n, const float a, const float* x, float* y) { - __m256 tmp; - __m256 scalar = _mm256_set1_ps(a); - tmp = _mm256_loadu_ps(x); - tmp = _mm256_mul_ps(tmp, scalar); - _mm256_storeu_ps(y, tmp); -} -void vscal_inp_intri8(const int n, const float a, float* x) { - __m256 tmp; - __m256 scalar = _mm256_set1_ps(a); - tmp = _mm256_loadu_ps(x); - tmp = _mm256_mul_ps(tmp, scalar); - _mm256_storeu_ps(x, tmp); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vscal_inp_mkl(const int n, const float a, float* x) { - paddle::platform::dynload::cblas_sscal(n, a, x, 1); -} -#endif - -TEST(JitKernel, vscal) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - std::memcpy(y.data(), x.data(), sizeof(float) * d); - float a = 2.f; - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VScal(&a, x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto trefs1 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VScal(&a, y_data, y_data, d); - } - auto trefe1 = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_inp_mkl(d, a, y_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_intri8(d, a, x_data, zref_data); - } - auto si1 = GetCurrentUS(); - auto si2 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_inp_intri8(d, a, y_data); - } - auto si3 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat - << " us, inplace: " << (si3 - si2) / repeat << " us"; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - auto ttgts1 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, y_data, y_data, d); - } - auto ttgte1 = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, inplace takes: " << (trefe1 - trefs1) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat - << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#if defined __AVX__ || defined __AVX2__ -void vmul_intri8(const int n, const float* x, const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_mul_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vmul_mkl(const int n, const float* x, const float* y, float* z) { - paddle::platform::dynload::vsMul(n, x, y, z); -} -#endif - -TEST(JitKernel, vmul) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VMul(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vmul_mkl(d, x_data, y_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vmul_intri8(d, x_data, y_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#if defined __AVX__ || defined __AVX2__ -void vadd_intri8(const int n, const float* x, const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_add_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vadd_mkl(const int n, const float* x, const float* y, float* z) { - paddle::platform::dynload::vsAdd(n, x, y, z); -} -#endif - -TEST(JitKernel, vadd) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAdd(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vadd_mkl(d, x_data, y_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vadd_intri8(d, x_data, y_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vaddrelu_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddKernel>& vadd, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VReluKernel>& vrelu, - const float* x, const float* y, float* z, int d) { - vadd->Compute(x, y, z, d); - vrelu->Compute(z, z, d); -} - -TEST(JitKernel, vaddrelu) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vadd = - jit::KernelPool::Instance().template Get>(d); - const auto& vrelu = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAddRelu(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d); - } - auto tmkle = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better takes: " << (tmkle - tmkls) / repeat << " us, " - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -TEST(JitKernel, pool) { - namespace jit = paddle::operators::math::jitkernel; - const int frame_size = 4; - std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); - - // empty call it to avoid unknown flag 'use_pinned_memory' on Mac - paddle::platform::MayIUse(paddle::platform::avx); - const auto& plstm1 = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>(attr); - - const auto& plstm2 = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>(attr); - EXPECT_EQ(plstm1, plstm2); - - const auto& peephole = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>( - jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true)); - EXPECT_TRUE(plstm1 != peephole); - - const auto& pvmul_f = - jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != - std::dynamic_pointer_cast(pvmul_f)); - - const auto& pvmul_d = - jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != - std::dynamic_pointer_cast(pvmul_d)); - - const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4"); -#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32) - EXPECT_EQ(pvmul_from_key, nullptr); -#else - EXPECT_EQ(pvmul_from_key, pvmul_f); -#endif - const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit"); - EXPECT_TRUE(pvmul_from_key2 == nullptr); -}