From fc84cb067be8d54af2429d57eaf1438a2ecc95f1 Mon Sep 17 00:00:00 2001 From: jiaopu Date: Mon, 11 May 2020 15:47:33 +0800 Subject: [PATCH] Add gather op --- .github/workflows/github-CI.yml | 6 +- lite/kernels/mlu/bridges/CMakeLists.txt | 3 + lite/kernels/mlu/bridges/gather_op.cc | 66 +++++++++ lite/kernels/mlu/bridges/gather_op_test.cc | 133 ++++++++++++++++++ lite/kernels/mlu/bridges/paddle_use_bridges.h | 1 + lite/kernels/mlu/bridges/test_helper.cc | 57 ++++++-- lite/kernels/mlu/subgraph_compute.h | 20 ++- 7 files changed, 270 insertions(+), 16 deletions(-) create mode 100644 lite/kernels/mlu/bridges/gather_op.cc create mode 100644 lite/kernels/mlu/bridges/gather_op_test.cc diff --git a/.github/workflows/github-CI.yml b/.github/workflows/github-CI.yml index a78a79ef93..320942e8ae 100644 --- a/.github/workflows/github-CI.yml +++ b/.github/workflows/github-CI.yml @@ -14,7 +14,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: modity build.sh - run: sed -i 's/DLITE_WITH_PYTHON=ON/DLITE_WITH_PYTHON=OFF/' lite/tools/build_mlu.sh && sed -i 's/WITH_TESTING=OFF/WITH_TESTING=ON/' lite/tools/build_mlu.sh && sed -i 's/PRINT_HW_TIME false/PRINT_HW_TIME true/' lite/kernels/mlu/bridges/graph.h + run: sed -i 's/DLITE_WITH_PYTHON=ON/DLITE_WITH_PYTHON=OFF/' lite/tools/build_mlu.sh && sed -i 's/WITH_TESTING=OFF/WITH_TESTING=ON/' lite/tools/build_mlu.sh && sed -i 's/PRINT_HW_TIME false/PRINT_HW_TIME true/' lite/kernels/mlu/bridges/graph.h && sed -i 's/BUILD_EXTRA=OFF/BUILD_EXTRA=ON/' lite/tools/build_mlu.sh - name: build run: ./lite/tools/build_mlu.sh build - name: test_act_converter_mlu @@ -47,6 +47,10 @@ jobs: run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_argmax_converter_mlu - name: test_split_converter_mlu run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_split_converter_mlu + - name: test_lrn_converter_mlu + run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_lrn_converter_mlu + - name: test_gather_converter_mlu + run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_gather_converter_mlu - name: test_classification run: | cd .. diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index bc43465c09..2f0f64c42c 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -50,6 +50,8 @@ set(mlu_subgraph_bridges if (LITE_BUILD_EXTRA) lite_cc_library(subgraph_bridge_lrn_op_mlu SRCS lrn_op.cc DEPS ${subgraph_bridge_deps_mlu}) list(APPEND mlu_subgraph_bridges subgraph_bridge_lrn_op_mlu) + lite_cc_library(subgraph_bridge_gather_op_mlu SRCS gather_op.cc DEPS ${subgraph_bridge_deps_mlu}) + list(APPEND mlu_subgraph_bridges subgraph_bridge_gather_op_mlu) endif() lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges}) @@ -71,5 +73,6 @@ lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimiz lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) if (LITE_BUILD_EXTRA) lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) + lite_cc_test(test_gather_converter_mlu SRCS gather_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) endif() message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}") diff --git a/lite/kernels/mlu/bridges/gather_op.cc b/lite/kernels/mlu/bridges/gather_op.cc new file mode 100644 index 0000000000..acf8732f86 --- /dev/null +++ b/lite/kernels/mlu/bridges/gather_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + auto x_var_name = op_info->Input("X").front(); + auto index_var_name = op_info->Input("Index").front(); + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + + CHECK(graph->HasNode(x_var_name)); + auto x_tensor = graph->GetNode(x_var_name)->mlu_tensor(); + auto index_tensor = graph->GetNode(index_var_name)->mlu_tensor(); + + auto output_tensor = graph + ->AddNode(out_var_name, + output_dims, + CNML_TENSOR, + CNML_NCHW, + graph->FPType()) + ->mlu_tensor(); + + cnmlBaseOp_t gather_op; + CNML_CALL(cnmlCreateGatherV2Op( + &gather_op, x_tensor, index_tensor, output_tensor, CNML_DIM_N)); + graph->FuseOp(gather_op); + CNML_CALL(cnmlDestroyBaseOp(&gather_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(gather, + kMLU, + paddle::lite::subgraph::mlu::GatherConverter); diff --git a/lite/kernels/mlu/bridges/gather_op_test.cc b/lite/kernels/mlu/bridges/gather_op_test.cc new file mode 100644 index 0000000000..f9b2153ca5 --- /dev/null +++ b/lite/kernels/mlu/bridges/gather_op_test.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/gather_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +template +void gather_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto index = + scope->FindVar(op_info->Input("Index").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + + auto x_dims = x->dims(); + auto index_dims = index->dims(); + CHECK(index_dims.size() == 1 || + (index_dims.size() == 2 && index_dims[1] == 1)); + + int batch_size = index_dims[0]; + DDim out_dims = x_dims; + out_dims[0] = batch_size; + out->Resize(out_dims); + + auto x_data = x->data(); + auto index_data = index->data(); + auto out_data = out->mutable_data(); + + auto slice_num = x_dims[0]; + auto slice_size = x_dims.Slice(1, x_dims.size()).production(); + for (int i = 0; i < batch_size; i++) { + auto index = index_data[i]; + CHECK_LT(index, slice_num) << "index <= slice_num"; + CHECK_GE(index, 0) << "index > 0"; + memcpy(out_data + i * slice_size, + x_data + index * slice_size, + slice_size * sizeof(float)); + } +} + +void test_gather() { + // prepare input&output variables + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + std::string index_var_name = "index"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + auto* index = scope.Var(index_var_name)->GetMutable(); + + x->Resize({5, 4, 3, 2}); + index->Resize({2}); + // initialize input&output data + FillTensor(x); + FillTensor(index, 1, 3); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("gather"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetInput("Index", {index_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + auto op = CreateOp(opdesc, &scope); + gather_ref(op); + out_ref->CopyDataFrom(*out); + + Tensor input; + input.Resize({5, 4, 3, 2}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(5), + static_cast(4), + static_cast(3), + static_cast(2)}, + {0, 2, 3, 1}); + x->CopyDataFrom(input); + LaunchOp(op, {x_var_name, index_var_name}, {out_var_name}); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + + Tensor output; + output.Resize(out->dims()); + transpose(out_data, + output.mutable_data(), + {static_cast(out->dims()[0]), + static_cast(out->dims()[2]), + static_cast(out->dims()[3]), + static_cast(out->dims()[1])}, + {0, 3, 1, 2}); + out_data = output.mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4); + } +} + +TEST(MLUBridges, gather) { test_gather(); } + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(gather, kMLU); diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index 2cd7024097..6f325ce2f4 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -38,5 +38,6 @@ USE_SUBGRAPH_BRIDGE(slice, kMLU); USE_SUBGRAPH_BRIDGE(squeeze, kMLU); USE_SUBGRAPH_BRIDGE(squeeze2, kMLU); #ifdef LITE_BUILD_EXTRA +USE_SUBGRAPH_BRIDGE(gather, kMLU); USE_SUBGRAPH_BRIDGE(lrn, kMLU) #endif diff --git a/lite/kernels/mlu/bridges/test_helper.cc b/lite/kernels/mlu/bridges/test_helper.cc index 7dca67fc30..92c6ddd18a 100644 --- a/lite/kernels/mlu/bridges/test_helper.cc +++ b/lite/kernels/mlu/bridges/test_helper.cc @@ -50,23 +50,54 @@ void LaunchOp(const std::shared_ptr op, // Convert input data var and add it into the MLU IR graph for (auto& input_name : input_var_names) { auto input_tensor = scope->FindMutableTensor(input_name); + auto data_type = input_tensor->precision(); + cnmlDataType_t fp_type; + switch (data_type) { + case paddle::lite_api::PrecisionType::kFP16: + fp_type = CNML_DATA_FLOAT16; + break; + case paddle::lite_api::PrecisionType::kFloat: + fp_type = CNML_DATA_FLOAT32; + break; + case paddle::lite_api::PrecisionType::kInt32: + fp_type = CNML_DATA_INT32; + break; + default: + CHECK(0); + } CHECK(input_tensor); Tensor temp_input; temp_input.Resize(input_tensor->dims().Vectorize()); temp_input.CopyDataFrom(*input_tensor); - auto input_node = - graph.AddNode(input_name, - input_tensor->dims().Vectorize(), - CNML_TENSOR, - CNML_NCHW, - graph.FPType(), - reinterpret_cast( - input_tensor->mutable_data(TARGET(kMLU)))); - CHECK(input_node); - CNRT_CHECK(cnrtMemcpy(input_tensor->mutable_data(), - temp_input.mutable_data(), - sizeof(float) * input_tensor->dims().production(), - CNRT_MEM_TRANS_DIR_HOST2DEV)); + if (fp_type == CNML_DATA_INT32) { + auto input_node = + graph.AddNode(input_name, + input_tensor->dims().Vectorize(), + CNML_TENSOR, + CNML_NCHW, + fp_type, + reinterpret_cast( + input_tensor->mutable_data(TARGET(kMLU)))); + CHECK(input_node); + CNRT_CHECK(cnrtMemcpy(input_tensor->mutable_data(), + temp_input.mutable_data(), + sizeof(int) * input_tensor->dims().production(), + CNRT_MEM_TRANS_DIR_HOST2DEV)); + } else { + auto input_node = + graph.AddNode(input_name, + input_tensor->dims().Vectorize(), + CNML_TENSOR, + CNML_NCHW, + fp_type, + reinterpret_cast( + input_tensor->mutable_data(TARGET(kMLU)))); + CHECK(input_node); + CNRT_CHECK(cnrtMemcpy(input_tensor->mutable_data(), + temp_input.mutable_data(), + sizeof(float) * input_tensor->dims().production(), + CNRT_MEM_TRANS_DIR_HOST2DEV)); + } } op->CheckShape(); op->InferShape(); diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index 4b1ffad2b0..8339d5b59f 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -86,6 +86,21 @@ class SubgraphEngine : public subgraph::Engine { return true; } + inline cnmlDataType_t PrecisionToDatatype(PrecisionType data_type) { + switch (data_type) { + case paddle::lite_api::PrecisionType::kFP16: + return CNML_DATA_FLOAT16; + case paddle::lite_api::PrecisionType::kFloat: + return CNML_DATA_FLOAT32; + case paddle::lite_api::PrecisionType::kInt32: + return CNML_DATA_INT32; + case paddle::lite_api::PrecisionType::kInt8: + return CNML_DATA_INT8; + default: + return PrecisionToDatatype(fp_type_); + } + } + protected: int BuildDeviceProgram() override { int status = 0; @@ -99,7 +114,8 @@ class SubgraphEngine : public subgraph::Engine { status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED; for (auto& input_name : input_names_) { auto input_tensor = scope_->FindMutableTensor(input_name); - + auto data_type = input_tensor->precision(); + cnmlDataType_t fp_type = PrecisionToDatatype(data_type); origin_itensors_.push_back(input_tensor); new_shape.push_back(input_tensor->dims().Vectorize()); @@ -108,7 +124,7 @@ class SubgraphEngine : public subgraph::Engine { input_tensor->dims().Vectorize(), CNML_TENSOR, CNML_NCHW, - graph->FPType()); + fp_type); CHECK(input_node); // MLU doesn't support dynamic dimensions/shapes, so need to rebuild // the program when the shape of any input tensor is changed. -- GitLab