diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 7f79974248260c390c452659a349d90e6d2a332e..c1d4cc1b889700079091f859dbbdb46f626dbb0f 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,6 +76,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel - SRCS jit_kernel.cc jit_gen.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc + SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc DEPS cpu_info cblas gflags enforce) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc new file mode 100644 index 0000000000000000000000000000000000000000..29a89bca9827bbc6e0fa2e76894cb2eed523abb2 --- /dev/null +++ b/paddle/fluid/operators/math/jit_code.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_code.h" +#include "paddle/fluid/operators/math/jit_kernel.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { +namespace gen { + +using namespace platform::jit; // NOLINT + +bool VMulJitCode::init(int d) { + // TODO(TJ): maybe one AVX is enough, AVX above would slow down freq + // try more with avx2 or avx512 + if (MayIUse(avx) || MayIUse(avx2)) { + return d % AVX_FLOAT_BLOCK == 0; + } else { + return false; + } +} + +void VMulJitCode::generate() { + preCode(); + int stride = sizeof(float) * AVX_FLOAT_BLOCK; + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src1, ptr[param1 + i * stride]); + vmovups(ymm_src2, ptr[param2 + i * stride]); + vmulps(ymm_dst, ymm_src1, ymm_src2); + vmovups(ptr[param3 + stride * i], ymm_dst); + } + postCode(); +} + +} // namespace gen +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h new file mode 100644 index 0000000000000000000000000000000000000000..db1a0cd0958d7047b83d1dd234eaf9f0fb748fa6 --- /dev/null +++ b/paddle/fluid/operators/math/jit_code.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/math/jit_gen.h" + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { +namespace gen { + +using reg64_t = const Xbyak::Reg64; +using reg32_t = const Xbyak::Reg32; +using xmm_t = const Xbyak::Xmm; +using ymm_t = const Xbyak::Ymm; +using zmm_t = const Xbyak::Zmm; +using Label = Xbyak::Label; + +class VMulJitCode : public JitCode { + public: + DECLARE_JIT_CODE(VMulJitCode); + explicit VMulJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + + xmm_t xmm_src1 = xmm_t(0); + ymm_t ymm_src1 = ymm_t(0); + zmm_t zmm_src1 = zmm_t(0); + xmm_t xmm_src2 = xmm_t(1); + ymm_t ymm_src2 = ymm_t(1); + zmm_t zmm_src2 = zmm_t(1); + + xmm_t xmm_dst = xmm_t(2); + ymm_t ymm_dst = ymm_t(2); + zmm_t zmm_dst = zmm_t(2); +}; + +} // namespace gen +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 7f92043b6f4d4d1046252965ba8c1082033fb6bf..cef21348e43c4342c92bd17f4c70ff106e35d99c 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include -#include "paddle/fluid/operators/math/jit_gen.h" +#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/platform/enforce.h" @@ -30,30 +30,7 @@ namespace paddle { namespace operators { namespace math { namespace jitkernel { - -namespace jit = platform::jit; // remove me - -using namespace platform::jit; // NOLINT - -/* VMUL JitKernel */ -struct VMulJitCode : public gen::JitCode { - DECLARE_JIT_CODE(VMulJitCode); - explicit VMulJitCode(size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : gen::JitCode(code_size, code_ptr) {} - static bool init(int d) { - if (MayIUse(avx) || MayIUse(avx2)) { - return d % AVX_FLOAT_BLOCK == 0; - } else if (MayIUse(avx512f)) { - return d % AVX512_FLOAT_BLOCK == 0; - } else { - return false; - } - } - void generate() override { - preCode(); - postCode(); - } -}; +namespace jit = platform::jit; template void VMulRefer(const T* x, const T* y, T* z, int n) { @@ -76,6 +53,7 @@ void VMulMKL(const double* x, const double* y, double* z, int n) { } #endif +/* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { public: @@ -88,7 +66,7 @@ class VMulKernelImpl : public VMulKernel { explicit VMulKernelImpl(int d) : VMulKernel() { if (useJIT(d)) { constexpr size_t sz = 256 * 1024; // TODO(TJ): should be related with d - jitcode_.reset(new VMulJitCode(sz)); + jitcode_.reset(new gen::VMulJitCode(d, sz)); this->Compute = jitcode_->getCode(); return; @@ -103,12 +81,12 @@ class VMulKernelImpl : public VMulKernel { } private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; }; template <> bool VMulKernelImpl::useJIT(int d) { - return VMulJitCode::init(d); + return gen::VMulJitCode::init(d); } template <>