From d76c529a48e6375cc203f8e46fb9aec3429096d1 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Sun, 8 Dec 2019 23:11:09 +0800 Subject: [PATCH] Add fc op on lite x86 platform (#2568) --- lite/kernels/x86/CMakeLists.txt | 4 +- lite/kernels/x86/fc_compute.h | 198 ++++++++++++++++++++-------- lite/kernels/x86/fc_compute_test.cc | 29 ++-- lite/operators/fc_op.cc | 3 + lite/operators/op_params.h | 1 + 5 files changed, 161 insertions(+), 74 deletions(-) diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index bf3a1685f0..c735563230 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -20,14 +20,13 @@ add_kernel(stack_compute_x86 X86 basic SRCS stack_compute.cc DEPS ${lite_kernel_ add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps}) add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(layer_norm_compute_x86 X86 basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} jit_kernel_helper) -# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps} jit_kernel_helper) # lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} ) add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps} blas math_function sequence2batch gru_compute) #add_kernel(gru_compute_x86 X86 basic SRCS gru_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps}) -# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) add_kernel(gather_compute_x86 X86 basic SRCS gather_compute.cc DEPS ${lite_kernel_deps} fluid_data_type) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) @@ -100,3 +99,4 @@ lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test. lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) #lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86) lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86) +lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) diff --git a/lite/kernels/x86/fc_compute.h b/lite/kernels/x86/fc_compute.h index 620236a454..4f9f6c428a 100644 --- a/lite/kernels/x86/fc_compute.h +++ b/lite/kernels/x86/fc_compute.h @@ -13,66 +13,131 @@ // limitations under the License. #pragma once -#include +#include +#include "lite/backends/x86/jit/helper.h" +#include "lite/backends/x86/jit/kernel_base.h" +#include "lite/backends/x86/jit/kernels.h" +#include "lite/backends/x86/math/blas.h" #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" #include "lite/core/type_system.h" #include "lite/operators/fc_op.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" namespace paddle { namespace lite { namespace kernels { namespace x86 { -template -void fc_compute_eigen(const T* x, - int x_h, - int x_w, // - const T* w, - int w_h, - int w_w, // - const T* b, // - T* out) { - using matrix_t = - Eigen::Matrix; - - Eigen::Map X(x, x_h, x_w); - Eigen::Map W(w, w_h, w_w); - Eigen::Map Out(out, x_h, w_w); - - Out = X * W; +inline void FCOutputSize(const lite::DDim& in_dims, + const lite::DDim& w_dims, + std::vector& out_dims, // NOLINT + int in_num_col_dims, + bool padding_weights) { + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; - if (b) { - Eigen::Map> B(b, w_w); - Out = Out.array().rowwise() + B.transpose().array(); + out_dims.reserve(static_cast(in_num_col_dims + 1)); + for (int i = 0; i < in_num_col_dims; ++i) { + out_dims.push_back(in_dims[i]); } + out_dims.push_back(w_dims1); } -template -void fc_compute_naive(const T* x, - int x_h, - int x_w, // - const T* w, - int w_h, - int w_w, // - const T* b, // - T* out) { - CHECK_EQ(x_w, w_h); - // out shape: (x_h, w_w) - memset(out, 0, x_h * w_w * sizeof(T)); - for (int i = 0; i < x_h; i++) { - for (int j = 0; j < w_w; j++) { - T tmp = static_cast(0); - for (int k = 0; k < x_w; k++) { - tmp += x[i * x_w + k] * w[k * w_w + j]; +template +class FCFunctor { + public: + void operator()(const lite::X86Context& context, + const int M, + const int N, + const int K, + const T* X, + const T* W, + T* Y, + const T* B = nullptr, + bool relu = false, + bool padding_weights = false) { + auto blas = lite::x86::math::GetBlas(context); + lite::Tensor Y1; + T* Y1_data = nullptr; + if (N % 128 == 0 && K % 128 == 0) { + const int NN = N + 4; + const int KK = K + 4; + lite::Tensor X1; + X1.Resize({M * KK}); + Y1.Resize({M * (N + 4)}); + T* X1_data = X1.mutable_data(); + Y1_data = Y1.mutable_data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + memcpy(X1_data + i * KK, X + i * K, K * sizeof(X[0])); + } + lite::Tensor W1; + T* W1_data = nullptr; + if (!padding_weights) { + W1.Resize({(K + 4) * (N + 4)}); + W1_data = W1.mutable_data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < K; i++) { + memcpy(W1_data + i * NN, W + i * N, N * sizeof(W[0])); + } + } + blas.GEMM(false, + false, + M, + N, + K, + static_cast(1.0), + X1_data, + KK, + (padding_weights ? W : W1_data), + NN, + static_cast(0.0), + Y1_data, + NN); + } else { + blas.MatMul(M, N, K, X, W, Y); + } + if (B == NULL) { + if (N % 128 == 0 && K % 128 == 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(Y[0])); + } + } + return; + } + if (relu) { + auto compute = + paddle::lite::jit::KernelFuncs, + lite::fluid::CPUPlace>::Cache() + .At(N); + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; + compute(B, src, dst, N); + } + } else { + auto compute = + paddle::lite::jit::KernelFuncs, + lite::fluid::CPUPlace>::Cache() + .At(N); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + T* src = (N % 128 == 0 && K % 128 == 0) ? Y1_data + i * (N + 4) : dst; + compute(B, src, dst, N); } - out[i * w_w + j] = tmp + b[j]; } } -} +}; template class FcCompute : public KernelLite { @@ -81,20 +146,43 @@ class FcCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); - CHECK_GE(param.input->dims().size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); + auto* input = param.input; + auto* w = param.w; + auto* bias = param.bias; + auto* output = param.output; + int in_num_col_dims = param.in_num_col_dims; + bool with_relu = (param.activation_type == "relu") ? true : false; + + auto w_dims = w->dims(); + bool padding_weights = param.padding_weights; + + std::vector output_dims; + FCOutputSize( + input->dims(), w_dims, output_dims, in_num_col_dims, padding_weights); + output->Resize(output_dims); + output->set_lod(input->lod()); + + auto out_dims = output->dims(); + auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + int M = out_dims.production() / w_dims1; + + const T* input_data = input->data(); + const T* w_data = w->data(); + T* output_data = output->mutable_data(); - fc_compute_eigen( - param.input->data(), // x - param.input->dims().Slice(0, param.in_num_col_dims).production(), - param.input->dims() - .Slice(param.in_num_col_dims, param.input->dims().size()) - .production(), - param.w->data(), // w - param.w->dims()[0], // w_h - param.w->dims()[1], // w_w - param.bias->data(), // b - param.output->mutable_data()); + auto& context = ctx_->As(); + FCFunctor fc; + fc(context, + M, + w_dims1, + w_dims0, + input_data, + w_data, + output_data, + bias ? bias->data() : NULL, + with_relu, + padding_weights); } virtual ~FcCompute() = default; diff --git a/lite/kernels/x86/fc_compute_test.cc b/lite/kernels/x86/fc_compute_test.cc index abc0597457..32394798b3 100644 --- a/lite/kernels/x86/fc_compute_test.cc +++ b/lite/kernels/x86/fc_compute_test.cc @@ -11,8 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "lite/kernels/x86/fc_compute.h" #include +#include +#include #include #include "lite/core/op_registry.h" @@ -43,7 +46,7 @@ TEST(fc_x86, run_test) { w.Resize(lite::DDim(w_shape)); std::vector b_shape{1, 4}; b.Resize(lite::DDim(b_shape)); - std::vector out_shape{1, 4}; + std::vector out_shape{batch_size, 4}; out.Resize(lite::DDim(out_shape)); auto x_data = x.mutable_data(); @@ -55,16 +58,12 @@ TEST(fc_x86, run_test) { x_data[i] = static_cast(i); } for (int64_t i = 0; i < w.dims().production(); i++) { - w_data[i] = static_cast(i); + w_data[i] = static_cast(2); } for (int64_t i = 0; i < b.dims().production(); i++) { - b_data[i] = static_cast(i); + b_data[i] = static_cast(2); } - /* lite::x86::math::fc_compute_eigen(x_data, batch_size, 3, // - w_data, 3, 4, // - b_data, ref_data); */ - // FcCompute fc; FcCompute fc; operators::FcParam param; @@ -75,21 +74,17 @@ TEST(fc_x86, run_test) { param.bias = &b; param.output = &out; param.in_mat_dims = x.dims(); + param.activation_type = "relu"; - // std::unique_ptr ctx(new KernelContext); - // ctx->As(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); fc.SetParam(param); - // fc.SetContext(std::move(ctx)); + fc.SetContext(std::move(ctx)); fc.Run(); - - VLOG(3) << "output vs ref"; + std::vector ref_data({8, 8, 8, 8, 26, 26, 26, 26}); for (int i = 0; i < out.dims().production(); i++) { - VLOG(3) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); } - - /* for (int i = 0; i < out.dims().production(); ++i) { - EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); - }*/ } } // namespace x86 diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 3f2a69dfbc..141436fbf7 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -90,6 +90,9 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { if (op_desc.HasAttr("activation_type")) { param_.activation_type = op_desc.GetAttr("activation_type"); } + if (op_desc.HasAttr("padding_weights")) { + param_.activation_type = op_desc.GetAttr("padding_weights"); + } // For Int8 if (op_desc.HasAttr("enable_int8")) { diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 64e63b653b..3534998663 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -86,6 +86,7 @@ struct FcParam { lite::DDim in_mat_dims; int in_num_col_dims{1}; std::string activation_type{""}; + bool padding_weights{false}; // for int8 WITH_INT8_CONFIG }; -- GitLab