From 64a796d16e927506bc96bcb35dc1ff0da6cc722a Mon Sep 17 00:00:00 2001 From: Santa An <49897975+AnBaolei1984@users.noreply.github.com> Date: Tue, 17 Mar 2020 16:00:44 +0800 Subject: [PATCH] [LITE][BM] support reduce and other ops, test=develop (#3199) * * support download bm_sdk, test=develop * [LITE][BM] add slice op * [LITE][BM] fix concat issue * [LITE][BM] support reduce full ops, test=develop * [LITE][BM] change test_resnet50 to change test_classify * [LITE][BM] add cast op * [LITE][BM] add reduce and other ops, test=develop * [LITE][BM] add reduce,cast and other ops, test=develop --- lite/api/CMakeLists.txt | 2 +- ...50_lite_bm.cc => test_classify_lite_bm.cc} | 2 +- lite/kernels/bm/bridges/CMakeLists.txt | 10 ++ lite/kernels/bm/bridges/cast_op.cc | 90 ++++++++++++++ lite/kernels/bm/bridges/concat_op.cc | 3 - lite/kernels/bm/bridges/conv_op.cc | 1 - lite/kernels/bm/bridges/conv_transpose_op.cc | 110 ++++++++++++++++++ lite/kernels/bm/bridges/interpolate_op.cc | 1 - lite/kernels/bm/bridges/paddle_use_bridges.h | 7 ++ lite/kernels/bm/bridges/reduce_full_op.cc | 77 ++++++++++++ lite/kernels/bm/bridges/slice_op.cc | 93 +++++++++++++++ lite/kernels/bm/bridges/squeeze_op.cc | 71 +++++++++++ 12 files changed, 460 insertions(+), 7 deletions(-) rename lite/api/{test_resnet50_lite_bm.cc => test_classify_lite_bm.cc} (99%) create mode 100644 lite/kernels/bm/bridges/cast_op.cc create mode 100644 lite/kernels/bm/bridges/conv_transpose_op.cc create mode 100644 lite/kernels/bm/bridges/reduce_full_op.cc create mode 100644 lite/kernels/bm/bridges/slice_op.cc create mode 100644 lite/kernels/bm/bridges/squeeze_op.cc diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 9e57bbbf00..e786f346cc 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -181,7 +181,7 @@ if(WITH_TESTING) add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz) endif() if(LITE_WITH_BM) - lite_cc_test(test_resnet50_lite_bm SRCS test_resnet50_lite_bm.cc + lite_cc_test(test_classify_lite_bm SRCS test_classify_lite_bm.cc DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils ${ops} ${host_kernels} ${bm_kernels} ${bm_bridges} ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) diff --git a/lite/api/test_resnet50_lite_bm.cc b/lite/api/test_classify_lite_bm.cc similarity index 99% rename from lite/api/test_resnet50_lite_bm.cc rename to lite/api/test_classify_lite_bm.cc index 73ad405f16..7da7dc0374 100644 --- a/lite/api/test_resnet50_lite_bm.cc +++ b/lite/api/test_classify_lite_bm.cc @@ -80,7 +80,7 @@ void TestModel(const std::vector& valid_places) { fclose(fp); } -TEST(ResNet50, test_bm) { +TEST(Classify, test_bm) { std::vector valid_places({Place{TARGET(kBM), PRECISION(kFloat)}, Place{TARGET(kX86), PRECISION(kFloat)}}); diff --git a/lite/kernels/bm/bridges/CMakeLists.txt b/lite/kernels/bm/bridges/CMakeLists.txt index ffe5018ba9..75375f493f 100644 --- a/lite/kernels/bm/bridges/CMakeLists.txt +++ b/lite/kernels/bm/bridges/CMakeLists.txt @@ -25,6 +25,11 @@ lite_cc_library(subgraph_bridge_box_coder_op_bm SRCS box_coder_op.cc DEPS ${bm_s lite_cc_library(subgraph_bridge_multiclass_nms_op_bm SRCS multiclass_nms_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_interpolate_op_bm SRCS interpolate_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_yolo_box_op_bm SRCS yolo_box_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_slice_op_bm SRCS slice_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_conv_transpose_op_bm SRCS conv_transpose_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_reduce_full_op_bm SRCS reduce_full_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps}) set(bm_subgraph_bridges subgraph_bridge_registry @@ -48,4 +53,9 @@ set(bm_subgraph_bridges subgraph_bridge_multiclass_nms_op_bm subgraph_bridge_interpolate_op_bm subgraph_bridge_yolo_box_op_bm + subgraph_bridge_slice_op_bm + subgraph_bridge_conv_transpose_op_bm + subgraph_bridge_reduce_full_op_bm + subgraph_bridge_squeeze_op_bm + subgraph_bridge_cast_op_bm CACHE INTERNAL "bm_subgraph_bridges") diff --git a/lite/kernels/bm/bridges/cast_op.cc b/lite/kernels/bm/bridges/cast_op.cc new file mode 100644 index 0000000000..33be20685b --- /dev/null +++ b/lite/kernels/bm/bridges/cast_op.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +bool CvtDtype(int dtype, int* ptype) { + switch (dtype) { + case 21: + *ptype = DTYPE_INT8; + break; + case 1: + *ptype = DTYPE_INT16; + break; + case 2: + *ptype = DTYPE_FP32; + break; + case 5: + *ptype = DTYPE_FP32; + break; + default: + LOG(WARNING) << "[BM] unsupported date type: " << dtype; + return false; + } + return true; +} + +int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + auto output_var_name = op_info->Output("Out").front(); + std::vector i_x_shape_data(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_dims[i]); + } + + int in_dtype = op_info->GetAttr("in_dtype"); + int out_dtype = op_info->GetAttr("out_dtype"); + + if (in_dtype == out_dtype) { + add_identity_layer(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + static_cast(output_var_name.c_str())); + } else { + int out_bm_dtype = 0; + CHECK_EQ(CvtDtype(out_dtype, &out_bm_dtype), true); + add_shape_cast_layer(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + static_cast(output_var_name.c_str()), + out_bm_dtype); + } + + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(cast, kBM, paddle::lite::subgraph::bm::CastConverter); diff --git a/lite/kernels/bm/bridges/concat_op.cc b/lite/kernels/bm/bridges/concat_op.cc index 0b568aa4d1..1fa8032885 100644 --- a/lite/kernels/bm/bridges/concat_op.cc +++ b/lite/kernels/bm/bridges/concat_op.cc @@ -30,8 +30,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto op_type = op_info->Type(); // input auto x_names = op_info->Input("X"); - auto x_type = kernel->GetInputDeclType("X"); - CHECK(x_type->layout() == DATALAYOUT(kNCHW)); // output auto output_var_name = op_info->Output("Out").front(); auto output = scope->FindVar(output_var_name)->GetMutable(); @@ -57,7 +55,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { shape[i][j] = static_cast(x_shape_data[j]); } } - auto axis = op_info->GetAttr("axis"); add_concat_layer(graph->GetCompilerHandle(), input_num, diff --git a/lite/kernels/bm/bridges/conv_op.cc b/lite/kernels/bm/bridges/conv_op.cc index ffe5a59aca..e4dff10702 100644 --- a/lite/kernels/bm/bridges/conv_op.cc +++ b/lite/kernels/bm/bridges/conv_op.cc @@ -55,7 +55,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { const_cast(&output_dims.data()[0]); std::vector i_input_shape_data(input_dims.size()); std::vector i_output_shape_data(output_dims.size()); - for (size_t i = 0; i < input_dims.size(); i++) { i_input_shape_data[i] = static_cast(input_shape_data[i]); } diff --git a/lite/kernels/bm/bridges/conv_transpose_op.cc b/lite/kernels/bm/bridges/conv_transpose_op.cc new file mode 100644 index 0000000000..b875feaa03 --- /dev/null +++ b/lite/kernels/bm/bridges/conv_transpose_op.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); + auto input_var_name = op_info->Input("Input").front(); + auto input = scope->FindVar(input_var_name)->GetMutable(); + auto input_dims = input->dims(); + auto output_var_name = op_info->Output("Output").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + auto filter_var_name = op_info->Input("Filter").front(); + auto filter = scope->FindVar(filter_var_name)->GetMutable(); + auto filter_dims = filter->dims(); + CHECK_EQ(input_dims.size(), 4); + CHECK_EQ(output_dims.size(), 4); + CHECK_EQ(filter_dims.size(), 4); + bool has_bias = lite::subgraph::bm::HasInputArg(op_info, scope, "Bias"); + float* bias_data = nullptr; + if (has_bias) { + auto bias_var_name = op_info->Input("Bias").front(); + auto* bias = scope->FindVar(bias_var_name)->GetMutable(); + bias_data = static_cast(bias->mutable_data()); + } + const int64_t* input_shape_data = + const_cast(&input_dims.data()[0]); + const int64_t* output_shape_data = + const_cast(&output_dims.data()[0]); + std::vector i_input_shape_data(input_dims.size()); + std::vector i_output_shape_data(output_dims.size()); + + for (size_t i = 0; i < input_dims.size(); i++) { + i_input_shape_data[i] = static_cast(input_shape_data[i]); + } + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_shape_data[i]); + } + const float* filter_data = + const_cast(filter->mutable_data()); + auto groups = op_info->GetAttr("groups"); + auto paddings = op_info->GetAttr>("paddings"); + auto strides = op_info->GetAttr>("strides"); + auto dilations = op_info->GetAttr>("dilations"); + + bool fuse_relu = false; + if (op_info->HasAttr("fuse_relu")) { + fuse_relu = op_info->GetAttr("fuse_relu"); + } + CHECK_EQ(fuse_relu, false); + add_deconv_layer(graph->GetCompilerHandle(), + const_cast(&i_input_shape_data[0]), + input_dims.size(), + static_cast(input_var_name.c_str()), + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + static_cast(unique_op_name.c_str()), + filter_data, + bias_data, + filter_dims.data()[2], + filter_dims.data()[3], + groups, + paddings[0], + paddings[0], + paddings[1], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + static_cast(has_bias)); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(conv2d_transpose, + kBM, + paddle::lite::subgraph::bm::ConvTransposeConverter); diff --git a/lite/kernels/bm/bridges/interpolate_op.cc b/lite/kernels/bm/bridges/interpolate_op.cc index 384b8e0daa..8a744d5f2a 100644 --- a/lite/kernels/bm/bridges/interpolate_op.cc +++ b/lite/kernels/bm/bridges/interpolate_op.cc @@ -54,7 +54,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { } else { type = 0; } - if (type == 2 && is_int) { add_upsample_layer(graph->GetCompilerHandle(), const_cast(&i_x_shape_data[0]), diff --git a/lite/kernels/bm/bridges/paddle_use_bridges.h b/lite/kernels/bm/bridges/paddle_use_bridges.h index 74303d2dd7..8dbbb53d81 100644 --- a/lite/kernels/bm/bridges/paddle_use_bridges.h +++ b/lite/kernels/bm/bridges/paddle_use_bridges.h @@ -44,3 +44,10 @@ USE_SUBGRAPH_BRIDGE(bilinear_interp, kBM); USE_SUBGRAPH_BRIDGE(yolo_box, kBM); USE_SUBGRAPH_BRIDGE(sqrt, kBM); USE_SUBGRAPH_BRIDGE(square, kBM); +USE_SUBGRAPH_BRIDGE(slice, kBM); +USE_SUBGRAPH_BRIDGE(conv2d_transpose, kBM); +USE_SUBGRAPH_BRIDGE(reduce_sum, kBM); +USE_SUBGRAPH_BRIDGE(reduce_mean, kBM); +USE_SUBGRAPH_BRIDGE(squeeze, kBM); +USE_SUBGRAPH_BRIDGE(squeeze2, kBM); +USE_SUBGRAPH_BRIDGE(cast, kBM); diff --git a/lite/kernels/bm/bridges/reduce_full_op.cc b/lite/kernels/bm/bridges/reduce_full_op.cc new file mode 100644 index 0000000000..401de8bfac --- /dev/null +++ b/lite/kernels/bm/bridges/reduce_full_op.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int ReduceFullConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // input + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + std::vector i_x_shape_data(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_shape_data[i]); + } + // output + auto output_var_name = op_info->Output("Out").front(); + auto dim = op_info->GetAttr>("dim"); + auto keep_dim = op_info->GetAttr("keep_dim"); + int op_code = -1; + if (op_type == "reduce_sum") { + op_code = REDUCE_SUM; + } else if (op_type == "reduce_mean") { + op_code = REDUCE_MEAN; + } + + add_reduce_full_layer(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + static_cast(output_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + const_cast(&dim[0]), + dim.size(), + op_code, + static_cast(keep_dim)); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(reduce_sum, + kBM, + paddle::lite::subgraph::bm::ReduceFullConverter); +REGISTER_SUBGRAPH_BRIDGE(reduce_mean, + kBM, + paddle::lite::subgraph::bm::ReduceFullConverter); diff --git a/lite/kernels/bm/bridges/slice_op.cc b/lite/kernels/bm/bridges/slice_op.cc new file mode 100644 index 0000000000..9e020e6fec --- /dev/null +++ b/lite/kernels/bm/bridges/slice_op.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // input + auto input_var_name = op_info->Input("Input").front(); + auto input = scope->FindVar(input_var_name)->GetMutable(); + auto input_dims = input->dims(); + const int64_t* input_shape_data = + const_cast(&input_dims.data()[0]); + std::vector i_input_shape_data(input_dims.size()); + for (size_t i = 0; i < input_dims.size(); i++) { + i_input_shape_data[i] = static_cast(input_shape_data[i]); + } + // output + auto output_var_name = op_info->Output("Out").front(); + auto axes = op_info->GetAttr>("axes"); + auto starts = op_info->GetAttr>("starts"); + auto ends = op_info->GetAttr>("ends"); + + std::vector begin_index(input_dims.size(), 0); + std::vector end_index(input_dims.size(), -1); + std::vector strides(input_dims.size(), 1); + int32_t begin_mask = 0; + int32_t end_mask = 0; + for (size_t i = 0; i < input_dims.size(); i++) { + begin_mask |= (1 << i); + end_mask |= (1 << i); + } + for (size_t i = 0; i < axes.size(); i++) { + begin_index[axes[i]] = starts[i]; + end_index[axes[i]] = ends[i] > static_cast(input_dims.size()) + ? static_cast(input_dims.size()) + : ends[i]; + begin_mask &= ~(1 << axes[i]); + end_mask &= ~(1 << axes[i]); + } + + add_stride_slice_layer_v2(graph->GetCompilerHandle(), + static_cast(input_var_name.c_str()), + const_cast(&i_input_shape_data[0]), + input_dims.size(), + static_cast(output_var_name.c_str()), + begin_index.data(), + end_index.data(), + strides.data(), + input_dims.size(), + begin_mask, + end_mask, + 0, + 0, + 0); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(slice, + kBM, + paddle::lite::subgraph::bm::SliceConverter); diff --git a/lite/kernels/bm/bridges/squeeze_op.cc b/lite/kernels/bm/bridges/squeeze_op.cc new file mode 100644 index 0000000000..550874e837 --- /dev/null +++ b/lite/kernels/bm/bridges/squeeze_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int SqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // input + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + std::vector i_x_shape_data(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_shape_data[i]); + } + // output + auto output_var_name = op_info->Output("Out").front(); + std::vector axes; + if (op_info->HasAttr("axes")) { + axes = op_info->GetAttr>("axes"); + } + auto unique_op_scale_name = lite::subgraph::bm::UniqueName(op_type); + add_squeeze_layer(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + const_cast(&axes[0]), + axes.size(), + static_cast(output_var_name.c_str())); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(squeeze, + kBM, + paddle::lite::subgraph::bm::SqueezeConverter); +REGISTER_SUBGRAPH_BRIDGE(squeeze2, + kBM, + paddle::lite::subgraph::bm::SqueezeConverter); -- GitLab