From 0348583fe29745117efdfb65e3b3ab7474573929 Mon Sep 17 00:00:00 2001 From: Santa An <49897975+AnBaolei1984@users.noreply.github.com> Date: Wed, 23 Sep 2020 12:20:04 +0800 Subject: [PATCH] [LITE][BM] fix input shape order changed issue,test=develop (#4407) * [LITE][BM] support multiclass_nms2 and fix some issues, test=develop * create * [LITE][BM] fix input shape order changed issue,test=develop --- lite/backends/bm/target_wrapper.cc | 2 +- lite/kernels/bm/subgraph_compute.cc | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/lite/backends/bm/target_wrapper.cc b/lite/backends/bm/target_wrapper.cc index 6dab2a574d..83aa4dc8c1 100644 --- a/lite/backends/bm/target_wrapper.cc +++ b/lite/backends/bm/target_wrapper.cc @@ -23,7 +23,7 @@ int TargetWrapperBM::device_id_ = 0; std::map TargetWrapperBM::bm_hds_; size_t TargetWrapperBM::num_devices() { - int count = 0; + int count = 1; bm_status_t ret = bm_dev_getcount(&count); CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: " << static_cast(ret); diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index efbb848313..eeb81ba9da 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -66,9 +66,9 @@ bool SubgraphEngine::BuildDeviceProgram() { graph.GetCompilerHandle(), const_cast(unique_net_name.c_str()), 1); void* bmodel_data = nullptr; unsigned int data_size = 0; - bm_hd_ = static_cast(ctx.GetHandle()); finish_bmcompiler_data(graph.GetCompilerHandle(), &bmodel_data, &data_size); graph.UnlockCompilerMutex(); + bm_hd_ = static_cast(ctx.GetHandle()); bmrt_hd_ = bmrt_create(bm_hd_); if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) { return false; @@ -79,15 +79,15 @@ bool SubgraphEngine::BuildDeviceProgram() { // input device_inputs_.resize(input_names_.size()); for (size_t i = 0; i < input_names_.size(); i++) { - origin_itensors_[i] = + auto origin_itensor = exec_scope_->FindMutableTensor(net_info_->input_names[i]); - CHECK(origin_itensors_[i]); + CHECK(origin_itensor); bm_device_mem_t* p_mem = static_cast(malloc(sizeof(bm_device_mem_t))); CHECK(p_mem != nullptr); - CHECK_EQ(bm_malloc_device_byte( - bm_hd_, p_mem, origin_itensors_[i]->memory_size()), - BM_SUCCESS); + CHECK_EQ( + bm_malloc_device_byte(bm_hd_, p_mem, origin_itensor->memory_size()), + BM_SUCCESS); bmrt_tensor_with_device(&device_inputs_[i], *p_mem, net_info_->input_dtypes[i], @@ -124,9 +124,11 @@ bool SubgraphEngine::BuildDeviceProgram() { bool SubgraphEngine::LaunchDeviceProgram() { for (size_t i = 0; i < device_inputs_.size(); i++) { + auto origin_itensor = + exec_scope_->FindMutableTensor(net_info_->input_names[i]); bm_memcpy_s2d(bm_hd_, device_inputs_[i].device_mem, - const_cast(origin_itensors_[i]->raw_data())); + const_cast(origin_itensor->raw_data())); } bmrt_launch_tensor_ex(bmrt_hd_, net_names_[0], -- GitLab