未验证 提交 7bf64f67 编写于 作者: S Santa An 提交者: GitHub

[LITE][BM] support reduce and other ops, test=develop (#3199)

* * support download bm_sdk, test=develop

* [LITE][BM] add slice op

* [LITE][BM] fix concat issue

* [LITE][BM] support reduce full ops, test=develop

* [LITE][BM] change test_resnet50 to change test_classify

* [LITE][BM] add cast op

* [LITE][BM] add reduce and other ops, test=develop

* [LITE][BM] add reduce,cast and other ops, test=develop
上级 e04399ba
...@@ -181,7 +181,7 @@ if(WITH_TESTING) ...@@ -181,7 +181,7 @@ if(WITH_TESTING)
add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz) add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz)
endif() endif()
if(LITE_WITH_BM) if(LITE_WITH_BM)
lite_cc_test(test_resnet50_lite_bm SRCS test_resnet50_lite_bm.cc lite_cc_test(test_classify_lite_bm SRCS test_classify_lite_bm.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${bm_kernels} ${bm_bridges} ${ops} ${host_kernels} ${bm_kernels} ${bm_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
......
...@@ -80,7 +80,7 @@ void TestModel(const std::vector<Place>& valid_places) { ...@@ -80,7 +80,7 @@ void TestModel(const std::vector<Place>& valid_places) {
fclose(fp); fclose(fp);
} }
TEST(ResNet50, test_bm) { TEST(Classify, test_bm) {
std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}}); Place{TARGET(kX86), PRECISION(kFloat)}});
......
...@@ -25,6 +25,11 @@ lite_cc_library(subgraph_bridge_box_coder_op_bm SRCS box_coder_op.cc DEPS ${bm_s ...@@ -25,6 +25,11 @@ lite_cc_library(subgraph_bridge_box_coder_op_bm SRCS box_coder_op.cc DEPS ${bm_s
lite_cc_library(subgraph_bridge_multiclass_nms_op_bm SRCS multiclass_nms_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_multiclass_nms_op_bm SRCS multiclass_nms_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_interpolate_op_bm SRCS interpolate_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_interpolate_op_bm SRCS interpolate_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_yolo_box_op_bm SRCS yolo_box_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_yolo_box_op_bm SRCS yolo_box_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_slice_op_bm SRCS slice_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_conv_transpose_op_bm SRCS conv_transpose_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reduce_full_op_bm SRCS reduce_full_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps})
set(bm_subgraph_bridges set(bm_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -48,4 +53,9 @@ set(bm_subgraph_bridges ...@@ -48,4 +53,9 @@ set(bm_subgraph_bridges
subgraph_bridge_multiclass_nms_op_bm subgraph_bridge_multiclass_nms_op_bm
subgraph_bridge_interpolate_op_bm subgraph_bridge_interpolate_op_bm
subgraph_bridge_yolo_box_op_bm subgraph_bridge_yolo_box_op_bm
subgraph_bridge_slice_op_bm
subgraph_bridge_conv_transpose_op_bm
subgraph_bridge_reduce_full_op_bm
subgraph_bridge_squeeze_op_bm
subgraph_bridge_cast_op_bm
CACHE INTERNAL "bm_subgraph_bridges") CACHE INTERNAL "bm_subgraph_bridges")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
bool CvtDtype(int dtype, int* ptype) {
switch (dtype) {
case 21:
*ptype = DTYPE_INT8;
break;
case 1:
*ptype = DTYPE_INT16;
break;
case 2:
*ptype = DTYPE_FP32;
break;
case 5:
*ptype = DTYPE_FP32;
break;
default:
LOG(WARNING) << "[BM] unsupported date type: " << dtype;
return false;
}
return true;
}
int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto output_var_name = op_info->Output("Out").front();
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_dims[i]);
}
int in_dtype = op_info->GetAttr<int>("in_dtype");
int out_dtype = op_info->GetAttr<int>("out_dtype");
if (in_dtype == out_dtype) {
add_identity_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(output_var_name.c_str()));
} else {
int out_bm_dtype = 0;
CHECK_EQ(CvtDtype(out_dtype, &out_bm_dtype), true);
add_shape_cast_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
static_cast<const char*>(output_var_name.c_str()),
out_bm_dtype);
}
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(cast, kBM, paddle::lite::subgraph::bm::CastConverter);
...@@ -30,8 +30,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -30,8 +30,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_type = op_info->Type(); auto op_type = op_info->Type();
// input // input
auto x_names = op_info->Input("X"); auto x_names = op_info->Input("X");
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
// output // output
auto output_var_name = op_info->Output("Out").front(); auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>(); auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
...@@ -57,7 +55,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -57,7 +55,6 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
shape[i][j] = static_cast<int32_t>(x_shape_data[j]); shape[i][j] = static_cast<int32_t>(x_shape_data[j]);
} }
} }
auto axis = op_info->GetAttr<int>("axis"); auto axis = op_info->GetAttr<int>("axis");
add_concat_layer(graph->GetCompilerHandle(), add_concat_layer(graph->GetCompilerHandle(),
input_num, input_num,
......
...@@ -55,7 +55,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -55,7 +55,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const_cast<const int64_t*>(&output_dims.data()[0]); const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_input_shape_data(input_dims.size()); std::vector<int32_t> i_input_shape_data(input_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size()); std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < input_dims.size(); i++) { for (size_t i = 0; i < input_dims.size(); i++) {
i_input_shape_data[i] = static_cast<int32_t>(input_shape_data[i]); i_input_shape_data[i] = static_cast<int32_t>(input_shape_data[i]);
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
auto output_var_name = op_info->Output("Output").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims();
CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(output_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4);
bool has_bias = lite::subgraph::bm::HasInputArg(op_info, scope, "Bias");
float* bias_data = nullptr;
if (has_bias) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
bias_data = static_cast<float*>(bias->mutable_data<float>());
}
const int64_t* input_shape_data =
const_cast<const int64_t*>(&input_dims.data()[0]);
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_input_shape_data(input_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < input_dims.size(); i++) {
i_input_shape_data[i] = static_cast<int32_t>(input_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int32_t>(output_shape_data[i]);
}
const float* filter_data =
const_cast<const float*>(filter->mutable_data<float>());
auto groups = op_info->GetAttr<int>("groups");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
bool fuse_relu = false;
if (op_info->HasAttr("fuse_relu")) {
fuse_relu = op_info->GetAttr<bool>("fuse_relu");
}
CHECK_EQ(fuse_relu, false);
add_deconv_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_input_shape_data[0]),
input_dims.size(),
static_cast<const char*>(input_var_name.c_str()),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
filter_data,
bias_data,
filter_dims.data()[2],
filter_dims.data()[3],
groups,
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
static_cast<int>(has_bias));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(conv2d_transpose,
kBM,
paddle::lite::subgraph::bm::ConvTransposeConverter);
...@@ -54,7 +54,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -54,7 +54,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else { } else {
type = 0; type = 0;
} }
if (type == 2 && is_int) { if (type == 2 && is_int) {
add_upsample_layer(graph->GetCompilerHandle(), add_upsample_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_shape_data[0]), const_cast<const int*>(&i_x_shape_data[0]),
......
...@@ -44,3 +44,10 @@ USE_SUBGRAPH_BRIDGE(bilinear_interp, kBM); ...@@ -44,3 +44,10 @@ USE_SUBGRAPH_BRIDGE(bilinear_interp, kBM);
USE_SUBGRAPH_BRIDGE(yolo_box, kBM); USE_SUBGRAPH_BRIDGE(yolo_box, kBM);
USE_SUBGRAPH_BRIDGE(sqrt, kBM); USE_SUBGRAPH_BRIDGE(sqrt, kBM);
USE_SUBGRAPH_BRIDGE(square, kBM); USE_SUBGRAPH_BRIDGE(square, kBM);
USE_SUBGRAPH_BRIDGE(slice, kBM);
USE_SUBGRAPH_BRIDGE(conv2d_transpose, kBM);
USE_SUBGRAPH_BRIDGE(reduce_sum, kBM);
USE_SUBGRAPH_BRIDGE(reduce_mean, kBM);
USE_SUBGRAPH_BRIDGE(squeeze, kBM);
USE_SUBGRAPH_BRIDGE(squeeze2, kBM);
USE_SUBGRAPH_BRIDGE(cast, kBM);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int ReduceFullConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
auto dim = op_info->GetAttr<std::vector<int32_t>>("dim");
auto keep_dim = op_info->GetAttr<bool>("keep_dim");
int op_code = -1;
if (op_type == "reduce_sum") {
op_code = REDUCE_SUM;
} else if (op_type == "reduce_mean") {
op_code = REDUCE_MEAN;
}
add_reduce_full_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
static_cast<const char*>(output_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
const_cast<const int*>(&dim[0]),
dim.size(),
op_code,
static_cast<int>(keep_dim));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(reduce_sum,
kBM,
paddle::lite::subgraph::bm::ReduceFullConverter);
REGISTER_SUBGRAPH_BRIDGE(reduce_mean,
kBM,
paddle::lite::subgraph::bm::ReduceFullConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
const int64_t* input_shape_data =
const_cast<const int64_t*>(&input_dims.data()[0]);
std::vector<int32_t> i_input_shape_data(input_dims.size());
for (size_t i = 0; i < input_dims.size(); i++) {
i_input_shape_data[i] = static_cast<int>(input_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
auto axes = op_info->GetAttr<std::vector<int32_t>>("axes");
auto starts = op_info->GetAttr<std::vector<int32_t>>("starts");
auto ends = op_info->GetAttr<std::vector<int32_t>>("ends");
std::vector<int32_t> begin_index(input_dims.size(), 0);
std::vector<int32_t> end_index(input_dims.size(), -1);
std::vector<int32_t> strides(input_dims.size(), 1);
int32_t begin_mask = 0;
int32_t end_mask = 0;
for (size_t i = 0; i < input_dims.size(); i++) {
begin_mask |= (1 << i);
end_mask |= (1 << i);
}
for (size_t i = 0; i < axes.size(); i++) {
begin_index[axes[i]] = starts[i];
end_index[axes[i]] = ends[i] > static_cast<int32_t>(input_dims.size())
? static_cast<int32_t>(input_dims.size())
: ends[i];
begin_mask &= ~(1 << axes[i]);
end_mask &= ~(1 << axes[i]);
}
add_stride_slice_layer_v2(graph->GetCompilerHandle(),
static_cast<const char*>(input_var_name.c_str()),
const_cast<const int*>(&i_input_shape_data[0]),
input_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
begin_index.data(),
end_index.data(),
strides.data(),
input_dims.size(),
begin_mask,
end_mask,
0,
0,
0);
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(slice,
kBM,
paddle::lite::subgraph::bm::SliceConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int SqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
std::vector<int> axes;
if (op_info->HasAttr("axes")) {
axes = op_info->GetAttr<std::vector<int>>("axes");
}
auto unique_op_scale_name = lite::subgraph::bm::UniqueName(op_type);
add_squeeze_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
const_cast<const int*>(&axes[0]),
axes.size(),
static_cast<const char*>(output_var_name.c_str()));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(squeeze,
kBM,
paddle::lite::subgraph::bm::SqueezeConverter);
REGISTER_SUBGRAPH_BRIDGE(squeeze2,
kBM,
paddle::lite::subgraph::bm::SqueezeConverter);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册