未验证 提交 53544680 编写于 作者: J jackzhang235 提交者: GitHub

modify mlu's TargetWrapper, create and run predictor in different thread (#112)

modify mlu's TargetWrapper to situation when creating and running predictor in different threads
上级 a2547758
......@@ -52,7 +52,8 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif
#ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init();
lite::TargetWrapperMlu::SetMLURunMode(config.mlu_core_version(),
lite::TargetWrapperMlu::SetMLURunMode(reinterpret_cast<int64_t>(this),
config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
......@@ -102,6 +103,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
}
void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_MLU
lite::TargetWrapperMlu::RegisterMLURunningPredictor(
reinterpret_cast<int64_t>(this));
#endif
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif
......
......@@ -36,12 +36,10 @@ void cnrtMemcpyDtoH(void* dst, const void* src, size_t size) {
} // namespace mlu
thread_local cnmlCoreVersion_t TargetWrapperMlu::mlu_core_version_{CNML_MLU270};
thread_local int TargetWrapperMlu::mlu_core_number_{1};
thread_local bool TargetWrapperMlu::use_first_conv_{false};
thread_local std::vector<float> TargetWrapperMlu::mean_vec_;
thread_local std::vector<float> TargetWrapperMlu::std_vec_;
thread_local DataLayoutType TargetWrapperMlu::input_layout_{DATALAYOUT(kNCHW)};
std::map<int64_t, TargetWrapperMlu::ThreadLocalInfo>
TargetWrapperMlu::predictor_info_map_;
std::map<std::thread::id, int64_t> TargetWrapperMlu::thread_predictor_map_;
std::mutex TargetWrapperMlu::info_map_mutex_;
size_t TargetWrapperMlu::num_devices() {
uint32_t dev_count = 0;
......@@ -84,43 +82,65 @@ void TargetWrapperMlu::MemcpySync(void* dst,
LOG(FATAL) << "Unsupported IoDirection" << static_cast<int>(dir);
}
}
void TargetWrapperMlu::SetMLURunMode(lite_api::MLUCoreVersion core_version,
void TargetWrapperMlu::SetMLURunMode(int64_t predictor_addr,
lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
break;
case (lite_api::MLUCoreVersion::MLU_270):
mlu_core_version_ = CNML_MLU270;
break;
default:
mlu_core_version_ = CNML_MLU270;
break;
}
mlu_core_number_ = core_number;
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
input_layout_ = input_layout;
ThreadLocalInfo info = ThreadLocalInfo(core_version,
core_number,
use_first_conv,
mean_vec,
std_vec,
input_layout);
std::lock_guard<std::mutex> lock(info_map_mutex_);
predictor_info_map_[predictor_addr] = info;
VLOG(6) << "predictor_info_map_ add key: " << predictor_addr;
thread_predictor_map_[std::this_thread::get_id()] = predictor_addr;
VLOG(6) << "thread_predictor_map_ add key: " << std::this_thread::get_id()
<< ", add value: " << predictor_addr;
}
void TargetWrapperMlu::RegisterMLURunningPredictor(int64_t predictor_addr) {
std::lock_guard<std::mutex> lock(info_map_mutex_);
thread_predictor_map_[std::this_thread::get_id()] = predictor_addr;
VLOG(6) << "thread_predictor_map_ add key: " << std::this_thread::get_id()
<< ", add value: " << predictor_addr;
}
#define RETURN_MLU_INFO(x) \
do { \
std::lock_guard<std::mutex> lock(info_map_mutex_); \
VLOG(6) << "call from thread: " << std::this_thread::get_id() \
<< ", predictor key: " \
<< thread_predictor_map_[std::this_thread::get_id()]; \
return predictor_info_map_ \
[thread_predictor_map_[std::this_thread::get_id()]] \
.x; \
} while (0)
cnmlCoreVersion_t TargetWrapperMlu::MLUCoreVersion() {
return mlu_core_version_;
RETURN_MLU_INFO(mlu_core_version_);
}
int TargetWrapperMlu::MLUCoreNumber() { return mlu_core_number_; }
int TargetWrapperMlu::MLUCoreNumber() { RETURN_MLU_INFO(mlu_core_number_); }
bool TargetWrapperMlu::UseFirstConv() { return use_first_conv_; }
bool TargetWrapperMlu::UseFirstConv() { RETURN_MLU_INFO(use_first_conv_); }
const std::vector<float>& TargetWrapperMlu::MeanVec() { return mean_vec_; }
const std::vector<float>& TargetWrapperMlu::MeanVec() {
RETURN_MLU_INFO(mean_vec_);
}
const std::vector<float>& TargetWrapperMlu::StdVec() { return std_vec_; }
const std::vector<float>& TargetWrapperMlu::StdVec() {
RETURN_MLU_INFO(std_vec_);
}
DataLayoutType TargetWrapperMlu::InputLayout() { return input_layout_; }
DataLayoutType TargetWrapperMlu::InputLayout() {
RETURN_MLU_INFO(input_layout_);
}
// void TargetWrapperMlu::MemcpyAsync(void* dst,
// const void* src,
......
......@@ -13,6 +13,9 @@
// limitations under the License.
#pragma once
#include <map>
#include <mutex> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include "lite/backends/mlu/mlu_utils.h"
#include "lite/core/target_wrapper.h"
......@@ -25,6 +28,41 @@ using TargetWrapperMlu = TargetWrapper<TARGET(kMLU)>;
template <>
class TargetWrapper<TARGET(kMLU)> {
public:
struct ThreadLocalInfo {
cnmlCoreVersion_t mlu_core_version_;
int mlu_core_number_;
bool use_first_conv_;
std::vector<float> mean_vec_;
std::vector<float> std_vec_;
DataLayoutType input_layout_;
ThreadLocalInfo() {}
ThreadLocalInfo(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout)
: mlu_core_number_(core_number),
use_first_conv_(use_first_conv),
mean_vec_(mean_vec),
std_vec_(std_vec),
input_layout_(input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
break;
case (lite_api::MLUCoreVersion::MLU_270):
mlu_core_version_ = CNML_MLU270;
break;
default:
mlu_core_version_ = CNML_MLU270;
break;
}
}
};
using queue_t = cnrtQueue_t;
static size_t num_devices();
......@@ -44,12 +82,14 @@ class TargetWrapper<TARGET(kMLU)> {
const void* src,
size_t size,
IoDirection dir);
static void SetMLURunMode(lite_api::MLUCoreVersion core_version,
static void SetMLURunMode(int64_t predictor_addr,
lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout);
static void RegisterMLURunningPredictor(int64_t);
static cnmlCoreVersion_t MLUCoreVersion();
static int MLUCoreNumber();
static bool UseFirstConv();
......@@ -61,13 +101,10 @@ class TargetWrapper<TARGET(kMLU)> {
// size_t size,
// IoDirection dir,
// const queue_t& queue);
private:
static thread_local cnmlCoreVersion_t mlu_core_version_;
static thread_local int mlu_core_number_;
static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
public:
static std::mutex info_map_mutex_;
static std::map<int64_t, ThreadLocalInfo> predictor_info_map_;
static std::map<std::thread::id, int64_t> thread_predictor_map_;
};
} // namespace lite
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/mlu/bridges/test_helper.h"
#include <thread> // NOLINT
#include <utility>
#include "lite/core/device_info.h"
#include "lite/core/op_registry.h"
......@@ -53,6 +54,17 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names,
cnmlDataOrder_t order) {
// add initialization of some global variables which maybe used in converter
// such as conv converter
lite::TargetWrapperMlu::SetMLURunMode(0, // just a fake address of predictor
lite_api::MLUCoreVersion::MLU_270,
1,
false, // used in conv converter
{},
{},
DATALAYOUT(kNCHW));
lite::TargetWrapperMlu::RegisterMLURunningPredictor(0);
CNRT_CALL(cnrtInit(0));
lite::SetMluDevice(0);
cnrtQueue_t queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册