From b53ece7a07653b22971ce8f0b9431f26ecb5bf69 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 20 Dec 2019 10:21:10 +0800 Subject: [PATCH] [XPU] add transpose bridge and unit test (#2630) * [XPU] add transpose bridge and unit test test=develop --- lite/backends/xpu/device.cc | 7 +- lite/core/arena/framework.cc | 12 +- lite/core/arena/framework.h | 6 +- lite/kernels/xpu/bridges/CMakeLists.txt | 2 + lite/kernels/xpu/bridges/paddle_use_bridges.h | 2 + lite/kernels/xpu/bridges/transpose_op.cc | 58 +++++++ lite/kernels/xpu/bridges/utility.cc | 12 ++ lite/kernels/xpu/bridges/utility.h | 3 + lite/kernels/xpu/subgraph_compute.cc | 17 ++- lite/operators/transpose_op.cc | 9 ++ lite/tests/kernels/CMakeLists.txt | 3 +- lite/tests/kernels/transpose_compute_test.cc | 142 ++++++++++++++++++ 12 files changed, 258 insertions(+), 15 deletions(-) create mode 100644 lite/kernels/xpu/bridges/transpose_op.cc create mode 100644 lite/tests/kernels/transpose_compute_test.cc diff --git a/lite/backends/xpu/device.cc b/lite/backends/xpu/device.cc index 769f14a642..74a5681aa9 100644 --- a/lite/backends/xpu/device.cc +++ b/lite/backends/xpu/device.cc @@ -30,7 +30,12 @@ std::unique_ptr Device::Build( // The XPU compiler build the graph and fill all of the constant params, only // one output is supported now. - xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0])); + xtcl::Array all_outs; + for (size_t i = 0; i < outputs->size(); i++) { + all_outs.push_back(*outputs->at(i)); + } + xtcl::xNetwork network = + builder->FinalizeNetwork(xtcl::relay::TupleNode::make(all_outs)); auto target = xtcl::Target::Create(device_name_); auto compiler = xtcl::network::xTensorCompiler(network, target); compiler.SetParams(*params); // Set the data of constant tensors diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index 1b2712ce66..fe36f1e1ba 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -35,12 +35,12 @@ void TestCase::CreateInstruction() { op_desc_.reset(new cpp::OpDesc()); op_desc_->SetType("subgraph"); op_desc_->SetAttr("sub_block", sub_block_idx); - op_desc_->SetInput("Inputs", op_desc_->input_vars()); - op_desc_->SetOutput("Outputs", op_desc_->output_vars()); - op_desc_->SetAttr>( - "input_data_names", sub_block_op_desc->input_vars()); - op_desc_->SetAttr>( - "output_data_names", sub_block_op_desc->output_vars()); + auto in_names = sub_block_op_desc->input_vars(); + auto out_names = sub_block_op_desc->output_vars(); + op_desc_->SetInput("Inputs", in_names); + op_desc_->SetOutput("Outputs", out_names); + op_desc_->SetAttr>("input_data_names", in_names); + op_desc_->SetAttr>("output_data_names", out_names); op = LiteOpRegistry::Global().Create(op_desc().Type()); static_cast(op.get())->SetSubBlock(sub_block_desc); } else { diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 671da20bdc..05af21bbdb 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -188,13 +188,17 @@ class Arena { tester_->Prepare(); } - bool TestPrecision() { + bool TestPrecision(const std::vector& exclude_outs = {}) { tester_->RunBaseline(tester_->baseline_scope()); tester_->RunInstruction(); bool success = true; for (auto& out : tester_->op_desc().OutputArgumentNames()) { for (auto& var : tester_->op_desc().Output(out)) { + if (std::find(exclude_outs.begin(), exclude_outs.end(), var) != + exclude_outs.end()) { + continue; + } success = success && CompareTensor(out, var); } } diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt index 4742517c1d..19e3dd7ec5 100644 --- a/lite/kernels/xpu/bridges/CMakeLists.txt +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -14,6 +14,7 @@ lite_cc_library(subgraph_bridge_pool_op_xpu SRCS pool_op.cc DEPS ${subgraph_brid lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgraph_bridge_deps_xpu}) lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps}) set(xpu_subgraph_bridges subgraph_bridge_registry @@ -26,6 +27,7 @@ set(xpu_subgraph_bridges subgraph_bridge_softmax_op_xpu subgraph_bridge_mul_op_xpu subgraph_bridge_batch_norm_op_xpu + subgraph_bridge_transpose_op_xpu CACHE INTERNAL "xpu_subgraph_bridges") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") diff --git a/lite/kernels/xpu/bridges/paddle_use_bridges.h b/lite/kernels/xpu/bridges/paddle_use_bridges.h index a99cf80a33..211a9a9ab0 100644 --- a/lite/kernels/xpu/bridges/paddle_use_bridges.h +++ b/lite/kernels/xpu/bridges/paddle_use_bridges.h @@ -22,3 +22,5 @@ USE_SUBGRAPH_BRIDGE(XPU, pool2d); USE_SUBGRAPH_BRIDGE(XPU, softmax); USE_SUBGRAPH_BRIDGE(XPU, mul); USE_SUBGRAPH_BRIDGE(XPU, batch_norm); +USE_SUBGRAPH_BRIDGE(XPU, transpose); +USE_SUBGRAPH_BRIDGE(XPU, transpose2); diff --git a/lite/kernels/xpu/bridges/transpose_op.cc b/lite/kernels/xpu/bridges/transpose_op.cc new file mode 100644 index 0000000000..3d0e87836d --- /dev/null +++ b/lite/kernels/xpu/bridges/transpose_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/graph.h" +#include "lite/kernels/xpu/bridges/utility.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace xpu { + +int TransposeConverter(void* ctx, OpLite* op) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + VLOG(3) << "[XPU] Converting " + op_type + "..."; + + // Create node and set params from op + auto x_var_name = op_info->Input("X").front(); + auto out_var_name = op_info->Output("Out").front(); + + auto axis = op_info->GetAttr>("axis"); + + CHECK(graph->HasNode(x_var_name)); + graph->AddNode( + out_var_name, + graph->builder_.CreateTranspose( + *graph->GetNode(x_var_name), + Cvt2ArrayInt(std::vector(axis.begin(), axis.end())))); + + return SUCCESS; +} + +} // namespace xpu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(XPU, + transpose, + paddle::lite::subgraph::xpu::TransposeConverter); +REGISTER_SUBGRAPH_BRIDGE(XPU, + transpose2, + paddle::lite::subgraph::xpu::TransposeConverter); diff --git a/lite/kernels/xpu/bridges/utility.cc b/lite/kernels/xpu/bridges/utility.cc index 82cce3eaf8..cf8d09a53a 100644 --- a/lite/kernels/xpu/bridges/utility.cc +++ b/lite/kernels/xpu/bridges/utility.cc @@ -125,6 +125,18 @@ std::shared_ptr CvtTensor(const Tensor& in_tensor, return out_tensor; } +xtcl::Array Cvt2ArrayInt(const std::vector& input) { + xtcl::Array output; + for (auto i : input) { + output.push_back(i); + } + return output; +} + +xtcl::Array Cvt2ArrayInt(const DDim& input) { + return Cvt2ArrayInt(input.Vectorize()); +} + } // namespace xpu } // namespace subgraph } // namespace lite diff --git a/lite/kernels/xpu/bridges/utility.h b/lite/kernels/xpu/bridges/utility.h index db2eef1f4f..f04488d2c3 100644 --- a/lite/kernels/xpu/bridges/utility.h +++ b/lite/kernels/xpu/bridges/utility.h @@ -47,6 +47,9 @@ std::shared_ptr CvtTensor( PrecisionType in_ptype = PRECISION(kFloat), DataLayoutType in_ltype = DATALAYOUT(kNCHW)); +xtcl::Array Cvt2ArrayInt(const std::vector& input); +xtcl::Array Cvt2ArrayInt(const DDim& input); + } // namespace xpu } // namespace subgraph } // namespace lite diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 8b4dc6f41d..899fb074b3 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -60,9 +60,14 @@ int SubgraphEngine::BuildDeviceProgram() { // Obtain the output nodes of the XPU IR graph and build the graph to XPU // runtime std::vector output_nodes; + std::vector valid_output_names; for (auto& output_name : output_names_) { - output_nodes.push_back(graph.GetNode(output_name).get()); + if (graph.HasNode(output_name)) { + output_nodes.push_back(graph.GetNode(output_name).get()); + valid_output_names.push_back(output_name); + } } + CHECK(!valid_output_names.empty()) << "[XPU] no valid output names"; device_program_ = lite::xpu::Device::Global().Build( &graph.builder_, &graph.params_, &output_nodes); if (device_program_ == nullptr) { @@ -73,16 +78,16 @@ int SubgraphEngine::BuildDeviceProgram() { // Query and check the dimensions of input and output tensors origin_idims_.resize(input_names_.size()); origin_itensors_.resize(input_names_.size()); - origin_odims_.resize(output_names_.size()); - origin_otensors_.resize(output_names_.size()); + origin_odims_.resize(valid_output_names.size()); + origin_otensors_.resize(valid_output_names.size()); for (int i = 0; i < input_names_.size(); i++) { origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); CHECK(origin_itensors_[i]); origin_idims_[i] = origin_itensors_[i]->dims(); VLOG(3) << "[XPU] Input dims[" << i << "]: " << origin_idims_[i]; } - for (int i = 0; i < output_names_.size(); i++) { - origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); + for (int i = 0; i < valid_output_names.size(); i++) { + origin_otensors_[i] = scope_->FindMutableTensor(valid_output_names[i]); CHECK(origin_otensors_[i]); origin_odims_[i] = origin_otensors_[i]->dims(); VLOG(3) << "[XPU] Output dims[" << i << "]: " << origin_odims_[i]; @@ -113,7 +118,7 @@ int SubgraphEngine::LaunchDeviceProgram() { device_program_->Run(); VLOG(3) << "[XPU] Process cost " << GetCurrentUS() - start_time << " us"; // Copy the data of output XPU tensor to the buffer of origin output tensors - for (size_t i = 0; i < output_names_.size(); i++) { + for (size_t i = 0; i < origin_otensors_.size(); i++) { auto output_ndarray = device_program_->GetOutput(i); std::memcpy(origin_otensors_[i]->mutable_data(), static_cast(output_ndarray.ToDLPack()->dl_tensor.data), diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index ce850be533..71086b492b 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -135,6 +135,15 @@ bool Transpose2Op::InferShape() const { out_dims[i] = x_dims[axis[i]]; } param_.output->Resize(out_dims); + + std::vector xshape_dims(x_dims.size() + 1, 0); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(xshape_dims); + auto xshape_lod = param_.xshape->mutable_lod(); + *xshape_lod = param_.x->lod(); + return true; } diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index a8065b619f..1671397ecf 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -1,4 +1,4 @@ -if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) +if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -24,6 +24,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE #lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/transpose_compute_test.cc b/lite/tests/kernels/transpose_compute_test.cc new file mode 100644 index 0000000000..62e0fc8e41 --- /dev/null +++ b/lite/tests/kernels/transpose_compute_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +int data_index(std::vector pos, DDimLite dims) { + int d1 = dims[1]; + int d2 = dims[2]; + int d3 = dims[3]; + return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1; +} + +std::vector pos_trans(std::vector in_pos, std::vector axis) { + std::vector out_pos(in_pos.size()); + for (int i = 0; i < axis.size(); i++) { + out_pos[axis[i]] = in_pos[i]; + } + return out_pos; +} + +class TransposeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string op_type_ = "transpose2"; + std::string input_ = "x"; + std::string output_ = "out"; + std::string xshape_ = "xshape"; + DDim x_dims_; + std::vector axis_; + + public: + TransposeComputeTester(const Place& place, + const std::string& alias, + DDim x_dims, + std::vector axis) + : TestCase(place, alias), x_dims_(x_dims), axis_(axis) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + + auto* x = scope->FindTensor(input_); + auto x_dims = x->dims(); + + std::vector out_shape(x_dims.size(), 0); + for (size_t i = 0; i < x_dims.size(); i++) { + out_shape[i] = x_dims[axis_[i]]; + } + out->Resize(out_shape); + + auto y_dims = out->dims(); + + int input_n = x_dims[0]; + int input_c = x_dims[1]; + int input_h = x_dims[2]; + int input_w = x_dims[3]; + + auto input_data = x->data(); + auto output_data = out->mutable_data(); + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + std::vector in_pos{n, c, h, w}; + std::vector out_pos = pos_trans(in_pos, axis_); + int in_index = data_index(in_pos, x_dims); + int out_index = data_index(out_pos, y_dims); + output_data[out_index] = input_data[in_index]; + } + } + } + } + + if (op_type_ == "transpose2") { + auto* xshape = scope->NewTensor(xshape_); + auto xshape_dims = x_dims.Vectorize(); + xshape_dims.insert(xshape_dims.begin(), 0); + xshape->Resize(xshape_dims); + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(op_type_); + op_desc->SetInput("X", {input_}); + op_desc->SetOutput("Out", {output_}); + if (op_type_ == "transpose2") { + op_desc->SetOutput("XShape", {xshape_}); + } + op_desc->SetAttr("axis", axis_); + } + + void PrepareData() override { + std::vector data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + data[i] = i * 1.1; + } + SetCommonTensor(input_, x_dims_, data.data()); + } +}; + +TEST(Transpose, precision) { + LOG(INFO) << "test Transpose op"; + float abs_error = 2e-5; + Place place; +#ifdef LITE_WITH_XPU + place = TARGET(kXPU); +#else + return; +#endif + + DDim x_dims{{2, 3, 4, 5}}; + // [XPU]: {3, 1, 0, 2} is unsupported + std::vector> axes{ + {0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}}; + for (auto axis : axes) { + std::unique_ptr tester( + new TransposeComputeTester(place, "def", x_dims, axis)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision({"xshape"}); + } +} + +} // namespace lite +} // namespace paddle -- GitLab