From 11bac0c9e5e8f4b7c5e50047154ab063cd9e911d Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 24 Dec 2019 22:11:39 +0800 Subject: [PATCH] [XPU] matmul bridge and unit test (#2666) --- lite/kernels/xpu/bridges/CMakeLists.txt | 2 + lite/kernels/xpu/bridges/matmul_op.cc | 88 +++++++++++++++++++ lite/kernels/xpu/bridges/paddle_use_bridges.h | 1 + lite/tests/kernels/matmul_compute_test.cc | 27 +++--- 4 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 lite/kernels/xpu/bridges/matmul_op.cc diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt index c0388e8a2c..4e166a1c12 100644 --- a/lite/kernels/xpu/bridges/CMakeLists.txt +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps}) set(xpu_subgraph_bridges subgraph_bridge_registry @@ -44,6 +45,7 @@ set(xpu_subgraph_bridges subgraph_bridge_reshape_op_xpu subgraph_bridge_layer_norm_op_xpu subgraph_bridge_dropout_op_xpu + subgraph_bridge_matmul_op_xpu CACHE INTERNAL "xpu_subgraph_bridges") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") diff --git a/lite/kernels/xpu/bridges/matmul_op.cc b/lite/kernels/xpu/bridges/matmul_op.cc new file mode 100644 index 0000000000..eaf2370ada --- /dev/null +++ b/lite/kernels/xpu/bridges/matmul_op.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/graph.h" +#include "lite/kernels/xpu/bridges/utility.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace xpu { + +int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[XPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x_type = kernel->GetInputDeclType("X"); + CHECK(x_type->precision() == PRECISION(kFloat)); + CHECK(x_type->layout() == DATALAYOUT(kNCHW)); + auto x = scope->FindMutableTensor(x_name); + auto x_dims = x->dims(); + + auto y_name = op_info->Input("Y").front(); + auto y_type = kernel->GetInputDeclType("Y"); + CHECK(y_type->precision() == PRECISION(kFloat)); + CHECK(y_type->layout() == DATALAYOUT(kNCHW)); + auto y = scope->FindMutableTensor(y_name); + auto y_dims = y->dims(); + + auto out_name = op_info->Output("Out").front(); + auto out_type = kernel->GetOutputDeclType("Out"); + CHECK(out_type->precision() == PRECISION(kFloat)); + CHECK(out_type->layout() == DATALAYOUT(kNCHW)); + + auto transpose_x = op_info->GetAttr("transpose_X"); + CHECK(!transpose_x) << "XPU only support transpose_x == true now"; + auto transpose_y = op_info->GetAttr("transpose_Y"); + auto alpha = op_info->GetAttr("alpha"); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->HasNode(x_name)) { + x_node = graph->GetNode(x_name); + } else { + x_node = graph->AddNode(x_name, x_dims); + } + + // Y node + std::shared_ptr y_node = nullptr; + if (graph->HasNode(y_name)) { + y_node = graph->GetNode(y_name); + } else { + y_node = graph->AddNode(y_name, y_dims); + } + + auto matmul_node = + graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y); + graph->AddNode(out_name, graph->builder_.CreateScale(matmul_node, alpha)); + + return SUCCESS; +} + +} // namespace xpu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(XPU, + matmul, + paddle::lite::subgraph::xpu::MatmulConverter); diff --git a/lite/kernels/xpu/bridges/paddle_use_bridges.h b/lite/kernels/xpu/bridges/paddle_use_bridges.h index 9f8cb0a61c..588fcdd6e4 100644 --- a/lite/kernels/xpu/bridges/paddle_use_bridges.h +++ b/lite/kernels/xpu/bridges/paddle_use_bridges.h @@ -35,3 +35,4 @@ USE_SUBGRAPH_BRIDGE(XPU, reshape2); USE_SUBGRAPH_BRIDGE(XPU, layer_norm); USE_SUBGRAPH_BRIDGE(XPU, gelu); USE_SUBGRAPH_BRIDGE(XPU, dropout); +USE_SUBGRAPH_BRIDGE(XPU, matmul); diff --git a/lite/tests/kernels/matmul_compute_test.cc b/lite/tests/kernels/matmul_compute_test.cc index 4915614b34..5d19e7fe3c 100644 --- a/lite/tests/kernels/matmul_compute_test.cc +++ b/lite/tests/kernels/matmul_compute_test.cc @@ -502,13 +502,16 @@ void test_matmulnxn_transpose(Place place) { } TEST(Matmul2x2, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); + Place place; +#if defined(LITE_WITH_ARM) + place = TARGET(kARM); +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#else + return; #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); + test_matmul2x2_no_transform(place); -#endif } TEST(Matmul2x2_x_transpose, precision) { @@ -520,14 +523,18 @@ TEST(Matmul2x2_x_transpose, precision) { test_matmul2x2_x_transpose(place); #endif } + TEST(Matmul2x2_y_transpose, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); + Place place; +#if defined(LITE_WITH_ARM) + place = TARGET(kARM); +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#else + return; #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); + test_matmul2x2_y_transpose(place); -#endif } TEST(Matmul2x2_transpose, precision) { -- GitLab