提交 11bac0c9 编写于 作者: Z zhupengyang 提交者: hong19860320

[XPU] matmul bridge and unit test (#2666)

上级 a386ed11
......@@ -23,6 +23,7 @@ lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu
lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges
subgraph_bridge_registry
......@@ -44,6 +45,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_reshape_op_xpu
subgraph_bridge_layer_norm_op_xpu
subgraph_bridge_dropout_op_xpu
subgraph_bridge_matmul_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto transpose_x = op_info->GetAttr<bool>("transpose_X");
CHECK(!transpose_x) << "XPU only support transpose_x == true now";
auto transpose_y = op_info->GetAttr<bool>("transpose_Y");
auto alpha = op_info->GetAttr<float>("alpha");
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Y node
std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
} else {
y_node = graph->AddNode(y_name, y_dims);
}
auto matmul_node =
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y);
graph->AddNode(out_name, graph->builder_.CreateScale(matmul_node, alpha));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
matmul,
paddle::lite::subgraph::xpu::MatmulConverter);
......@@ -35,3 +35,4 @@ USE_SUBGRAPH_BRIDGE(XPU, reshape2);
USE_SUBGRAPH_BRIDGE(XPU, layer_norm);
USE_SUBGRAPH_BRIDGE(XPU, gelu);
USE_SUBGRAPH_BRIDGE(XPU, dropout);
USE_SUBGRAPH_BRIDGE(XPU, matmul);
......@@ -502,13 +502,16 @@ void test_matmulnxn_transpose(Place place) {
}
TEST(Matmul2x2, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_matmul2x2_no_transform(place);
#endif
}
TEST(Matmul2x2_x_transpose, precision) {
......@@ -520,14 +523,18 @@ TEST(Matmul2x2_x_transpose, precision) {
test_matmul2x2_x_transpose(place);
#endif
}
TEST(Matmul2x2_y_transpose, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_matmul2x2_y_transpose(place);
#endif
}
TEST(Matmul2x2_transpose, precision) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册