From 002413ad814b0d5bb099466a346313617ce56366 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Thu, 12 Sep 2019 18:45:34 +0800 Subject: [PATCH] add matmul op kernels for asr test=develop (#2032) --- lite/kernels/x86/CMakeLists.txt | 2 + lite/kernels/x86/matmul_compute.cc | 26 ++++++++ lite/kernels/x86/matmul_compute.h | 76 +++++++++++++++++++++ lite/kernels/x86/matmul_compute_test.cc | 87 +++++++++++++++++++++++++ 4 files changed, 191 insertions(+) create mode 100644 lite/kernels/x86/matmul_compute.cc create mode 100644 lite/kernels/x86/matmul_compute.h create mode 100644 lite/kernels/x86/matmul_compute_test.cc diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 40c0645ece..4b50ec1f0d 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -38,6 +38,7 @@ add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${ if(NOT LITE_WITH_X86) return() endif() +add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) @@ -48,3 +49,4 @@ lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc D lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) +lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) diff --git a/lite/kernels/x86/matmul_compute.cc b/lite/kernels/x86/matmul_compute.cc new file mode 100644 index 0000000000..6949e018cb --- /dev/null +++ b/lite/kernels/x86/matmul_compute.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/matmul_compute.h" + +REGISTER_LITE_KERNEL(matmul, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::MatMulCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/matmul_compute.h b/lite/kernels/x86/matmul_compute.h new file mode 100644 index 0000000000..3d2b3c7482 --- /dev/null +++ b/lite/kernels/x86/matmul_compute.h @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "lite/backends/x86/math/blas.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static lite::DDim RowMatrixFromVector(const lite::DDim &x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return lite::DDim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static lite::DDim ColumnMatrixFromVector(const lite::DDim &y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return lite::DDim({y_dim[0], 1}); +} + +template +class MatMulCompute : public KernelLite { + public: + using param_t = operators::MatMulParam; + + void Run() override { + auto &context = ctx_->As(); + auto ¶m = *param_.get_mutable(); + + auto *x = param.X; + auto *y = param.Y; + auto *out = param.Out; + out->mutable_data(); + + auto blas = lite::x86::math::GetBlas(context); + auto mat_dim_a = lite::x86::math::CreateMatrixDescriptor( + RowMatrixFromVector(x->dims()), 0, param.transpose_X); + auto mat_dim_b = lite::x86::math::CreateMatrixDescriptor( + ColumnMatrixFromVector(y->dims()), 0, param.transpose_Y); + auto scale = static_cast(param.alpha); + blas.MatMul(*x, mat_dim_a, *y, mat_dim_b, scale, out, T(0)); + } + + virtual ~MatMulCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/matmul_compute_test.cc b/lite/kernels/x86/matmul_compute_test.cc new file mode 100644 index 0000000000..53d2d1a47a --- /dev/null +++ b/lite/kernels/x86/matmul_compute_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/matmul_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(matmul_x86, retrive_op) { + auto matmul = + KernelRegistry::Global().Create( + "matmul"); + ASSERT_FALSE(matmul.empty()); + ASSERT_TRUE(matmul.front()); +} + +TEST(matmul_x86, init) { + lite::kernels::x86::MatMulCompute matmul; + ASSERT_EQ(matmul.precision(), PRECISION(kFloat)); + ASSERT_EQ(matmul.target(), TARGET(kX86)); +} + +TEST(matmul_x86, run_test) { + lite::Tensor x, y, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector y_shape{2, 4}; + y.Resize(lite::DDim(y_shape)); + std::vector out_shape{batch_size, 3, 4}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto y_data = y.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < y.dims().production(); i++) { + y_data[i] = static_cast(i); + } + // MatMulCompute matmul; + MatMulCompute matmul; + operators::MatMulParam param; + + param.X = &x; + param.Y = &y; + param.Out = &out; + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + matmul.SetContext(std::move(ctx)); + matmul.SetParam(param); + matmul.Run(); + + std::vector ref_result = {4, 5, 6, 7, 12, 17, 22, 27, 20, 29, 38, 47}; + + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_result[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(matmul, kX86, kFloat, kNCHW, def); -- GitLab