From 6aa3670f67c3b37749578d608f9c3e7224fee94d Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Fri, 17 Mar 2023 11:31:34 +0800 Subject: [PATCH] [phi][jit] clean paddle/phi/kernels/jit Unused methods (#51446) * [phi][jit] rm Softmax StrideScal * [phi][jit] rm kStrideScal * [phi][jit] fix Softmax clean omission * [phi][jit] fix Softmax clean omission * [phi][jit] fix StrideScal clean omission * [phi][jit] fix mkl SoftmaxKernel clean omission * [phi][jit] fix test error * [phi][jit] fix test error * [phi][jit] rm NCHW16CMulNC * [phi][jit] fix test error * [phi][jit] rm HSum HMax * [phi][jit] fix test error * [phi][jit] rm StrideASum * add AUTHORS.md * [phi][jit] fix test error --- AUTHORS.md | 1 + .../phi/kernels/cpu/layer_norm_grad_kernel.cc | 8 +- paddle/phi/kernels/funcs/jit/benchmark.cc | 35 --- .../phi/kernels/funcs/jit/gen/CMakeLists.txt | 3 - paddle/phi/kernels/funcs/jit/gen/blas.cc | 43 ---- paddle/phi/kernels/funcs/jit/gen/blas.h | 13 - paddle/phi/kernels/funcs/jit/gen/hopv.cc | 102 -------- paddle/phi/kernels/funcs/jit/gen/hopv.h | 93 ------- paddle/phi/kernels/funcs/jit/helper.cc | 6 - paddle/phi/kernels/funcs/jit/kernel_base.h | 34 --- .../kernels/funcs/jit/more/mix/CMakeLists.txt | 1 - paddle/phi/kernels/funcs/jit/more/mix/mix.cc | 38 --- paddle/phi/kernels/funcs/jit/more/mix/mix.h | 3 - .../kernels/funcs/jit/more/mkl/CMakeLists.txt | 2 - paddle/phi/kernels/funcs/jit/more/mkl/mkl.cc | 45 ---- paddle/phi/kernels/funcs/jit/more/mkl/mkl.h | 38 --- .../kernels/funcs/jit/refer/CMakeLists.txt | 6 - paddle/phi/kernels/funcs/jit/refer/refer.cc | 6 - paddle/phi/kernels/funcs/jit/refer/refer.h | 85 ------- paddle/phi/kernels/funcs/jit/test.cc | 240 +----------------- 20 files changed, 9 insertions(+), 793 deletions(-) delete mode 100644 paddle/phi/kernels/funcs/jit/gen/hopv.cc delete mode 100644 paddle/phi/kernels/funcs/jit/gen/hopv.h diff --git a/AUTHORS.md b/AUTHORS.md index b199f5bae37..fa6cb9dc99d 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -122,3 +122,4 @@ This is an incomplete list of authors of [Paddle](https://github.com/PaddlePaddl | yiakwy, yiakwy-xpu-ml-framework-team | Yi Wang (Graphcore) | | [Yulv-git](https://github.com/Yulv-git) | Shuangchi He | | [zrr1999](https://github.com/zrr1999) | Rongrui Zhan | +| [gouzil](https://github.com/gouzil) | Chuan Tian | diff --git a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc index 020dd15cc57..c42e423ba2d 100644 --- a/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc @@ -14,18 +14,14 @@ #include "paddle/phi/kernels/layer_norm_grad_kernel.h" -#include "paddle/phi/kernels/cpu/elementwise.h" -#include "paddle/phi/kernels/funcs/layer_norm_util.h" -#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) -#include "paddle/phi/kernels/funcs/jit/kernels.h" -#endif #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/jit/benchmark.cc b/paddle/phi/kernels/funcs/jit/benchmark.cc index c80864973b1..26552308cd8 100644 --- a/paddle/phi/kernels/funcs/jit/benchmark.cc +++ b/paddle/phi/kernels/funcs/jit/benchmark.cc @@ -171,17 +171,6 @@ void BenchKernelAXYN() { } } -template -void BenchKernelXRN() { - using T = typename KernelTuple::data_type; - for (int d : TestSizes()) { - phi::DenseTensor x; - RandomVec(d, x.mutable_data({d}, PlaceType())); - T res; - BenchAllImpls(d, x.data(), &res, d); - } -} - template void BenchKernelXYN() { using T = typename KernelTuple::data_type; @@ -390,22 +379,6 @@ void BenchKernelMatMul() { } } -template -void BenchKernelSoftmax() { - using T = typename KernelTuple::data_type; - for (int bs : {1, 2, 10}) { - for (int n : TestSizes()) { - phi::DenseTensor x, y; - x.Resize({bs, n}); - y.Resize({bs, n}); - RandomVec(bs * n, x.mutable_data(PlaceType()), -2.f, 2.f); - const T* x_data = x.data(); - T* y_data = y.mutable_data(PlaceType()); - BenchAllImpls(n, x_data, y_data, n, bs, 1); - } - } -} - template void BenchKernelLayerNorm() { using T = typename KernelTuple::data_type; @@ -514,9 +487,6 @@ void BenchKernelVBroadcast() { #define BenchKernelVTanh BenchKernelXYN #define BenchKernelVCopy BenchKernelXYN -#define BenchKernelHMax BenchKernelXRN -#define BenchKernelHSum BenchKernelXRN - #define BenchKernelLSTMCtHt BenchKernelLSTM #define BenchKernelLSTMC1H1 BenchKernelLSTM @@ -550,10 +520,6 @@ BENCH_FP32_CPU(VSigmoid); BENCH_FP32_CPU(VTanh); BENCH_FP32_CPU(VCopy); -// xrn -BENCH_FP32_CPU(HMax); -BENCH_FP32_CPU(HSum); - // LSTM BENCH_FP32_CPU(LSTMCtHt); BENCH_FP32_CPU(LSTMC1H1); @@ -569,7 +535,6 @@ BENCH_FP32_CPU(CRFDecoding); BENCH_FP32_CPU(SeqPool); BENCH_FP32_CPU(EmbSeqPool); BENCH_FP32_CPU(MatMul); -BENCH_FP32_CPU(Softmax); BENCH_FP32_CPU(Sgd); BENCH_FP32_CPU(VBroadcast); diff --git a/paddle/phi/kernels/funcs/jit/gen/CMakeLists.txt b/paddle/phi/kernels/funcs/jit/gen/CMakeLists.txt index 60e29ea81d5..e2b9b51590f 100644 --- a/paddle/phi/kernels/funcs/jit/gen/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/jit/gen/CMakeLists.txt @@ -34,10 +34,7 @@ use_jitkernel_gen(kLSTMC1H1) use_jitkernel_gen(kGRUH1) use_jitkernel_gen(kGRUHtPart1) use_jitkernel_gen(kGRUHtPart2) -use_jitkernel_gen(kNCHW16CMulNC) use_jitkernel_gen(kSeqPool) -use_jitkernel_gen(kHMax) -use_jitkernel_gen(kHSum) use_jitkernel_gen(kEmbSeqPool) use_jitkernel_gen(kAdam) use_jitkernel_gen(kAdamW) diff --git a/paddle/phi/kernels/funcs/jit/gen/blas.cc b/paddle/phi/kernels/funcs/jit/gen/blas.cc index 32b2b43af54..8c287efcf5d 100644 --- a/paddle/phi/kernels/funcs/jit/gen/blas.cc +++ b/paddle/phi/kernels/funcs/jit/gen/blas.cc @@ -110,48 +110,6 @@ void VXXJitCode::genCode() { ret(); } -void NCHW16CMulNCJitCode::genCode() { - // RDI is ptr x_input - // RSI is ptr y_input - // RDX is ptr output - // RCX is height - // r8 is width - - push(rbx); - - xor_(rax, rax); - xor_(r10, r10); - vmovups(zmm3, ptr[rsi]); - - L("h_loop"); - xor_(rbx, rbx); - L("w_loop"); - vmovups(zmm2, ptr[rdi + rax]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx + rax], zmm1); - add(rax, 64); - inc(rbx); - cmp(r8, rbx); - jnz("w_loop"); - inc(r10); - cmp(r10, rcx); - jnz("h_loop"); - - pop(rbx); - ret(); -} - -class NCHW16CMulNCCreator : public JitCodeCreator { - public: - bool CanBeUsed(const int& attr) const override { - return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f); - } - size_t CodeSize(const int& d) const override { return 256 * 1024; } - std::unique_ptr CreateJitCode(const int& attr) const override { - return make_unique(attr, CodeSize(attr)); - } -}; - #define DECLARE_BLAS_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ @@ -188,4 +146,3 @@ REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); -REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator); diff --git a/paddle/phi/kernels/funcs/jit/gen/blas.h b/paddle/phi/kernels/funcs/jit/gen/blas.h index a1577b86e65..a046634440e 100644 --- a/paddle/phi/kernels/funcs/jit/gen/blas.h +++ b/paddle/phi/kernels/funcs/jit/gen/blas.h @@ -108,19 +108,6 @@ DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false); #undef DECLARE_BLAS_JITCODE -// nChw16c = nChw16c .* NC -class NCHW16CMulNCJitCode : public JitCode { - public: - DECLARE_JIT_CODE(NCHW16CMulNCJitCode); - explicit NCHW16CMulNCJitCode(int d /*unused*/, - size_t code_size, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr) { - this->genCode(); - } - void genCode() override; -}; - } // namespace gen } // namespace jit } // namespace phi diff --git a/paddle/phi/kernels/funcs/jit/gen/hopv.cc b/paddle/phi/kernels/funcs/jit/gen/hopv.cc deleted file mode 100644 index f6eb4f37986..00000000000 --- a/paddle/phi/kernels/funcs/jit/gen/hopv.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/phi/kernels/funcs/jit/gen/hopv.h" - -#include "paddle/phi/backends/cpu/cpu_info.h" -#include "paddle/phi/kernels/funcs/jit/registry.h" - -namespace phi { -namespace jit { -namespace gen { - -void HOPVJitCode::genCode() { - const int num_blocks = num_ / YMM_FLOAT_BLOCK; - int offset = 0; - - if (num_blocks > 0) { - // load one firstly - vmovups(ymm_tmp, ptr[param_src]); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - for (int i = 1; i < num_blocks; ++i) { - vmovups(ymm_src, ptr[param_src + offset]); - process(ymm_tmp, ymm_src, ymm_tmp); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - vextractf128(xmm_dst, ymm_tmp, 1); - process(xmm_dst, xmm_dst, xmm_tmp); - } else { - if (type_ == operand_type::MAX) { - vbroadcastss(ymm_dst, ptr[param_src]); - } else if (type_ == operand_type::ADD) { - vxorps(ymm_dst, ymm_dst, ymm_dst); - } - } - - int rest = num_ % YMM_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src, ptr[param_src + offset]); - offset += sizeof(float) * 4; - rest -= 4; - process(xmm_dst, xmm_dst, xmm_src); - } - - vpermilps(xmm_tmp, xmm_dst, 16 + 8 + 3); - process(xmm_dst, xmm_dst, xmm_tmp); - - if (rest >= 2) { - vmovq(xmm_src, ptr[param_src + offset]); - offset += sizeof(float) * 2; - rest -= 2; - process(xmm_dst, xmm_dst, xmm_src); - } - - vpermilps(xmm_tmp, xmm_dst, 1); - process(xmm_dst, xmm_dst, xmm_tmp); - - if (rest >= 1) { - vmovss(xmm_src, ptr[param_src + offset]); - process(xmm_dst, xmm_dst, xmm_src); - } - vmovss(ptr[param_dst], xmm_dst); - ret(); -} - -#define DECLARE_HOP_CREATOR(name) \ - class name##Creator : public JitCodeCreator { \ - public: \ - bool CanBeUsed(const int& attr) const override { \ - return phi::backends::cpu::MayIUse(phi::backends::cpu::avx); \ - } \ - size_t CodeSize(const int& d) const override { \ - return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ - } \ - std::unique_ptr CreateJitCode(const int& attr) const override { \ - return make_unique(attr, CodeSize(attr)); \ - } \ - } - -DECLARE_HOP_CREATOR(HMax); -DECLARE_HOP_CREATOR(HSum); - -#undef DECLARE_HOP_CREATOR - -} // namespace gen -} // namespace jit -} // namespace phi - -namespace gen = phi::jit::gen; - -REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator); -REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator); diff --git a/paddle/phi/kernels/funcs/jit/gen/hopv.h b/paddle/phi/kernels/funcs/jit/gen/hopv.h deleted file mode 100644 index 71c7212578d..00000000000 --- a/paddle/phi/kernels/funcs/jit/gen/hopv.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#pragma once - -#include - -#include "glog/logging.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/kernels/funcs/jit/gen/jitcode.h" - -namespace phi { -namespace jit { -namespace gen { - -// horizontal operand vector -class HOPVJitCode : public JitCode { - public: - explicit HOPVJitCode(int d, - operand_type type, - size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d), type_(type) { - if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) { - PADDLE_THROW(phi::errors::Unimplemented( - "Do not support operand type code: %d.", type)); - } - this->genCode(); - } - - std::string name() const override { - std::string base = "VXXJitCode"; - if (type_ == operand_type::MAX) { - base += "_MAX"; - } else { - base += "_SUM"; - } - return base; - } - void genCode() override; - - protected: - template - void process(JMM& dst, JMM& src1, JMM& src2) { // NOLINT - if (type_ == operand_type::MAX) { - vmaxps(dst, src1, src2); - } else if (type_ == operand_type::ADD) { - vaddps(dst, src1, src2); - } - } - - private: - int num_; - operand_type type_; - reg64_t param_src{abi_param1}; - reg64_t param_dst{abi_param2}; - reg64_t param_attr{abi_param3}; - - ymm_t ymm_tmp = ymm_t(0); - ymm_t ymm_src = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); - - xmm_t xmm_tmp = xmm_t(0); - xmm_t xmm_src = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); -}; - -#define DECLARE_HOP_JITCODE(name, op_type) \ - class name##JitCode : public HOPVJitCode { \ - public: \ - explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \ - : HOPVJitCode(d, op_type, code_size, code_ptr) {} \ - }; - -DECLARE_HOP_JITCODE(HMax, operand_type::MAX); -DECLARE_HOP_JITCODE(HSum, operand_type::ADD); - -#undef DECLARE_HOP_JITCODE - -} // namespace gen -} // namespace jit -} // namespace phi diff --git a/paddle/phi/kernels/funcs/jit/helper.cc b/paddle/phi/kernels/funcs/jit/helper.cc index fb34f47ed57..eb171a6aaeb 100644 --- a/paddle/phi/kernels/funcs/jit/helper.cc +++ b/paddle/phi/kernels/funcs/jit/helper.cc @@ -38,7 +38,6 @@ const char* to_string(KernelType kt) { ONE_CASE(kVAddRelu); ONE_CASE(kVSub); ONE_CASE(kVScal); - ONE_CASE(kStrideScal); ONE_CASE(kVAddBias); ONE_CASE(kVRelu); ONE_CASE(kVBroadcast); @@ -55,15 +54,10 @@ const char* to_string(KernelType kt) { ONE_CASE(kGRUHtPart2); ONE_CASE(kCRFDecoding); ONE_CASE(kLayerNorm); - ONE_CASE(kNCHW16CMulNC); ONE_CASE(kSeqPool); ONE_CASE(kMatMul); - ONE_CASE(kHMax); ONE_CASE(kAdam); ONE_CASE(kAdamW); - ONE_CASE(kHSum); - ONE_CASE(kStrideASum); - ONE_CASE(kSoftmax); ONE_CASE(kEmbSeqPool); ONE_CASE(kSgd); default: diff --git a/paddle/phi/kernels/funcs/jit/kernel_base.h b/paddle/phi/kernels/funcs/jit/kernel_base.h index bdb8132eec6..5f2e48076d9 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_base.h +++ b/paddle/phi/kernels/funcs/jit/kernel_base.h @@ -31,17 +31,11 @@ typedef enum { kGRUH1, kGRUHtPart1, kGRUHtPart2, - kHSum, // horizontal max - kHMax, // horizontal sum kLSTMCtHt, kLSTMC1H1, kLayerNorm, kMatMul, - kNCHW16CMulNC, kSeqPool, - kSoftmax, - kStrideASum, - kStrideScal, kVAdd, kVAddBias, kVAddRelu, @@ -94,10 +88,6 @@ struct XYNTuple { typedef void (*func_type)(const T*, T*, int); }; -// x, returned value, n -template -struct XRNTuple : public XYNTuple {}; - // x, returned value, n, stride template struct XRNSTuple { @@ -121,8 +111,6 @@ DECLARE_KERNELTUPLE(XYZNTuple, VSub); DECLARE_KERNELTUPLE(AXYNTuple, VScal); DECLARE_KERNELTUPLE(AXYNTuple, VAddBias); -DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal); - DECLARE_KERNELTUPLE(XYNTuple, VRelu); DECLARE_KERNELTUPLE(XYNTuple, VIdentity); DECLARE_KERNELTUPLE(XYNTuple, VSquare); @@ -131,11 +119,6 @@ DECLARE_KERNELTUPLE(XYNTuple, VSigmoid); DECLARE_KERNELTUPLE(XYNTuple, VTanh); DECLARE_KERNELTUPLE(XYNTuple, VCopy); -DECLARE_KERNELTUPLE(XRNTuple, HMax); -DECLARE_KERNELTUPLE(XRNTuple, HSum); - -DECLARE_KERNELTUPLE(XRNSTuple, StrideASum); - typedef struct { void* gates; // gates: x_ch, x_ih, x_fh, x_oh const void* ct_1; @@ -351,23 +334,6 @@ struct LayerNormTuple { T*, T*, T*, T*, const T*, const T*, int, const float, int); }; -template -struct SoftmaxTuple { - static constexpr KernelType kernel_type = kSoftmax; - typedef T data_type; - typedef int attr_type; - typedef void (*func_type)(const T*, T*, int, int, int); -}; - -// nChw16c = nChw16c .* NC -template -struct NCHW16CMulNCTuple { - static constexpr KernelType kernel_type = kNCHW16CMulNC; - typedef T data_type; - typedef int attr_type; - typedef void (*func_type)(const T*, const T*, T*, int, int); -}; - // Just for adding to kernel pool without template class Kernel { public: diff --git a/paddle/phi/kernels/funcs/jit/more/mix/CMakeLists.txt b/paddle/phi/kernels/funcs/jit/more/mix/CMakeLists.txt index b5bc6c84575..2fa8557c1d8 100644 --- a/paddle/phi/kernels/funcs/jit/more/mix/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/jit/more/mix/CMakeLists.txt @@ -18,4 +18,3 @@ use_jitkernel_more(kLSTMC1H1, mix) use_jitkernel_more(kGRUH1, mix) use_jitkernel_more(kGRUHtPart1, mix) use_jitkernel_more(kGRUHtPart2, mix) -use_jitkernel_more(kSoftmax, mix) diff --git a/paddle/phi/kernels/funcs/jit/more/mix/mix.cc b/paddle/phi/kernels/funcs/jit/more/mix/mix.cc index 78a042cf469..7bb58a8b246 100644 --- a/paddle/phi/kernels/funcs/jit/more/mix/mix.cc +++ b/paddle/phi/kernels/funcs/jit/more/mix/mix.cc @@ -49,41 +49,6 @@ void VTanh(const T* x, T* y, int n) { compute_addbias(&b, y, y, n); } -// remain is the product of dimension shapes after the axis dimension -void Softmax(const T* x, T* y, int n, int bs, int remain) { - auto compute_hmax = KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_hsum = KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_vscal = KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_strideasum = - KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_stridescal = - KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_vaddbias = - KernelFuncs, CPUPlace>::Cache().At(n); - auto compute_vexp = KernelFuncs, CPUPlace>::Cache().At(n); - - for (int i = 0; i < bs; ++i) { - T scalar; - compute_hmax(x, &scalar, n); - scalar = static_cast(0) - scalar; - compute_vaddbias(&scalar, x, y, n); // x - max - compute_vexp(y, y, n); - if (remain == 1) { - compute_hsum(y, &scalar, n); - scalar = static_cast(1) / scalar; - compute_vscal(&scalar, y, y, n); - } else { - for (int j = 0; j < remain; ++j) { - compute_strideasum(&y[j], &scalar, n, remain); - scalar = static_cast(1) / scalar; - compute_stridescal(&scalar, &y[j], &y[j], n, remain); - } - } - x += n; - y += n; - } -} - void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT if (type == kVSigmoid) { return KernelFuncs, CPUPlace>::Cache().At(d); @@ -221,8 +186,6 @@ bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; } bool VTanhKernel::CanBeUsed(const int& d) const { return true; } -bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; } - bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; } bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; } @@ -245,7 +208,6 @@ namespace mix = phi::jit::more::mix; REGISTER_MORE_KERNEL(VSigmoid); REGISTER_MORE_KERNEL(VTanh); -REGISTER_MORE_KERNEL(Softmax); REGISTER_MORE_KERNEL(LSTMCtHt); REGISTER_MORE_KERNEL(LSTMC1H1); REGISTER_MORE_KERNEL(GRUH1); diff --git a/paddle/phi/kernels/funcs/jit/more/mix/mix.h b/paddle/phi/kernels/funcs/jit/more/mix/mix.h index f1024969743..c932cb065e5 100644 --- a/paddle/phi/kernels/funcs/jit/more/mix/mix.h +++ b/paddle/phi/kernels/funcs/jit/more/mix/mix.h @@ -26,7 +26,6 @@ using T = float; void VSigmoid(const T* x, T* y, int n); void VTanh(const T* x, T* y, int n); -void Softmax(const T* x, T* y, int n, int bs, int remain); void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr); void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr); @@ -47,8 +46,6 @@ DECLARE_MORE_KERNEL(VSigmoid); DECLARE_MORE_KERNEL(VTanh); // XRN -DECLARE_MORE_KERNEL(Softmax); - DECLARE_MORE_KERNEL(LSTMCtHt); DECLARE_MORE_KERNEL(LSTMC1H1); diff --git a/paddle/phi/kernels/funcs/jit/more/mkl/CMakeLists.txt b/paddle/phi/kernels/funcs/jit/more/mkl/CMakeLists.txt index 609ddd3c284..7f6df06f87a 100644 --- a/paddle/phi/kernels/funcs/jit/more/mkl/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/jit/more/mkl/CMakeLists.txt @@ -11,14 +11,12 @@ use_jitkernel_more(kMatMul, mkl) use_jitkernel_more(kVMul, mkl) use_jitkernel_more(kVAdd, mkl) use_jitkernel_more(kVScal, mkl) -use_jitkernel_more(kStrideScal, mkl) use_jitkernel_more(kVExp, mkl) use_jitkernel_more(kVSquare, mkl) use_jitkernel_more(kVCopy, mkl) use_jitkernel_more(kVSigmoid, mkl) use_jitkernel_more(kVTanh, mkl) use_jitkernel_more(kSeqPool, mkl) -use_jitkernel_more(kSoftmax, mkl) use_jitkernel_more(kEmbSeqPool, mkl) use_jitkernel_more(kSgd, mkl) use_jitkernel_more(kVBroadcast, mkl) diff --git a/paddle/phi/kernels/funcs/jit/more/mkl/mkl.cc b/paddle/phi/kernels/funcs/jit/more/mkl/mkl.cc index ebb60f24d56..daf9eac988a 100644 --- a/paddle/phi/kernels/funcs/jit/more/mkl/mkl.cc +++ b/paddle/phi/kernels/funcs/jit/more/mkl/mkl.cc @@ -104,26 +104,6 @@ void VScal(const double* a, const double* x, double* y, int n) { } } -template <> -void StrideScal( - const float* a, const float* x, float* y, int n, int stride) { - if (x == y) { - phi::dynload::cblas_sscal(n / stride, *a, y, stride); - } else { - refer::StrideScal(a, x, y, n, stride); - } -} - -template <> -void StrideScal( - const double* a, const double* x, double* y, int n, int stride) { - if (x == y) { - phi::dynload::cblas_dscal(n / stride, *a, y, stride); - } else { - refer::StrideScal(a, x, y, n, stride); - } -} - template <> void VExp(const float* x, float* y, int n) { phi::dynload::vsExp(n, x, y); @@ -174,16 +154,6 @@ void ASum(const double* x, double* res, int n) { res[0] = phi::dynload::cblas_dasum(n, x, 1); } -template <> -void StrideASum(const float* x, float* res, int n, int stride) { - res[0] = phi::dynload::cblas_sasum(n / stride, x, stride); -} - -template <> -void StrideASum(const double* x, double* res, int n, int stride) { - res[0] = phi::dynload::cblas_dasum(n / stride, x, stride); -} - // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> bool VMulKernel::CanBeUsed(const int& d) const { @@ -200,11 +170,6 @@ bool VScalKernel::CanBeUsed(const int& d) const { return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f) && d > 512; } -template <> -bool StrideScalKernel::CanBeUsed(const int& d) const { - return true; -} - template <> bool VExpKernel::CanBeUsed(const int& d) const { return d > 7; @@ -281,12 +246,6 @@ bool MatMulKernel::CanBeUsed(const matmul_attr_t& attr) const { return true; } -template <> -bool SoftmaxKernel::CanBeUsed(const int& d) const { - // tuned on avx2 - return phi::backends::cpu::MayIUse(phi::backends::cpu::avx) && d < 60; -} - #define AWALYS_USE_ME_WITH_DOUBLE(func) \ template <> \ bool func##Kernel::CanBeUsed(const int& d) const { \ @@ -296,13 +255,11 @@ bool SoftmaxKernel::CanBeUsed(const int& d) const { AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VAdd); AWALYS_USE_ME_WITH_DOUBLE(VScal); -AWALYS_USE_ME_WITH_DOUBLE(StrideScal); AWALYS_USE_ME_WITH_DOUBLE(VExp); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VSquare); AWALYS_USE_ME_WITH_DOUBLE(VCopy); -AWALYS_USE_ME_WITH_DOUBLE(Softmax); #undef AWALYS_USE_ME_WITH_DOUBLE } // namespace mkl @@ -320,7 +277,6 @@ REGISTER_MKL_KERNEL(MatMul); REGISTER_MKL_KERNEL(VMul); REGISTER_MKL_KERNEL(VAdd); REGISTER_MKL_KERNEL(VScal); -REGISTER_MKL_KERNEL(StrideScal); REGISTER_MKL_KERNEL(VExp); REGISTER_MKL_KERNEL(VSquare); REGISTER_MKL_KERNEL(VCopy); @@ -329,7 +285,6 @@ REGISTER_MKL_KERNEL(VSigmoid); REGISTER_MKL_KERNEL(VTanh); REGISTER_MKL_KERNEL(SeqPool); REGISTER_MKL_KERNEL(EmbSeqPool); -REGISTER_MKL_KERNEL(Softmax); REGISTER_MKL_KERNEL(Sgd); #undef REGISTER_MKL_KERNEL diff --git a/paddle/phi/kernels/funcs/jit/more/mkl/mkl.h b/paddle/phi/kernels/funcs/jit/more/mkl/mkl.h index 20e50db6e06..017fd798003 100644 --- a/paddle/phi/kernels/funcs/jit/more/mkl/mkl.h +++ b/paddle/phi/kernels/funcs/jit/more/mkl/mkl.h @@ -154,42 +154,6 @@ void EmbSeqPool(const T* table, template void ASum(const T* x, T* res, int n); -template -void StrideASum(const T* x, T* res, int n, int stride); - -template -void StrideScal(const T* a, const T* x, T* y, int n, int stride); - -// remain is the product of dimension shapes after the axis dimension -template -void Softmax(const T* x, T* y, int n, int bs, int remain = 1) { - std::vector entities(bs); - for (int i = 0; i < bs; ++i) { - entities[i] = x[i * n]; - for (int c = 1; c < n; ++c) { - entities[i] = x[i * n + c] > entities[i] ? x[i * n + c] : entities[i]; - } - for (int c = 0; c < n; ++c) { - y[i * n + c] = x[i * n + c] - entities[i]; - } - } - VExp(y, y, n * bs); - for (int i = 0; i < bs; ++i) { - T sum; - if (remain == 1) { - ASum(&y[i * n], &sum, n); - sum = static_cast(1) / sum; - VScal(&sum, &y[i * n], &y[i * n], n); - } else { - for (int j = 0; j < remain; ++j) { - StrideASum(&y[i * n + j], &sum, n, remain); - sum = static_cast(1) / sum; - StrideScal(&sum, &y[i * n + j], &y[i * n + j], n, remain); - } - } - } -} - template void Sgd(const T* lr, const T* param, @@ -284,7 +248,6 @@ DECLARE_MKL_KERNEL(VAdd); // AXYN DECLARE_MKL_KERNEL(VScal); -DECLARE_MKL_KERNEL(StrideScal); // XYN DECLARE_MKL_KERNEL(VExp); @@ -296,7 +259,6 @@ DECLARE_MKL_KERNEL(VCopy); // others DECLARE_MKL_KERNEL(SeqPool); DECLARE_MKL_KERNEL(EmbSeqPool); -DECLARE_MKL_KERNEL(Softmax); DECLARE_MKL_KERNEL(Sgd); DECLARE_MKL_KERNEL(VBroadcast); diff --git a/paddle/phi/kernels/funcs/jit/refer/CMakeLists.txt b/paddle/phi/kernels/funcs/jit/refer/CMakeLists.txt index 5ef93f989df..632dc98eb71 100644 --- a/paddle/phi/kernels/funcs/jit/refer/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/jit/refer/CMakeLists.txt @@ -16,7 +16,6 @@ use_jitkernel_refer(kVAdd) use_jitkernel_refer(kVAddRelu) use_jitkernel_refer(kVSub) use_jitkernel_refer(kVScal) -use_jitkernel_refer(kStrideScal) use_jitkernel_refer(kVAddBias) use_jitkernel_refer(kVCopy) use_jitkernel_refer(kVRelu) @@ -31,14 +30,9 @@ use_jitkernel_refer(kGRUHtPart1) use_jitkernel_refer(kGRUHtPart2) use_jitkernel_refer(kCRFDecoding) use_jitkernel_refer(kLayerNorm) -use_jitkernel_refer(kNCHW16CMulNC) use_jitkernel_refer(kSeqPool) use_jitkernel_refer(kMatMul) use_jitkernel_refer(kVSquare) -use_jitkernel_refer(kHSum) -use_jitkernel_refer(kHMax) -use_jitkernel_refer(kStrideASum) -use_jitkernel_refer(kSoftmax) use_jitkernel_refer(kEmbSeqPool) use_jitkernel_refer(kAdam) use_jitkernel_refer(kAdamW) diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.cc b/paddle/phi/kernels/funcs/jit/refer/refer.cc index b6730610423..a9111155f93 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.cc +++ b/paddle/phi/kernels/funcs/jit/refer/refer.cc @@ -28,7 +28,6 @@ REGISTER_REFER_KERNEL(VAddRelu); REGISTER_REFER_KERNEL(VSub); REGISTER_REFER_KERNEL(VScal); -REGISTER_REFER_KERNEL(StrideScal); REGISTER_REFER_KERNEL(VAddBias); REGISTER_REFER_KERNEL(VRelu); @@ -48,13 +47,8 @@ REGISTER_REFER_KERNEL(GRUHtPart2); REGISTER_REFER_KERNEL(CRFDecoding); REGISTER_REFER_KERNEL(LayerNorm); -REGISTER_REFER_KERNEL(NCHW16CMulNC); REGISTER_REFER_KERNEL(SeqPool); REGISTER_REFER_KERNEL(MatMul); -REGISTER_REFER_KERNEL(HMax); -REGISTER_REFER_KERNEL(HSum); -REGISTER_REFER_KERNEL(StrideASum); -REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(Adam); REGISTER_REFER_KERNEL(AdamW); diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index 6d682775313..c7c3835f890 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -353,19 +353,6 @@ void LayerNorm(T* x, } } -template -void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { - int offset = 0; - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - for (int i = 0; i < 16; ++i) { - z[i + offset] = y[i] * x[i + offset]; - } - offset += ZMM_FLOAT_BLOCK; - } - } -} - template void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { for (int w = 0; w < attr->w; ++w) { @@ -407,68 +394,6 @@ void MatMul(const T* A, const T* B, T* C, const matmul_attr_t* attr) { } } -template -void HMax(const T* x, T* res, int n) { - res[0] = x[0]; - for (int i = 1; i < n; ++i) { - res[0] = res[0] < x[i] ? x[i] : res[0]; - } -} - -template -void HSum(const T* x, T* res, int n) { - res[0] = x[0]; - for (int i = 1; i < n; ++i) { - res[0] += x[i]; - } -} - -template -void StrideASum(const T* x, T* res, int n, int stride) { - res[0] = x[0]; - for (int i = stride; i < n; i += stride) { - res[0] += std::abs(x[i]); - } -} - -template -void StrideScal(const T* a, const T* x, T* y, int n, int stride) { - for (int i = 0; i < n; ++i) { - if (i % stride == 0) { - y[i] = x[i] * a[0]; - } else { - y[i] = x[i]; - } - } -} - -// y = e^(x - max(x)) -// y = y / sum(y) -// remain is the product of dimension shapes after the axis dimension -template -void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) { - for (int i = 0; i < bs; ++i) { - T scalar; - HMax(x, &scalar, n); - scalar = static_cast(0) - scalar; - VAddBias(&scalar, x, y, n); // x - max - VExp(y, y, n); - if (remain == 1) { - HSum(y, &scalar, n); - scalar = static_cast(1) / scalar; - VScal(&scalar, y, y, n); - } else { - for (int j = 0; j < remain; j++) { - StrideASum(&y[j], &scalar, n, remain); - scalar = static_cast(1) / scalar; - StrideScal(&scalar, &y[j], &y[j], n, remain); - } - } - x += n; - y += n; - } -} - // embedding seq pool // table is a matrix with (tbl_h, tbl_w) // idx is a matrix with (idx_h, idx_w) @@ -654,9 +579,6 @@ DECLARE_REFER_KERNEL(VSub); DECLARE_REFER_KERNEL(VScal); DECLARE_REFER_KERNEL(VAddBias); -// const T* a, const T* x, T* y, int n, int stride -DECLARE_REFER_KERNEL(StrideScal); - // const T* x, T* y, int n DECLARE_REFER_KERNEL(VRelu); DECLARE_REFER_KERNEL(VIdentity); @@ -675,18 +597,11 @@ DECLARE_REFER_KERNEL(GRUH1); DECLARE_REFER_KERNEL(GRUHtPart1); DECLARE_REFER_KERNEL(GRUHtPart2); -DECLARE_REFER_KERNEL(HMax); -DECLARE_REFER_KERNEL(HSum); - -DECLARE_REFER_KERNEL(StrideASum); - // others DECLARE_REFER_KERNEL(CRFDecoding); DECLARE_REFER_KERNEL(LayerNorm); -DECLARE_REFER_KERNEL(NCHW16CMulNC); DECLARE_REFER_KERNEL(SeqPool); DECLARE_REFER_KERNEL(MatMul); -DECLARE_REFER_KERNEL(Softmax); DECLARE_REFER_KERNEL(EmbSeqPool); DECLARE_REFER_KERNEL(Adam); DECLARE_REFER_KERNEL(AdamW); diff --git a/paddle/phi/kernels/funcs/jit/test.cc b/paddle/phi/kernels/funcs/jit/test.cc index 56bd318f601..dedbfa74d1f 100644 --- a/paddle/phi/kernels/funcs/jit/test.cc +++ b/paddle/phi/kernels/funcs/jit/test.cc @@ -227,33 +227,6 @@ void TestKernelXYN() { } } -template -void TestKernelXRN() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - auto last_acc = FLAGS_acc; - FLAGS_acc = 1e-4; - for (int d : TestSizes()) { - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - std::vector x(d); - RandomVec(d, x.data()); - T ref_res; - ref(x.data(), &ref_res, d); - - auto verifier = [](const typename KernelTuple::func_type tgt, - const std::vector& x, - const T ref_res) { - EXPECT_TRUE(tgt != nullptr); - T tgt_res; - tgt(x.data(), &tgt_res, x.size()); - ExpectEQ(&tgt_res, &ref_res, 1); - }; - TestAllImpls(d, verifier, x, ref_res); - } - FLAGS_acc = last_acc; -} - template void TestKernelLSTM() { using T = typename KernelTuple::data_type; @@ -411,62 +384,6 @@ void TestKernelGRU() { } } -template -void TestKernelNCHW16CMulNC() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - const int n = 3, c = 16 * 4, h = 10, w = 10; - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - int sz = n * c * h * w; - std::vector x(sz), y(n * c), zref(sz); - std::vector ztgt(sz), zjit(sz); - RandomVec(sz, x.data()); - RandomVec(n * c, y.data()); - - const T* x_data = x.data(); - const T* y_data = y.data(); - T* zref_data = zref.data(); - T* ztgt_data = ztgt.data(); - T* zjit_data = zjit.data(); - constexpr int simd_width = ZMM_FLOAT_BLOCK; - int C = c / simd_width; - auto tgt = jit::KernelFuncs::Cache().At(0); - auto funcs = jit::GetAllCandidateFuncs(0); - EXPECT_GT(funcs.size(), 0UL); - auto jitcode = funcs[0]; - EXPECT_TRUE(tgt != nullptr); - - if (std::is_same::value && - phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { - EXPECT_TRUE(jitcode != nullptr); - } - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_zref = - zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - auto ptr_ztgt = - ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - ref(ptr_x, ptr_y, ptr_zref, h, w); - tgt(ptr_x, ptr_y, ptr_ztgt, h, w); - - if (jitcode) { - auto ptr_zjit = - zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - jitcode(ptr_x, ptr_y, ptr_zjit, h, w); - } - } - } - ExpectEQ(ztgt_data, zref_data, sz); - if (jitcode) { - ExpectEQ(zjit_data, zref_data, sz); - } -} - template void TestKernelLayerNorm() { using T = typename KernelTuple::data_type; @@ -770,138 +687,6 @@ void TestKernelMatMul() { FLAGS_acc = last_acc; } -template -void TestKernelSoftmax() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - for (int bs : {1, 2, 10}) { - for (int n : TestSizes()) { - for (int m : {1, 2, 3}) { // remain - if (m > n || n % m != 0) { - continue; - } - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - std::vector x(bs * n), y(bs * n); - RandomVec(bs * n, x.data()); - const T* x_data = x.data(); - T* y_data = y.data(); - - std::vector xinp(x.size()); // inplace test - std::copy(x.begin(), x.end(), xinp.begin()); - ref(x_data, y_data, n, bs, m); - T* xinp_data = xinp.data(); - ref(xinp_data, xinp_data, n, bs, m); - ExpectEQ(xinp_data, y_data, n * bs); - - auto verifier = [](const typename KernelTuple::func_type tgt, - const std::vector& x, - const std::vector& yref, - int n, - int bs, - int m) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - EXPECT_EQ(x.size(), static_cast(n * bs)); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - std::vector ytgt(n * bs); - T* ytgt_data = ytgt.data(); - // test normal - tgt(x_data, ytgt_data, n, bs, m); - ExpectEQ(ytgt_data, yref_data, n * bs); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(ytgt_data, ytgt_data, n, bs, m); - ExpectEQ(ytgt_data, yref_data, n * bs); - }; - TestAllImpls(n, verifier, x, y, n, bs, m); - } - } - } -} - -template -void TestKernelStrideASum() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - for (int d : TestSizes()) { - for (int m : {1, 2, 3}) { // stride - if (m > d || d % m != 0) { - continue; - } - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - std::vector x(d); - RandomVec(d, x.data()); - T ref_res; - ref(x.data(), &ref_res, d, m); - - auto verifier = [](const typename KernelTuple::func_type tgt, - const std::vector& x, - const T ref_res, - const int m) { - EXPECT_TRUE(tgt != nullptr); - T tgt_res; - tgt(x.data(), &tgt_res, x.size(), m); - ExpectEQ(&tgt_res, &ref_res, 1); - }; - TestAllImpls(d, verifier, x, ref_res, m); - } - } -} - -template -void TestKernelStrideScal() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - for (int d : TestSizes()) { - for (int m : {1, 2, 3}) { // stride - if (m > d || d % m != 0) { - continue; - } - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - - const T a = static_cast(3); - std::vector x(d), yref(d); - std::vector xinp(d); // inplace test - RandomVec(d, x.data()); - std::copy(x.begin(), x.end(), xinp.begin()); - - const T* x_data = x.data(); - T* yref_data = yref.data(); - T* xinp_data = xinp.data(); - // test refer code inplace - ref(&a, x_data, yref_data, d, m); - ref(&a, xinp_data, xinp_data, d, m); - ExpectEQ(xinp_data, yref_data, d); - - auto verifier = [](const typename KernelTuple::func_type tgt, - const T a, - const std::vector& x, - const std::vector& yref, - const int m) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(yref.size(), x.size()); - const T* x_data = x.data(); - const T* yref_data = yref.data(); - const int d = yref.size(); - std::vector ytgt(d); - T* ytgt_data = ytgt.data(); - // test normal - tgt(&a, x_data, ytgt_data, d, m); - ExpectEQ(ytgt_data, yref_data, d); - // test inplace x - std::copy(x.begin(), x.end(), ytgt.begin()); - tgt(&a, ytgt_data, ytgt_data, d, m); - ExpectEQ(ytgt_data, yref_data, d); - }; - TestAllImpls(d, verifier, a, x, yref, m); - } - } -} - template void TestKernelAdam() { using T = typename KernelTuple::data_type; @@ -1268,7 +1053,7 @@ TEST(JITKernel_pool, jitcreator) { #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) EXPECT_EQ(jitcreators.size(), 0UL); #else - EXPECT_EQ(jitcreators.size(), 27UL); + EXPECT_EQ(jitcreators.size(), 24UL); #endif } @@ -1287,14 +1072,14 @@ TEST(JITKernel_pool, jitpool) { TEST(JITKernel_pool, more) { const auto& kers = jit::KernelPool::Instance().AllKernels(); - size_t target_num = 8; + size_t target_num = 7; #ifdef __AVX__ target_num += 2; #endif #ifdef PADDLE_WITH_MKLML - target_num += 12; + target_num += 11; #endif EXPECT_EQ(kers.size(), target_num); @@ -1302,7 +1087,7 @@ TEST(JITKernel_pool, more) { TEST(JITKernel_pool, refer) { const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); - EXPECT_EQ(kers.size(), 33UL); + EXPECT_EQ(kers.size(), 27UL); } // test helper @@ -1425,11 +1210,9 @@ TEST(JITKernel_helper, attr) { out << jit::to_string(jit::kNone) << jit::to_string(jit::kCRFDecoding) << jit::to_string(jit::kEmbSeqPool) << jit::to_string(jit::kGRUH1) << jit::to_string(jit::kGRUHtPart1) << jit::to_string(jit::kGRUHtPart2) - << jit::to_string(jit::kHSum) << jit::to_string(jit::kHMax) << jit::to_string(jit::kLSTMCtHt) << jit::to_string(jit::kLSTMC1H1) << jit::to_string(jit::kLayerNorm) << jit::to_string(jit::kMatMul) - << jit::to_string(jit::kNCHW16CMulNC) << jit::to_string(jit::kSeqPool) - << jit::to_string(jit::kSoftmax) << jit::to_string(jit::kVAdd) + << jit::to_string(jit::kSeqPool) << jit::to_string(jit::kVAdd) << jit::to_string(jit::kVAddBias) << jit::to_string(jit::kVAddRelu) << jit::to_string(jit::kVBroadcast) << jit::to_string(jit::kVCopy) << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity) @@ -1438,7 +1221,7 @@ TEST(JITKernel_helper, attr) { << jit::to_string(jit::kAdam) << jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare) << jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh); - EXPECT_EQ(out.str().size(), 239UL); + EXPECT_EQ(out.str().size(), 208UL); // SeqPoolTypes out.str(""); @@ -1635,9 +1418,6 @@ TEST(JITKernel_key, sgd) { #define TestKernelVTanh TestKernelXYN #define TestKernelVCopy TestKernelXYN -#define TestKernelHMax TestKernelXRN -#define TestKernelHSum TestKernelXRN - #define TestKernelLSTMCtHt TestKernelLSTM #define TestKernelLSTMC1H1 TestKernelLSTM @@ -1667,9 +1447,6 @@ TEST_CPU_KERNEL(VSigmoid); TEST_CPU_KERNEL(VTanh); TEST_CPU_KERNEL(VCopy); -TEST_CPU_KERNEL(HMax); -TEST_CPU_KERNEL(HSum); - TEST_CPU_KERNEL(LSTMCtHt); TEST_CPU_KERNEL(LSTMC1H1); @@ -1677,18 +1454,13 @@ TEST_CPU_KERNEL(GRUH1); TEST_CPU_KERNEL(GRUHtPart1); TEST_CPU_KERNEL(GRUHtPart2); -TEST_CPU_KERNEL(NCHW16CMulNC); TEST_CPU_KERNEL(LayerNorm); TEST_CPU_KERNEL(CRFDecoding); TEST_CPU_KERNEL(SeqPool); TEST_CPU_KERNEL(EmbSeqPool); TEST_CPU_KERNEL(MatMul); -TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Adam); TEST_CPU_KERNEL(AdamW); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); - -TEST_CPU_KERNEL(StrideASum); -TEST_CPU_KERNEL(StrideScal); -- GitLab