diff --git a/lite/api/test_resnet50_lite_bm.cc b/lite/api/test_resnet50_lite_bm.cc index e315295250be0e69ae58aeac95789af069308eff..532d6d0bf206b029d022b3337b76fa0469903290 100644 --- a/lite/api/test_resnet50_lite_bm.cc +++ b/lite/api/test_resnet50_lite_bm.cc @@ -31,8 +31,6 @@ namespace paddle { namespace lite { void TestModel(const std::vector& valid_places) { - //DeviceInfo::Init(); - //DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads); lite::Predictor predictor; std::vector passes; passes.push_back("bm_subgraph_pass"); @@ -70,39 +68,17 @@ void TestModel(const std::vector& valid_places) { << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; - std::vector> results; - // i = 1 - // ground truth result from fluid - results.emplace_back(std::vector( - {0.0002451055, 0.0002585023, 0.0002659616, 0.0002823})); auto* out = predictor.GetOutput(0); ASSERT_EQ(out->dims().size(), 2); ASSERT_EQ(out->dims()[0], 1); ASSERT_EQ(out->dims()[1], 1000); - int step = 50; - for (int i = 0; i < results.size(); ++i) { - for (int j = 0; j < results[i].size(); ++j) { - EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], - results[i][j], - 1e-6); - } - } - auto* out_data = out->data(); - LOG(INFO) << "output data:"; - for (int i = 0; i < out->numel(); i += step) { - LOG(INFO) << out_data[i]; - } - float max_val = out_data[0]; - int max_val_arg = 0; - for (int i = 1; i < out->numel(); i++) { - if (max_val < out_data[i]) { - max_val = out_data[i]; - max_val_arg = i; - } + FILE* fp = fopen("result.txt", "wb"); + for (int i = 0; i < out->numel(); i++) { + fprintf(fp, "%f\n", out_data[i]); } - LOG(INFO) << "max val:" << max_val << ", max_val_arg:" << max_val_arg; + fclose(fp); } TEST(ResNet50, test_bm) { diff --git a/lite/backends/bm/target_wrapper.cc b/lite/backends/bm/target_wrapper.cc index 489bbbd7b0d81de9ea20a3c109759be11212109f..b8196d6d25cb3e51885d7949d357db8a51833f96 100644 --- a/lite/backends/bm/target_wrapper.cc +++ b/lite/backends/bm/target_wrapper.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include #include "lite/backends/bm/target_wrapper.h" #include "bmlib_runtime.h" #include "bmcompiler_if.h" @@ -19,8 +18,8 @@ namespace paddle { namespace lite { -static int g_current_device_id = 0; -static std::map g_bm_handles; +int TargetWrapperBM::device_id_ = 0; +std::map TargetWrapperBM::bm_hds_; size_t TargetWrapperBM::num_devices() { int count = 0; @@ -29,25 +28,36 @@ size_t TargetWrapperBM::num_devices() { } void TargetWrapperBM::SetDevice(int id) { - g_current_device_id = id; - - if (g_bm_handles.find(id) == g_bm_handles.end()) { +/* + if (id < 0 || (size_t)id >= num_devices()) { + LOG(FATAL) << "Failed with invalid device id " << id; + } +*/ + device_id_ = id; + if (bm_hds_.find(id) == bm_hds_.end()) { bm_handle_t bm_handle; bm_status_t ret = bm_dev_request(&bm_handle, id); CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: " << (int)ret; - g_bm_handles.insert(std::pair(id, bm_handle)); + bm_hds_.insert(std::pair(id, bm_handle)); } return; } + +void* TargetWrapperBM::GetHandle() { + if (bm_hds_.find(device_id_) == bm_hds_.end()) { + LOG(FATAL) << "device not initialized " << device_id_; + } + return bm_hds_.at(device_id_); +} void* TargetWrapperBM::Malloc(size_t size) { void* ptr{}; - if (g_bm_handles.find(g_current_device_id) == g_bm_handles.end()) { - SetDevice(g_current_device_id); + if (bm_hds_.find(device_id_) == bm_hds_.end()) { + SetDevice(device_id_); } - bm_handle_t bm_handle = g_bm_handles.at(g_current_device_id); + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); bm_device_mem_t* p_mem = (bm_device_mem_t*)malloc(sizeof(bm_device_mem_t)); bm_malloc_device_byte(bm_handle, p_mem, size); ptr = (void*)p_mem; @@ -56,7 +66,7 @@ void* TargetWrapperBM::Malloc(size_t size) { void TargetWrapperBM::Free(void* ptr) { if (ptr != NULL) { - bm_handle_t bm_handle = g_bm_handles.at(g_current_device_id); + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); bm_device_mem_t* mem = static_cast(ptr); bm_free_device(bm_handle, *mem); free(ptr); @@ -68,11 +78,11 @@ void TargetWrapperBM::MemcpySync(void* dst, const void* src, size_t size, IoDirection dir) { - if (g_bm_handles.find(g_current_device_id) == g_bm_handles.end()){ + if (bm_hds_.find(device_id_) == bm_hds_.end()){ return; } - bm_handle_t bm_handle = g_bm_handles.at(g_current_device_id); + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); bm_device_mem_t* pmem{}; const bm_device_mem_t* pcst_mem{}; diff --git a/lite/backends/bm/target_wrapper.h b/lite/backends/bm/target_wrapper.h index 04cb39c2ce8faa4625bf0cc3b4256f3e6e2fe401..99f69f976b302411f91cbcaac62a3dd2462fd801 100644 --- a/lite/backends/bm/target_wrapper.h +++ b/lite/backends/bm/target_wrapper.h @@ -14,6 +14,7 @@ #pragma once #include "lite/core/target_wrapper.h" +#include namespace paddle { namespace lite { @@ -43,6 +44,8 @@ class TargetWrapper { static void* Malloc(size_t size); static void Free(void* ptr); + + static void* GetHandle(); static void MemcpySync(void* dst, const void* src, @@ -61,6 +64,11 @@ class TargetWrapper { int value, size_t count, const stream_t& stream) {} + + private: + static int device_id_; + static std::map bm_hds_; + }; } // namespace lite } // namespace paddle diff --git a/lite/core/context.h b/lite/core/context.h index b2b0906808026bf7a30785f816800c5fc1a3b524..37da208309985bd3e91ee9c244bf95231d5ad7ad 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -90,8 +90,15 @@ class Context { Context() {} explicit Context(const BMContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler - void InitOnce() {} + void InitOnce() {Init(0);} + + void Init(int dev_id) { + TargetWrapperBM::SetDevice(dev_id); + } void CopySharedTo(BMContext* ctx) {} + void* GetHandle() { + return TargetWrapperBM::GetHandle(); + } std::string name() const { return "BMContext"; } }; diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index 8451e148ff3ada4788c9cdee4038fdba7fae738e..1652e0cbf75cb445ca573464d7db78d30c73fc83 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -22,7 +22,6 @@ #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/paddle_use_bridges.h" #include "lite/kernels/bm/bridges/utility.h" -#include "bmcompiler_if.h" namespace paddle { namespace lite { @@ -30,11 +29,11 @@ namespace kernels { namespace bm { int SubgraphEngine::BuildDeviceProgram() { - int status = 0; subgraph::bm::Graph graph; const auto& bridges = subgraph::Registry::Instance(); graph.CreateCompilerHandle(); + auto& ctx = this->ctx_->template As(); for (auto& inst : origin_program_) { auto op = inst.op(); @@ -56,11 +55,73 @@ int SubgraphEngine::BuildDeviceProgram() { std::string net_name = "paddle_bitmain"; __bmcompile_opt(graph.GetCompilerHandle(), const_cast(net_name.c_str()), 2); - finish_bmcompiler(graph.GetCompilerHandle()); + + void* bmodel_data = nullptr; + unsigned int data_size = 0; + bm_hd_ = static_cast(ctx.GetHandle()); + finish_bmcompiler_data(graph.GetCompilerHandle(), &bmodel_data, &data_size); + bmrt_hd_ = bmrt_create(bm_hd_); + if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) { + return subgraph::FAILED; + } + + bmrt_get_network_names(bmrt_hd_, &net_names_); + net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]); + auto &stage = net_info_->stages[0]; + + // input + origin_idims_.resize(input_names_.size()); + origin_itensors_.resize(input_names_.size()); + device_inputs_.resize(input_names_.size()); + for (size_t i = 0; i < input_names_.size(); i++) { + origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]); + CHECK(origin_itensors_[i]); + origin_idims_[i] = origin_itensors_[i]->dims(); + bm_device_mem_t* p_mem = static_cast(malloc(sizeof(bm_device_mem_t))); + CHECK(p_mem != nullptr); + CHECK(bm_malloc_device_byte(bm_hd_, p_mem, origin_itensors_[i]->memory_size()) == BM_SUCCESS); + bmrt_tensor_with_device(&device_inputs_[i], *p_mem, + net_info_->input_dtypes[i], + stage.input_shapes[i]); + } + + // output + origin_odims_.resize(output_names_.size()); + origin_otensors_.resize(output_names_.size()); + device_outputs_.resize(output_names_.size()); + + for (size_t i = 0; i < output_names_.size(); i++) { + origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]); + CHECK(origin_otensors_[i]); + origin_odims_[i] = origin_otensors_[i]->dims(); + output_map_.insert(std::pair(output_names_[i], i)); + origin_otensors_[i]->mutable_data(); + } + + for (size_t i = 0; i < output_names_.size(); i++) { + int mapping_index = output_map_.at(net_info_->output_names[i]); + bm_device_mem_t* p_mem = static_cast(malloc(sizeof(bm_device_mem_t))); + CHECK(p_mem != nullptr); + CHECK(bm_malloc_device_byte(bm_hd_, p_mem, origin_otensors_[mapping_index]->memory_size()) == BM_SUCCESS); + bmrt_tensor_with_device(&device_outputs_[i], *p_mem, + net_info_->output_dtypes[i], + stage.output_shapes[i]); + } + return status; } int SubgraphEngine::LaunchDeviceProgram() { + for (size_t i = 0; i < device_inputs_.size(); i++) { + bm_memcpy_s2d(bm_hd_, device_inputs_[i].device_mem, const_cast(origin_itensors_[i]->raw_data())); + } + + bmrt_launch_tensor_ex(bmrt_hd_, net_names_[0], static_cast(&device_inputs_[0]), + net_info_->input_num, static_cast(&device_outputs_[0]), net_info_->output_num, true, false); + bm_thread_sync(bm_hd_); + for (size_t i = 0; i < device_outputs_.size(); i++) { + bm_memcpy_d2s(bm_hd_, const_cast(origin_otensors_[i]->raw_data()), device_outputs_[i].device_mem); + } return 0; } diff --git a/lite/kernels/bm/subgraph_compute.h b/lite/kernels/bm/subgraph_compute.h index f2a2de6b902adc0b221fa148e1269fbb4c151e50..03f5dd393e9250bd6cb7cbf26da93a2c89cfeb31 100644 --- a/lite/kernels/bm/subgraph_compute.h +++ b/lite/kernels/bm/subgraph_compute.h @@ -24,6 +24,9 @@ #include "lite/kernels/npu/bridges/engine.h" #include "lite/kernels/npu/bridges/registry.h" +#include "bmcompiler_if.h" +#include "bmruntime_interface.h" + namespace paddle { namespace lite { namespace kernels { @@ -43,20 +46,25 @@ class SubgraphEngine : public subgraph::Engine { protected: int BuildDeviceProgram() override; int LaunchDeviceProgram() override; + +private: + void* bmrt_hd_; + std::vector device_inputs_; + std::vector device_outputs_; + std::map output_map_; + const char** net_names_; + const bm_net_info_t* net_info_; + bm_handle_t bm_hd_; }; class SubgraphCompute : public KernelLite { public: using param_t = operators::SubgraphParam; - void PrepareForRun() override; - void Run() override; - virtual ~SubgraphCompute() = default; private: - void* bm_compiler_ht_; std::unique_ptr engine_; };