未验证 提交 53544680 编写于 作者: J jackzhang235 提交者: GitHub

modify mlu's TargetWrapper, create and run predictor in different thread (#112)

modify mlu's TargetWrapper to situation when creating and running predictor in different threads
上级 a2547758
...@@ -52,7 +52,8 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -52,7 +52,8 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif #endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init(); Env<TARGET(kMLU)>::Init();
lite::TargetWrapperMlu::SetMLURunMode(config.mlu_core_version(), lite::TargetWrapperMlu::SetMLURunMode(reinterpret_cast<int64_t>(this),
config.mlu_core_version(),
config.mlu_core_number(), config.mlu_core_number(),
config.mlu_use_first_conv(), config.mlu_use_first_conv(),
config.mlu_first_conv_mean(), config.mlu_first_conv_mean(),
...@@ -102,6 +103,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { ...@@ -102,6 +103,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
} }
void CxxPaddleApiImpl::Run() { void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_MLU
lite::TargetWrapperMlu::RegisterMLURunningPredictor(
reinterpret_cast<int64_t>(this));
#endif
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_); lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif #endif
......
...@@ -36,12 +36,10 @@ void cnrtMemcpyDtoH(void* dst, const void* src, size_t size) { ...@@ -36,12 +36,10 @@ void cnrtMemcpyDtoH(void* dst, const void* src, size_t size) {
} // namespace mlu } // namespace mlu
thread_local cnmlCoreVersion_t TargetWrapperMlu::mlu_core_version_{CNML_MLU270}; std::map<int64_t, TargetWrapperMlu::ThreadLocalInfo>
thread_local int TargetWrapperMlu::mlu_core_number_{1}; TargetWrapperMlu::predictor_info_map_;
thread_local bool TargetWrapperMlu::use_first_conv_{false}; std::map<std::thread::id, int64_t> TargetWrapperMlu::thread_predictor_map_;
thread_local std::vector<float> TargetWrapperMlu::mean_vec_; std::mutex TargetWrapperMlu::info_map_mutex_;
thread_local std::vector<float> TargetWrapperMlu::std_vec_;
thread_local DataLayoutType TargetWrapperMlu::input_layout_{DATALAYOUT(kNCHW)};
size_t TargetWrapperMlu::num_devices() { size_t TargetWrapperMlu::num_devices() {
uint32_t dev_count = 0; uint32_t dev_count = 0;
...@@ -84,43 +82,65 @@ void TargetWrapperMlu::MemcpySync(void* dst, ...@@ -84,43 +82,65 @@ void TargetWrapperMlu::MemcpySync(void* dst,
LOG(FATAL) << "Unsupported IoDirection" << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection" << static_cast<int>(dir);
} }
} }
void TargetWrapperMlu::SetMLURunMode(lite_api::MLUCoreVersion core_version,
void TargetWrapperMlu::SetMLURunMode(int64_t predictor_addr,
lite_api::MLUCoreVersion core_version,
int core_number, int core_number,
bool use_first_conv, bool use_first_conv,
const std::vector<float>& mean_vec, const std::vector<float>& mean_vec,
const std::vector<float>& std_vec, const std::vector<float>& std_vec,
DataLayoutType input_layout) { DataLayoutType input_layout) {
switch (core_version) { ThreadLocalInfo info = ThreadLocalInfo(core_version,
case (lite_api::MLUCoreVersion::MLU_220): core_number,
mlu_core_version_ = CNML_MLU220; use_first_conv,
break; mean_vec,
case (lite_api::MLUCoreVersion::MLU_270): std_vec,
mlu_core_version_ = CNML_MLU270; input_layout);
break; std::lock_guard<std::mutex> lock(info_map_mutex_);
default: predictor_info_map_[predictor_addr] = info;
mlu_core_version_ = CNML_MLU270; VLOG(6) << "predictor_info_map_ add key: " << predictor_addr;
break; thread_predictor_map_[std::this_thread::get_id()] = predictor_addr;
} VLOG(6) << "thread_predictor_map_ add key: " << std::this_thread::get_id()
mlu_core_number_ = core_number; << ", add value: " << predictor_addr;
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
input_layout_ = input_layout;
} }
void TargetWrapperMlu::RegisterMLURunningPredictor(int64_t predictor_addr) {
std::lock_guard<std::mutex> lock(info_map_mutex_);
thread_predictor_map_[std::this_thread::get_id()] = predictor_addr;
VLOG(6) << "thread_predictor_map_ add key: " << std::this_thread::get_id()
<< ", add value: " << predictor_addr;
}
#define RETURN_MLU_INFO(x) \
do { \
std::lock_guard<std::mutex> lock(info_map_mutex_); \
VLOG(6) << "call from thread: " << std::this_thread::get_id() \
<< ", predictor key: " \
<< thread_predictor_map_[std::this_thread::get_id()]; \
return predictor_info_map_ \
[thread_predictor_map_[std::this_thread::get_id()]] \
.x; \
} while (0)
cnmlCoreVersion_t TargetWrapperMlu::MLUCoreVersion() { cnmlCoreVersion_t TargetWrapperMlu::MLUCoreVersion() {
return mlu_core_version_; RETURN_MLU_INFO(mlu_core_version_);
} }
int TargetWrapperMlu::MLUCoreNumber() { return mlu_core_number_; } int TargetWrapperMlu::MLUCoreNumber() { RETURN_MLU_INFO(mlu_core_number_); }
bool TargetWrapperMlu::UseFirstConv() { return use_first_conv_; } bool TargetWrapperMlu::UseFirstConv() { RETURN_MLU_INFO(use_first_conv_); }
const std::vector<float>& TargetWrapperMlu::MeanVec() { return mean_vec_; } const std::vector<float>& TargetWrapperMlu::MeanVec() {
RETURN_MLU_INFO(mean_vec_);
}
const std::vector<float>& TargetWrapperMlu::StdVec() { return std_vec_; } const std::vector<float>& TargetWrapperMlu::StdVec() {
RETURN_MLU_INFO(std_vec_);
}
DataLayoutType TargetWrapperMlu::InputLayout() { return input_layout_; } DataLayoutType TargetWrapperMlu::InputLayout() {
RETURN_MLU_INFO(input_layout_);
}
// void TargetWrapperMlu::MemcpyAsync(void* dst, // void TargetWrapperMlu::MemcpyAsync(void* dst,
// const void* src, // const void* src,
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <mutex> // NOLINT
#include <thread> // NOLINT
#include <vector> #include <vector>
#include "lite/backends/mlu/mlu_utils.h" #include "lite/backends/mlu/mlu_utils.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
...@@ -25,6 +28,41 @@ using TargetWrapperMlu = TargetWrapper<TARGET(kMLU)>; ...@@ -25,6 +28,41 @@ using TargetWrapperMlu = TargetWrapper<TARGET(kMLU)>;
template <> template <>
class TargetWrapper<TARGET(kMLU)> { class TargetWrapper<TARGET(kMLU)> {
public: public:
struct ThreadLocalInfo {
cnmlCoreVersion_t mlu_core_version_;
int mlu_core_number_;
bool use_first_conv_;
std::vector<float> mean_vec_;
std::vector<float> std_vec_;
DataLayoutType input_layout_;
ThreadLocalInfo() {}
ThreadLocalInfo(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout)
: mlu_core_number_(core_number),
use_first_conv_(use_first_conv),
mean_vec_(mean_vec),
std_vec_(std_vec),
input_layout_(input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
break;
case (lite_api::MLUCoreVersion::MLU_270):
mlu_core_version_ = CNML_MLU270;
break;
default:
mlu_core_version_ = CNML_MLU270;
break;
}
}
};
using queue_t = cnrtQueue_t; using queue_t = cnrtQueue_t;
static size_t num_devices(); static size_t num_devices();
...@@ -44,12 +82,14 @@ class TargetWrapper<TARGET(kMLU)> { ...@@ -44,12 +82,14 @@ class TargetWrapper<TARGET(kMLU)> {
const void* src, const void* src,
size_t size, size_t size,
IoDirection dir); IoDirection dir);
static void SetMLURunMode(lite_api::MLUCoreVersion core_version, static void SetMLURunMode(int64_t predictor_addr,
lite_api::MLUCoreVersion core_version,
int core_number, int core_number,
bool use_first_conv, bool use_first_conv,
const std::vector<float>& mean_vec, const std::vector<float>& mean_vec,
const std::vector<float>& std_vec, const std::vector<float>& std_vec,
DataLayoutType input_layout); DataLayoutType input_layout);
static void RegisterMLURunningPredictor(int64_t);
static cnmlCoreVersion_t MLUCoreVersion(); static cnmlCoreVersion_t MLUCoreVersion();
static int MLUCoreNumber(); static int MLUCoreNumber();
static bool UseFirstConv(); static bool UseFirstConv();
...@@ -61,13 +101,10 @@ class TargetWrapper<TARGET(kMLU)> { ...@@ -61,13 +101,10 @@ class TargetWrapper<TARGET(kMLU)> {
// size_t size, // size_t size,
// IoDirection dir, // IoDirection dir,
// const queue_t& queue); // const queue_t& queue);
private: public:
static thread_local cnmlCoreVersion_t mlu_core_version_; static std::mutex info_map_mutex_;
static thread_local int mlu_core_number_; static std::map<int64_t, ThreadLocalInfo> predictor_info_map_;
static thread_local bool use_first_conv_; static std::map<std::thread::id, int64_t> thread_predictor_map_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
}; };
} // namespace lite } // namespace lite
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/mlu/bridges/test_helper.h" #include "lite/kernels/mlu/bridges/test_helper.h"
#include <thread> // NOLINT
#include <utility> #include <utility>
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -53,6 +54,17 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op, ...@@ -53,6 +54,17 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names, const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names, const std::vector<std::string>& output_var_names,
cnmlDataOrder_t order) { cnmlDataOrder_t order) {
// add initialization of some global variables which maybe used in converter
// such as conv converter
lite::TargetWrapperMlu::SetMLURunMode(0, // just a fake address of predictor
lite_api::MLUCoreVersion::MLU_270,
1,
false, // used in conv converter
{},
{},
DATALAYOUT(kNCHW));
lite::TargetWrapperMlu::RegisterMLURunningPredictor(0);
CNRT_CALL(cnrtInit(0)); CNRT_CALL(cnrtInit(0));
lite::SetMluDevice(0); lite::SetMluDevice(0);
cnrtQueue_t queue_; cnrtQueue_t queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册