From 2848145814f956505ff0aba5653ebe029f96594f Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 25 Dec 2019 09:18:40 +0800 Subject: [PATCH] [X86] Polish the implementation of fc and imporve the unittest (#2656) * Remove GEMM padding in fc_compute. test=develop * Write a common ParallelFor function to run the for loop in parallel. * Add the codes of padding GEMM back in fc. * Refine the code of fc when padding_weight is false to avoid the definition of temporary Tensor. * Refine the unit test of fc and add testing case of padding and parallel. test=develop * Enable more test cases in common fc unittest, including padding and parallel for x86 target. * Remove the fc test under kernels/x86. test=develop * Disable relu in test of fc for non-x86 target. test=develop * Change the eps of arm. test=develop --- lite/api/test_step_rnn_lite_x86.cc | 9 +-- lite/backends/x86/parallel.h | 73 ++++++++++++++++++ lite/kernels/x86/CMakeLists.txt | 1 - lite/kernels/x86/fc_compute.h | 105 ++++++++++++-------------- lite/kernels/x86/fc_compute_test.cc | 95 ----------------------- lite/operators/fc_op.cc | 12 +-- lite/tests/kernels/fc_compute_test.cc | 94 +++++++++++++++++++---- 7 files changed, 214 insertions(+), 175 deletions(-) create mode 100644 lite/backends/x86/parallel.h delete mode 100644 lite/kernels/x86/fc_compute_test.cc diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc index cb9360b16d..075d314df6 100644 --- a/lite/api/test_step_rnn_lite_x86.cc +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -49,7 +49,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { "micro_video_id", "vertical_type_id"}; - for (int i = 0; i < target_names.size(); ++i) { + for (size_t i = 0; i < target_names.size(); ++i) { auto input_tensor = predictor->GetInput(i); int size = 0; if (i == 6 || i == 8) { @@ -74,8 +74,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { predictor->Run(); } - // LOG(INFO) << "================== Speed Report ==================="; - LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + LOG(INFO) << "warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; @@ -86,8 +85,8 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::vector out_shape = out->shape(); - for (int i = 0; i < results.size(); ++i) { - for (int j = 0; j < results[i].size(); ++j) { + for (size_t i = 0; i < results.size(); ++i) { + for (size_t j = 0; j < results[i].size(); ++j) { EXPECT_NEAR( out->data()[j + (out_shape[1] * i)], results[i][j], 1e-6); } diff --git a/lite/backends/x86/parallel.h b/lite/backends/x86/parallel.h new file mode 100644 index 0000000000..0689ec4c23 --- /dev/null +++ b/lite/backends/x86/parallel.h @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#ifdef PADDLE_WITH_MKLML +#include +#include "lite/backends/x86/mklml.h" +#endif + +namespace paddle { +namespace lite { +namespace x86 { + +static void SetNumThreads(int num_threads) { +#ifdef PADDLE_WITH_MKLML + int real_num_threads = std::max(num_threads, 1); + x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); +#endif +} + +static inline int64_t GetMaxThreads() { + int64_t num_threads = 1; +#ifdef PADDLE_WITH_MKLML + // Do not support nested omp parallem. + num_threads = omp_in_parallel() ? 1 : omp_get_max_threads(); +#endif + return std::max(num_threads, 1L); +} + +using ThreadHandler = + std::function; + +static inline void RunParallelFor(const int64_t begin, + const int64_t end, + const ThreadHandler& f) { + if (begin >= end) { + return; + } + +#ifdef PADDLE_WITH_MKLML + int64_t num_threads = std::min(GetMaxThreads(), end - begin); + if (num_threads > 1) { +#pragma omp parallel num_threads(num_threads) + { + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = (end - begin + num_threads - 1) / num_threads; + int64_t begin_tid = begin + tid * chunk_size; + f(begin_tid, std::min(end, chunk_size + begin_tid)); + } + return; + } +#endif + + f(begin, end); +} + +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 7bf131729a..75a95d1c91 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -100,4 +100,3 @@ lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test. lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) #lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86) lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86) -lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) diff --git a/lite/kernels/x86/fc_compute.h b/lite/kernels/x86/fc_compute.h index 4f9f6c428a..971f5dfa2f 100644 --- a/lite/kernels/x86/fc_compute.h +++ b/lite/kernels/x86/fc_compute.h @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #pragma once #include @@ -18,6 +19,7 @@ #include "lite/backends/x86/jit/kernel_base.h" #include "lite/backends/x86/jit/kernels.h" #include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/parallel.h" #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" @@ -57,34 +59,45 @@ class FCFunctor { bool relu = false, bool padding_weights = false) { auto blas = lite::x86::math::GetBlas(context); - lite::Tensor Y1; T* Y1_data = nullptr; - if (N % 128 == 0 && K % 128 == 0) { + + auto compute = + relu + ? jit::KernelFuncs, fluid::CPUPlace>::Cache() + .At(N) + : jit::KernelFuncs, fluid::CPUPlace>::Cache().At( + N); + auto parallel_compute = [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; i++) { + T* dst = Y + i * N; + T* src = Y1_data ? Y1_data + i * (N + 4) : dst; + compute(B, src, dst, N); + } + }; + + // Because of the overhead of memcpy, we only do padding for GEMM + // when weights is already padded in fc_fuse_pass. + if (padding_weights) { const int NN = N + 4; const int KK = K + 4; + + // NOTE: here need to mutable_data for temporary Tensor X1 and Y1, + // the overhead is unmeasured. lite::Tensor X1; X1.Resize({M * KK}); - Y1.Resize({M * (N + 4)}); T* X1_data = X1.mutable_data(); + + lite::Tensor Y1; + Y1.Resize({M * (N + 4)}); Y1_data = Y1.mutable_data(); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0])); - } - lite::Tensor W1; - T* W1_data = nullptr; - if (!padding_weights) { - W1.Resize({(K + 4) * (N + 4)}); - W1_data = W1.mutable_data(); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < K; i++) { - memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0])); + + auto parallel_memcpy_x = [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; i++) { + memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); } - } + }; + lite::x86::RunParallelFor(0, M, parallel_memcpy_x); + blas.GEMM(false, false, M, @@ -93,48 +106,30 @@ class FCFunctor { static_cast(1.0), X1_data, KK, - (padding_weights ? W : W1_data), + W, NN, static_cast(0.0), Y1_data, NN); - } else { - blas.MatMul(M, N, K, X, W, Y); - } - if (B == NULL) { - if (N % 128 == 0 && K % 128 == 0) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0])); - } - } - return; - } - if (relu) { - auto compute = - paddle::lite::jit::KernelFuncs, - lite::fluid::CPUPlace>::Cache() - .At(N); - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; - compute(B, src, dst, N); + + if (!B) { + auto parallel_memcpy_y = [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; i++) { + memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); + } + }; + lite::x86::RunParallelFor(0, M, parallel_memcpy_y); + return; } + + lite::x86::RunParallelFor(0, M, parallel_compute); } else { - auto compute = - paddle::lite::jit::KernelFuncs, - lite::fluid::CPUPlace>::Cache() - .At(N); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; - compute(B, src, dst, N); + blas.MatMul(M, N, K, X, W, Y); + if (!B) { + return; } + + lite::x86::RunParallelFor(0, M, parallel_compute); } } }; diff --git a/lite/kernels/x86/fc_compute_test.cc b/lite/kernels/x86/fc_compute_test.cc deleted file mode 100644 index 32394798b3..0000000000 --- a/lite/kernels/x86/fc_compute_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/x86/fc_compute.h" -#include -#include -#include -#include -#include "lite/core/op_registry.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace x86 { - -TEST(fc_x86, retrive_op) { - auto fc = - KernelRegistry::Global().Create("fc"); - ASSERT_FALSE(fc.empty()); - ASSERT_TRUE(fc.front()); -} - -TEST(fc_x86, init) { - FcCompute fc; - ASSERT_EQ(fc.precision(), PRECISION(kFloat)); - ASSERT_EQ(fc.target(), TARGET(kX86)); -} - -TEST(fc_x86, run_test) { - lite::Tensor x, w, b, out; - constexpr int batch_size = 2; - std::vector x_shape{batch_size, 3}; - x.Resize(lite::DDim(x_shape)); - std::vector w_shape{3, 4}; - w.Resize(lite::DDim(w_shape)); - std::vector b_shape{1, 4}; - b.Resize(lite::DDim(b_shape)); - std::vector out_shape{batch_size, 4}; - out.Resize(lite::DDim(out_shape)); - - auto x_data = x.mutable_data(); - auto w_data = w.mutable_data(); - auto b_data = b.mutable_data(); - auto out_data = out.mutable_data(); - - for (int64_t i = 0; i < x.dims().production(); i++) { - x_data[i] = static_cast(i); - } - for (int64_t i = 0; i < w.dims().production(); i++) { - w_data[i] = static_cast(2); - } - for (int64_t i = 0; i < b.dims().production(); i++) { - b_data[i] = static_cast(2); - } - - // FcCompute fc; - FcCompute fc; - operators::FcParam param; - - param.in_num_col_dims = 1; - param.input = &x; - param.w = &w; - param.bias = &b; - param.output = &out; - param.in_mat_dims = x.dims(); - param.activation_type = "relu"; - - std::unique_ptr ctx(new KernelContext); - ctx->As(); - fc.SetParam(param); - fc.SetContext(std::move(ctx)); - fc.Run(); - std::vector ref_data({8, 8, 8, 8, 26, 26, 26, 26}); - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); - } -} - -} // namespace x86 -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 141436fbf7..ad3fcf79a3 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -27,21 +27,21 @@ bool FcOpLite::CheckShape() const { const auto input_dims = param_.input->dims(); const auto w_dims = param_.w->dims(); + CHECK_EQ_OR_FALSE(w_dims.size(), 2UL); + int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; if (param_.bias) { const auto bias_dims = param_.bias->dims(); if (bias_dims.size() == 2) { CHECK_EQ_OR_FALSE(bias_dims[0], 1); - CHECK_EQ_OR_FALSE(bias_dims[1], w_dims[1]); + CHECK_EQ_OR_FALSE(bias_dims[1], w_dims_1); } else if (bias_dims.size() == 1) { - CHECK_EQ_OR_FALSE(bias_dims[0], w_dims[1]); + CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1); } } - CHECK_EQ_OR_FALSE(w_dims.size(), 2UL); CHECK_GT_OR_FALSE(input_dims.size(), static_cast(param_.in_num_col_dims)); - param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims); // CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]); @@ -91,7 +91,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.activation_type = op_desc.GetAttr("activation_type"); } if (op_desc.HasAttr("padding_weights")) { - param_.activation_type = op_desc.GetAttr("padding_weights"); + param_.padding_weights = op_desc.GetAttr("padding_weights"); + } else { + param_.padding_weights = false; } // For Int8 diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index 1dca6d41ed..de7bbc2158 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -18,11 +18,14 @@ #include "lite/core/arena/framework.h" #include "lite/tests/utils/fill_data.h" #include "lite/tests/utils/naive_math_impl.h" +#ifdef LITE_WITH_X86 +#include "lite/backends/x86/parallel.h" +#endif namespace paddle { namespace lite { -void fill_bias_fc(float* out, const float* bias, int num, int channel) { +void AddBias(float* out, const float* bias, int num, int channel) { int remain = channel; for (int j = 0; j < num; ++j) { const float* ptr_bias = bias; @@ -33,7 +36,15 @@ void fill_bias_fc(float* out, const float* bias, int num, int channel) { } } -DDim compute_out_dim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) { +void Relu(float* out, int num, int channel) { + for (int i = 0; i < num * channel; ++i) { + if (out[i] < 0) { + out[i] = 0; + } + } +} + +DDim ComputeOutDim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) { std::vector out_dim; out_dim.resize(in_num_col_dim + 1); auto in_mat_dims = dim_in.Flatten2D(in_num_col_dim); @@ -49,12 +60,16 @@ class FcOPTest : public arena::TestCase { // common attributes for this op. std::string input_ = "x"; std::string weight_ = "w"; + std::string weight_padding_ = "w_padding"; std::string bias_ = "b"; std::string out_ = "out"; DDim dims_{{1, 128}}; DDim wdims_{{128, 4}}; + DDim wdims_padding_; DDim bdims_{{4}}; int in_num_col_dims_{1}; + bool with_relu_{false}; + bool padding_weights_{false}; public: FcOPTest(const Place& place, @@ -62,12 +77,22 @@ class FcOPTest : public arena::TestCase { DDim dim_in, DDim dim_w, DDim dim_b, - int in_num_col_dims) + int in_num_col_dims, + bool with_relu, + bool padding) : TestCase(place, alias), dims_(std::move(dim_in)), wdims_(std::move(dim_w)), bdims_(dim_b), - in_num_col_dims_(in_num_col_dims) {} + in_num_col_dims_(in_num_col_dims), + with_relu_(with_relu) { +#ifdef LITE_WITH_X86 + if (padding && wdims_[0] % 128 == 0 && wdims_[1] % 128 == 0) { + padding_weights_ = true; + wdims_padding_ = DDim({wdims_[0] + 4, wdims_[1] + 4}); + } +#endif + } void RunBaseline(Scope* scope) override { auto x = scope->FindTensor(input_); @@ -76,11 +101,9 @@ class FcOPTest : public arena::TestCase { bool flag_bias = b; auto out = scope->NewTensor(out_); CHECK(out); - DDim out_dim = compute_out_dim(x->dims(), w->dims(), in_num_col_dims_); + DDim out_dim = ComputeOutDim(x->dims(), w->dims(), in_num_col_dims_); out->Resize(out_dim); - LOG(INFO) << "out dims: " << out_dim; - auto x_data = x->data(); auto w_data = w->data(); const float* b_data = nullptr; @@ -94,7 +117,9 @@ class FcOPTest : public arena::TestCase { int k = wdims_[0]; int n = wdims_[1]; - LOG(INFO) << "m: " << m << ", n: " << n << ", k: " << k; + LOG(INFO) << "M=" << m << ", N=" << n << ", K=" << k + << ", bias=" << flag_bias << ", with_relu=" << with_relu_ + << ", padding_weights=" << padding_weights_; if (m == 1) { basic_gemv(n, @@ -126,20 +151,34 @@ class FcOPTest : public arena::TestCase { false, false); if (flag_bias) { - fill_bias_fc(out_data, b_data, m, n); + AddBias(out_data, b_data, m, n); } } +#ifdef LITE_WITH_X86 + if (flag_bias && with_relu_) { + Relu(out_data, m, n); + } +#endif } void PrepareOpDesc(cpp::OpDesc* op_desc) { op_desc->SetType("fc"); op_desc->SetInput("Input", {input_}); - op_desc->SetInput("W", {weight_}); + if (padding_weights_) { + op_desc->SetInput("W", {weight_padding_}); + } else { + op_desc->SetInput("W", {weight_}); + } if (bdims_.production() > 0) { op_desc->SetInput("Bias", {bias_}); } op_desc->SetOutput("Out", {out_}); op_desc->SetAttr("in_num_col_dims", in_num_col_dims_); +#ifdef LITE_WITH_X86 + std::string activation_type = with_relu_ ? "relu" : ""; + op_desc->SetAttr("activation_type", activation_type); + op_desc->SetAttr("padding_weights", padding_weights_); +#endif } void PrepareData() override { @@ -155,22 +194,37 @@ class FcOPTest : public arena::TestCase { SetCommonTensor(input_, dims_, din.data()); SetCommonTensor(weight_, wdims_, win.data()); + if (padding_weights_) { + std::vector win_padding(wdims_padding_.production()); + for (int64_t i = 0; i < wdims_[0]; ++i) { + memcpy(&(win_padding[i * wdims_padding_[1]]), + &(win[i * wdims_[1]]), + wdims_[1] * sizeof(float)); + } + SetCommonTensor(weight_padding_, wdims_padding_, win_padding.data()); + } if (flag_bias) { SetCommonTensor(bias_, bdims_, bin.data()); } } }; -void test_fc(Place place, float abs_error) { +void TestFCMain(Place place, + float abs_error, + bool with_relu = false, + bool padding = false) { for (auto& m : {1, 3, 16}) { for (auto& n : {1, 4, 16, 128, 256, 1024}) { for (auto& k : {1, 16, 128, 1024}) { for (auto& bflag : {false, true}) { + if (!bflag && with_relu) { + continue; + } DDim dim_in{{m, k}}; DDim wdim{{k, n}}; DDim bdim{{bflag ? n : 0}}; - std::unique_ptr tester( - new FcOPTest(place, "def", dim_in, wdim, bdim, 1)); + std::unique_ptr tester(new FcOPTest( + place, "def", dim_in, wdim, bdim, 1, with_relu, padding)); #ifdef LITE_WITH_ARM if (place == TARGET(kARM)) { auto& ctx = tester->context()->As(); @@ -195,13 +249,25 @@ TEST(FcOP, precision) { #if defined(LITE_WITH_NPU) place = TARGET(kNPU); abs_error = 2e-1; // Using fp16 in NPU +#elif defined(LITE_WITH_X86) + place = TARGET(kX86); + abs_error = 1e-4; #elif defined(LITE_WITH_ARM) place = TARGET(kARM); #else return; #endif - test_fc(place, abs_error); + TestFCMain(place, abs_error); } +#ifdef LITE_WITH_X86 +TEST(FcOP, padding_and_parallel) { + Place place(TARGET(kX86)); + float abs_error = 1e-4; + x86::SetNumThreads(4); + TestFCMain(place, abs_error, true, true); +} +#endif + } // namespace lite } // namespace paddle -- GitLab