未验证 提交 b53ece7a 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add transpose bridge and unit test (#2630)

* [XPU] add transpose bridge and unit test

test=develop
上级 30ec4fba
......@@ -30,7 +30,12 @@ std::unique_ptr<xtcl::network::xRuntimeInstance> Device::Build(
// The XPU compiler build the graph and fill all of the constant params, only
// one output is supported now.
xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0]));
xtcl::Array<xtcl::xExpr> all_outs;
for (size_t i = 0; i < outputs->size(); i++) {
all_outs.push_back(*outputs->at(i));
}
xtcl::xNetwork network =
builder->FinalizeNetwork(xtcl::relay::TupleNode::make(all_outs));
auto target = xtcl::Target::Create(device_name_);
auto compiler = xtcl::network::xTensorCompiler(network, target);
compiler.SetParams(*params); // Set the data of constant tensors
......
......@@ -35,12 +35,12 @@ void TestCase::CreateInstruction() {
op_desc_.reset(new cpp::OpDesc());
op_desc_->SetType("subgraph");
op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx);
op_desc_->SetInput("Inputs", op_desc_->input_vars());
op_desc_->SetOutput("Outputs", op_desc_->output_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"input_data_names", sub_block_op_desc->input_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"output_data_names", sub_block_op_desc->output_vars());
auto in_names = sub_block_op_desc->input_vars();
auto out_names = sub_block_op_desc->output_vars();
op_desc_->SetInput("Inputs", in_names);
op_desc_->SetOutput("Outputs", out_names);
op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names);
op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names);
op = LiteOpRegistry::Global().Create(op_desc().Type());
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(sub_block_desc);
} else {
......
......@@ -188,13 +188,17 @@ class Arena {
tester_->Prepare();
}
bool TestPrecision() {
bool TestPrecision(const std::vector<std::string>& exclude_outs = {}) {
tester_->RunBaseline(tester_->baseline_scope());
tester_->RunInstruction();
bool success = true;
for (auto& out : tester_->op_desc().OutputArgumentNames()) {
for (auto& var : tester_->op_desc().Output(out)) {
if (std::find(exclude_outs.begin(), exclude_outs.end(), var) !=
exclude_outs.end()) {
continue;
}
success = success && CompareTensor(out, var);
}
}
......
......@@ -14,6 +14,7 @@ lite_cc_library(subgraph_bridge_pool_op_xpu SRCS pool_op.cc DEPS ${subgraph_brid
lite_cc_library(subgraph_bridge_softmax_op_xpu SRCS softmax_op.cc DEPS ${subgraph_bridge_deps_xpu})
lite_cc_library(subgraph_bridge_mul_op_xpu SRCS mul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_batch_norm_op_xpu SRCS batch_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_xpu SRCS transpose_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges
subgraph_bridge_registry
......@@ -26,6 +27,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_softmax_op_xpu
subgraph_bridge_mul_op_xpu
subgraph_bridge_batch_norm_op_xpu
subgraph_bridge_transpose_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
......@@ -22,3 +22,5 @@ USE_SUBGRAPH_BRIDGE(XPU, pool2d);
USE_SUBGRAPH_BRIDGE(XPU, softmax);
USE_SUBGRAPH_BRIDGE(XPU, mul);
USE_SUBGRAPH_BRIDGE(XPU, batch_norm);
USE_SUBGRAPH_BRIDGE(XPU, transpose);
USE_SUBGRAPH_BRIDGE(XPU, transpose2);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int TransposeConverter(void* ctx, OpLite* op) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
CHECK(graph->HasNode(x_var_name));
graph->AddNode(
out_var_name,
graph->builder_.CreateTranspose(
*graph->GetNode(x_var_name),
Cvt2ArrayInt(std::vector<int64_t>(axis.begin(), axis.end()))));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(XPU,
transpose,
paddle::lite::subgraph::xpu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(XPU,
transpose2,
paddle::lite::subgraph::xpu::TransposeConverter);
......@@ -125,6 +125,18 @@ std::shared_ptr<xtcl::xNDArray> CvtTensor(const Tensor& in_tensor,
return out_tensor;
}
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const std::vector<int64_t>& input) {
xtcl::Array<xtcl::Integer> output;
for (auto i : input) {
output.push_back(i);
}
return output;
}
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const DDim& input) {
return Cvt2ArrayInt(input.Vectorize());
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
......
......@@ -47,6 +47,9 @@ std::shared_ptr<xtcl::xNDArray> CvtTensor(
PrecisionType in_ptype = PRECISION(kFloat),
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const std::vector<int64_t>& input);
xtcl::Array<xtcl::Integer> Cvt2ArrayInt(const DDim& input);
} // namespace xpu
} // namespace subgraph
} // namespace lite
......
......@@ -60,9 +60,14 @@ int SubgraphEngine::BuildDeviceProgram() {
// Obtain the output nodes of the XPU IR graph and build the graph to XPU
// runtime
std::vector<xtcl::xExpr*> output_nodes;
std::vector<std::string> valid_output_names;
for (auto& output_name : output_names_) {
output_nodes.push_back(graph.GetNode(output_name).get());
if (graph.HasNode(output_name)) {
output_nodes.push_back(graph.GetNode(output_name).get());
valid_output_names.push_back(output_name);
}
}
CHECK(!valid_output_names.empty()) << "[XPU] no valid output names";
device_program_ = lite::xpu::Device::Global().Build(
&graph.builder_, &graph.params_, &output_nodes);
if (device_program_ == nullptr) {
......@@ -73,16 +78,16 @@ int SubgraphEngine::BuildDeviceProgram() {
// Query and check the dimensions of input and output tensors
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
origin_odims_.resize(valid_output_names.size());
origin_otensors_.resize(valid_output_names.size());
for (int i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
VLOG(3) << "[XPU] Input dims[" << i << "]: " << origin_idims_[i];
}
for (int i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
for (int i = 0; i < valid_output_names.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(valid_output_names[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
VLOG(3) << "[XPU] Output dims[" << i << "]: " << origin_odims_[i];
......@@ -113,7 +118,7 @@ int SubgraphEngine::LaunchDeviceProgram() {
device_program_->Run();
VLOG(3) << "[XPU] Process cost " << GetCurrentUS() - start_time << " us";
// Copy the data of output XPU tensor to the buffer of origin output tensors
for (size_t i = 0; i < output_names_.size(); i++) {
for (size_t i = 0; i < origin_otensors_.size(); i++) {
auto output_ndarray = device_program_->GetOutput(i);
std::memcpy(origin_otensors_[i]->mutable_data<float>(),
static_cast<float*>(output_ndarray.ToDLPack()->dl_tensor.data),
......
......@@ -135,6 +135,15 @@ bool Transpose2Op::InferShape() const {
out_dims[i] = x_dims[axis[i]];
}
param_.output->Resize(out_dims);
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.xshape->Resize(xshape_dims);
auto xshape_lod = param_.xshape->mutable_lod();
*xshape_lod = param_.x->lod();
return true;
}
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -24,6 +24,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
int data_index(std::vector<int> pos, DDimLite dims) {
int d1 = dims[1];
int d2 = dims[2];
int d3 = dims[3];
return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1;
}
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (int i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[i];
}
return out_pos;
}
class TransposeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "transpose2";
std::string input_ = "x";
std::string output_ = "out";
std::string xshape_ = "xshape";
DDim x_dims_;
std::vector<int> axis_;
public:
TransposeComputeTester(const Place& place,
const std::string& alias,
DDim x_dims,
std::vector<int> axis)
: TestCase(place, alias), x_dims_(x_dims), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(input_);
auto x_dims = x->dims();
std::vector<int64_t> out_shape(x_dims.size(), 0);
for (size_t i = 0; i < x_dims.size(); i++) {
out_shape[i] = x_dims[axis_[i]];
}
out->Resize(out_shape);
auto y_dims = out->dims();
int input_n = x_dims[0];
int input_c = x_dims[1];
int input_h = x_dims[2];
int input_w = x_dims[3];
auto input_data = x->data<float>();
auto output_data = out->mutable_data<float>();
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
std::vector<int> in_pos{n, c, h, w};
std::vector<int> out_pos = pos_trans(in_pos, axis_);
int in_index = data_index(in_pos, x_dims);
int out_index = data_index(out_pos, y_dims);
output_data[out_index] = input_data[in_index];
}
}
}
}
if (op_type_ == "transpose2") {
auto* xshape = scope->NewTensor(xshape_);
auto xshape_dims = x_dims.Vectorize();
xshape_dims.insert(xshape_dims.begin(), 0);
xshape->Resize(xshape_dims);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
if (op_type_ == "transpose2") {
op_desc->SetOutput("XShape", {xshape_});
}
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, x_dims_, data.data());
}
};
TEST(Transpose, precision) {
LOG(INFO) << "test Transpose op";
float abs_error = 2e-5;
Place place;
#ifdef LITE_WITH_XPU
place = TARGET(kXPU);
#else
return;
#endif
DDim x_dims{{2, 3, 4, 5}};
// [XPU]: {3, 1, 0, 2} is unsupported
std::vector<std::vector<int>> axes{
{0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}};
for (auto axis : axes) {
std::unique_ptr<arena::TestCase> tester(
new TransposeComputeTester(place, "def", x_dims, axis));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册