diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt index 19e3dd7ec53f87454563987597071973b4dc3123..339eb5976f30ca5dfced09e19815b0f7a014b5c1 100644 --- a/lite/kernels/xpu/bridges/CMakeLists.txt +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -15,6 +15,7 @@ lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgrap lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_subgraph_bridge_deps}) set(xpu_subgraph_bridges subgraph_bridge_registry @@ -28,6 +29,7 @@ set(xpu_subgraph_bridges subgraph_bridge_mul_op_xpu subgraph_bridge_batch_norm_op_xpu subgraph_bridge_transpose_op_xpu + subgraph_bridge_reshape_op_xpu CACHE INTERNAL "xpu_subgraph_bridges") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") diff --git a/lite/kernels/xpu/bridges/paddle_use_bridges.h b/lite/kernels/xpu/bridges/paddle_use_bridges.h index 211a9a9ab0d4027fb2cb6ec31447c481e6221d94..f899f0d21cb9e040cfb5f8c18b2a1af6d4f8fc51 100644 --- a/lite/kernels/xpu/bridges/paddle_use_bridges.h +++ b/lite/kernels/xpu/bridges/paddle_use_bridges.h @@ -24,3 +24,5 @@ USE_SUBGRAPH_BRIDGE(XPU, mul); USE_SUBGRAPH_BRIDGE(XPU, batch_norm); USE_SUBGRAPH_BRIDGE(XPU, transpose); USE_SUBGRAPH_BRIDGE(XPU, transpose2); +USE_SUBGRAPH_BRIDGE(XPU, reshape); +USE_SUBGRAPH_BRIDGE(XPU, reshape2); diff --git a/lite/kernels/xpu/bridges/reshape_op.cc b/lite/kernels/xpu/bridges/reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..37d7bf58b0756ea3a83b3d396a75fa542ca03442 --- /dev/null +++ b/lite/kernels/xpu/bridges/reshape_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reshape_op.h" +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/graph.h" +#include "lite/kernels/xpu/bridges/utility.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace xpu { + +int ReshapeConverter(void* ctx, OpLite* op) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto scope = op->scope(); + auto op_type = op_info->Type(); + VLOG(3) << "[XPU] Converting " + op_type + "..."; + + // Create node and set params from op + auto x_var_name = op_info->Input("X").front(); + auto out_var_name = op_info->Output("Out").front(); + + std::vector shape; + if (op_info->HasInput("ShapeTensor") && + !op_info->Input("ShapeTensor").empty()) { + for (auto var_name : op_info->Input("ShapeTensor")) { + shape.emplace_back(scope->FindMutableTensor(var_name)->data()[0]); + } + CHECK_GT(shape.size(), 0) + << "ShapeError: When `shape` in ReshapeOp is a list or tuple " + "which contains Tensor, the shape's size can't be zero. " + "But received shape's size is " + << shape.size(); + } else if (op_info->HasInput("Shape") && !op_info->Input("Shape").empty()) { + auto shape_tensor = + scope->FindMutableTensor(op_info->Input("Shape").front()); + auto shape_data = shape_tensor->data(); + shape = std::vector(shape_data, shape_data + shape_tensor->numel()); + } else if (op_info->HasAttr("shape")) { + shape = op_info->GetAttr>("shape"); + } else { + LOG(FATAL) << "no new shape for reshape op"; + } + auto out_dims = + operators::ValidateShape(shape, scope->FindTensor(x_var_name)->dims()); + + CHECK(graph->HasNode(x_var_name)); + graph->AddNode(out_var_name, + graph->builder_.CreateReshape(*graph->GetNode(x_var_name), + Cvt2ArrayInt(out_dims))); + + return SUCCESS; +} + +} // namespace xpu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(XPU, + reshape2, + paddle::lite::subgraph::xpu::ReshapeConverter); +REGISTER_SUBGRAPH_BRIDGE(XPU, + reshape, + paddle::lite::subgraph::xpu::ReshapeConverter); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 1671397ecfeef73e1088958189cc00eecc916d02..f8fbb732674a37256b959b182df4a3a859b3999f 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -25,6 +25,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/reshape_compute_test.cc b/lite/tests/kernels/reshape_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..85cd724148290a06a9303515004e8d003c32c053 --- /dev/null +++ b/lite/tests/kernels/reshape_compute_test.cc @@ -0,0 +1,187 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class ReshapeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string op_type_ = "reshape2"; + std::string input_ = "x"; + std::string output_ = "out"; + std::string xshape_ = "xshape"; + std::vector shape_tensor_vct_; + std::string shape_tensor_; + DDim x_dims_; + std::vector shape_; + bool inplace_ = false; + + public: + ReshapeComputeTester(const Place& place, + const std::string& alias, + DDim x_dims, + std::vector shape, + bool is_shape_tensor_vct = false, + bool is_shape_tensor = false, + bool is_shape = true) + : TestCase(place, alias), x_dims_(x_dims) { + if (is_shape_tensor_vct) { + for (size_t i = 0; i < shape.size(); i++) { + shape_tensor_vct_.emplace_back(op_type_ + "/shape" + std::to_string(i)); + } + } else if (is_shape_tensor) { + shape_tensor_ = op_type_ + "/shape"; + } else if (is_shape) { + shape_ = shape; + } else { + LOG(FATAL) << "must set new shape!"; + } + } + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + + auto* x = scope->FindTensor(input_); + auto x_dims = x->dims(); + + std::vector out_shape; + if (shape_tensor_vct_.size() > 0) { + for (auto shape_tensor : shape_tensor_vct_) { + out_shape.push_back(scope->FindTensor(shape_tensor)->data()[0]); + } + } else if (!shape_tensor_.empty()) { + auto shape_tensor = scope->FindTensor(shape_tensor_); + auto shape_tensor_data = shape_tensor->data(); + out_shape = std::vector(shape_tensor_data, + shape_tensor_data + shape_tensor->numel()); + } else if (!shape_.empty()) { + out_shape = shape_; + } else { + LOG(FATAL) << "must set new shape!"; + } + + std::vector final_out_shape(out_shape.size(), 1); + int unk_dim_idx = -1; + int cap = 1; + for (size_t i = 0; i < out_shape.size(); i++) { + if (out_shape[i] == -1) { + CHECK_EQ(unk_dim_idx, -1); + unk_dim_idx = i; + } else if (out_shape[i] == 0) { + CHECK_LE(i, x_dims.size()); + final_out_shape[i] = x_dims[i]; + } else if (out_shape[i] > 0) { + final_out_shape[i] = out_shape[i]; + } else { + LOG(FATAL) << "invalid shape"; + } + cap *= final_out_shape[i]; + } + + if (unk_dim_idx > -1) { + final_out_shape[unk_dim_idx] = x_dims.production() / cap; + } + + out->Resize(final_out_shape); + + auto x_data = x->data(); + auto out_data = out->mutable_data(); + memcpy(out_data, x_data, sizeof(float) * x_dims.production()); + + if (op_type_ == "reshape2") { + auto* xshape = scope->NewTensor(xshape_); + auto xshape_dims = x_dims.Vectorize(); + xshape_dims.insert(xshape_dims.begin(), 0); + xshape->Resize(xshape_dims); + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(op_type_); + op_desc->SetInput("X", {input_}); + if (shape_tensor_vct_.size() > 0) { + op_desc->SetInput("ShapeTensor", shape_tensor_vct_); + } else if (!shape_tensor_.empty()) { + op_desc->SetInput("Shape", {shape_tensor_}); + } else if (shape_.size() > 0) { + op_desc->SetAttr("shape", shape_); + } else { + LOG(FATAL) << "invalid shape"; + } + op_desc->SetOutput("Out", {output_}); + if (op_type_ == "reshape2") { + op_desc->SetOutput("XShape", {xshape_}); + } + op_desc->SetAttr("inplace", inplace_); + } + + void PrepareData() override { + std::vector data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + data[i] = i * 1.1; + } + SetCommonTensor(input_, x_dims_, data.data()); + + if (shape_tensor_vct_.size() > 0) { + for (size_t i = 0; i < shape_.size(); i++) { + std::vector shape_data{shape_[i]}; + SetCommonTensor(shape_tensor_vct_[i], + DDim(std::vector{1}), + shape_data.data()); + } + } + if (!shape_tensor_.empty()) { + SetCommonTensor( + shape_tensor_, + DDim(std::vector{static_cast(shape_.size())}), + shape_.data()); + } + } +}; + +TEST(Reshape, precision) { + LOG(INFO) << "test Reshape op"; + float abs_error = 2e-5; + Place place; +#ifdef LITE_WITH_XPU + place = TARGET(kXPU); +#else + return; +#endif + + DDim x_dims{{2, 3, 4, 5}}; + std::vector> shapes{{5, 4, 3, 2}, + {2, 3, 20}, + {2, 60}, + {120}, + {2, 3, -1}, + {0, 0, 20}, + {0, 0, -1}}; + for (auto shape : shapes) { + std::unique_ptr tester( + new ReshapeComputeTester(place, "def", x_dims, shape)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision({"xshape"}); + } +} + +} // namespace lite +} // namespace paddle