提交 7bd142bd 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add reshape bridge and unit test (#2621)

test=develop
上级 dded92f9
...@@ -15,6 +15,7 @@ lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgrap ...@@ -15,6 +15,7 @@ lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgrap
lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges set(xpu_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -28,6 +29,7 @@ set(xpu_subgraph_bridges ...@@ -28,6 +29,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_mul_op_xpu subgraph_bridge_mul_op_xpu
subgraph_bridge_batch_norm_op_xpu subgraph_bridge_batch_norm_op_xpu
subgraph_bridge_transpose_op_xpu subgraph_bridge_transpose_op_xpu
subgraph_bridge_reshape_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges") CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
...@@ -24,3 +24,5 @@ USE_SUBGRAPH_BRIDGE(XPU, mul); ...@@ -24,3 +24,5 @@ USE_SUBGRAPH_BRIDGE(XPU, mul);
USE_SUBGRAPH_BRIDGE(XPU, batch_norm); USE_SUBGRAPH_BRIDGE(XPU, batch_norm);
USE_SUBGRAPH_BRIDGE(XPU, transpose); USE_SUBGRAPH_BRIDGE(XPU, transpose);
USE_SUBGRAPH_BRIDGE(XPU, transpose2); USE_SUBGRAPH_BRIDGE(XPU, transpose2);
USE_SUBGRAPH_BRIDGE(XPU, reshape);
USE_SUBGRAPH_BRIDGE(XPU, reshape2);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int ReshapeConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto scope = op->scope();
auto op_type = op_info->Type();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
std::vector<int> shape;
if (op_info->HasInput("ShapeTensor") &&
!op_info->Input("ShapeTensor").empty()) {
for (auto var_name : op_info->Input("ShapeTensor")) {
shape.emplace_back(scope->FindMutableTensor(var_name)->data<int>()[0]);
}
CHECK_GT(shape.size(), 0)
<< "ShapeError: When `shape` in ReshapeOp is a list or tuple "
"which contains Tensor, the shape's size can't be zero. "
"But received shape's size is "
<< shape.size();
} else if (op_info->HasInput("Shape") && !op_info->Input("Shape").empty()) {
auto shape_tensor =
scope->FindMutableTensor(op_info->Input("Shape").front());
auto shape_data = shape_tensor->data<int>();
shape = std::vector<int>(shape_data, shape_data + shape_tensor->numel());
} else if (op_info->HasAttr("shape")) {
shape = op_info->GetAttr<std::vector<int>>("shape");
} else {
LOG(FATAL) << "no new shape for reshape op";
}
auto out_dims =
operators::ValidateShape(shape, scope->FindTensor(x_var_name)->dims());
CHECK(graph->HasNode(x_var_name));
graph->AddNode(out_var_name,
graph->builder_.CreateReshape(*graph->GetNode(x_var_name),
Cvt2ArrayInt(out_dims)));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
reshape2,
paddle::lite::subgraph::xpu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU,
reshape,
paddle::lite::subgraph::xpu::ReshapeConverter);
...@@ -25,6 +25,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH ...@@ -25,6 +25,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class ReshapeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "reshape2";
std::string input_ = "x";
std::string output_ = "out";
std::string xshape_ = "xshape";
std::vector<std::string> shape_tensor_vct_;
std::string shape_tensor_;
DDim x_dims_;
std::vector<int> shape_;
bool inplace_ = false;
public:
ReshapeComputeTester(const Place& place,
const std::string& alias,
DDim x_dims,
std::vector<int> shape,
bool is_shape_tensor_vct = false,
bool is_shape_tensor = false,
bool is_shape = true)
: TestCase(place, alias), x_dims_(x_dims) {
if (is_shape_tensor_vct) {
for (size_t i = 0; i < shape.size(); i++) {
shape_tensor_vct_.emplace_back(op_type_ + "/shape" + std::to_string(i));
}
} else if (is_shape_tensor) {
shape_tensor_ = op_type_ + "/shape";
} else if (is_shape) {
shape_ = shape;
} else {
LOG(FATAL) << "must set new shape!";
}
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(input_);
auto x_dims = x->dims();
std::vector<int> out_shape;
if (shape_tensor_vct_.size() > 0) {
for (auto shape_tensor : shape_tensor_vct_) {
out_shape.push_back(scope->FindTensor(shape_tensor)->data<int>()[0]);
}
} else if (!shape_tensor_.empty()) {
auto shape_tensor = scope->FindTensor(shape_tensor_);
auto shape_tensor_data = shape_tensor->data<int>();
out_shape = std::vector<int>(shape_tensor_data,
shape_tensor_data + shape_tensor->numel());
} else if (!shape_.empty()) {
out_shape = shape_;
} else {
LOG(FATAL) << "must set new shape!";
}
std::vector<int64_t> final_out_shape(out_shape.size(), 1);
int unk_dim_idx = -1;
int cap = 1;
for (size_t i = 0; i < out_shape.size(); i++) {
if (out_shape[i] == -1) {
CHECK_EQ(unk_dim_idx, -1);
unk_dim_idx = i;
} else if (out_shape[i] == 0) {
CHECK_LE(i, x_dims.size());
final_out_shape[i] = x_dims[i];
} else if (out_shape[i] > 0) {
final_out_shape[i] = out_shape[i];
} else {
LOG(FATAL) << "invalid shape";
}
cap *= final_out_shape[i];
}
if (unk_dim_idx > -1) {
final_out_shape[unk_dim_idx] = x_dims.production() / cap;
}
out->Resize(final_out_shape);
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * x_dims.production());
if (op_type_ == "reshape2") {
auto* xshape = scope->NewTensor(xshape_);
auto xshape_dims = x_dims.Vectorize();
xshape_dims.insert(xshape_dims.begin(), 0);
xshape->Resize(xshape_dims);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
if (shape_tensor_vct_.size() > 0) {
op_desc->SetInput("ShapeTensor", shape_tensor_vct_);
} else if (!shape_tensor_.empty()) {
op_desc->SetInput("Shape", {shape_tensor_});
} else if (shape_.size() > 0) {
op_desc->SetAttr("shape", shape_);
} else {
LOG(FATAL) << "invalid shape";
}
op_desc->SetOutput("Out", {output_});
if (op_type_ == "reshape2") {
op_desc->SetOutput("XShape", {xshape_});
}
op_desc->SetAttr("inplace", inplace_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, x_dims_, data.data());
if (shape_tensor_vct_.size() > 0) {
for (size_t i = 0; i < shape_.size(); i++) {
std::vector<int> shape_data{shape_[i]};
SetCommonTensor(shape_tensor_vct_[i],
DDim(std::vector<int64_t>{1}),
shape_data.data());
}
}
if (!shape_tensor_.empty()) {
SetCommonTensor(
shape_tensor_,
DDim(std::vector<int64_t>{static_cast<int64_t>(shape_.size())}),
shape_.data());
}
}
};
TEST(Reshape, precision) {
LOG(INFO) << "test Reshape op";
float abs_error = 2e-5;
Place place;
#ifdef LITE_WITH_XPU
place = TARGET(kXPU);
#else
return;
#endif
DDim x_dims{{2, 3, 4, 5}};
std::vector<std::vector<int>> shapes{{5, 4, 3, 2},
{2, 3, 20},
{2, 60},
{120},
{2, 3, -1},
{0, 0, 20},
{0, 0, -1}};
for (auto shape : shapes) {
std::unique_ptr<arena::TestCase> tester(
new ReshapeComputeTester(place, "def", x_dims, shape));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册