提交 e0a5364f 编写于 作者: J jiaopu

move sth from device_info to backends/mlu

上级 067111d4
......@@ -36,12 +36,12 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif
#ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init();
lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
config.mlu_first_conv_std(),
config.mlu_input_layout());
lite::TargetWrapperMlu::SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
config.mlu_first_conv_std(),
config.mlu_input_layout());
#endif // LITE_WITH_MLU
auto places = config.valid_places();
std::vector<std::string> passes{};
......
......@@ -36,6 +36,13 @@ void cnrtMemcpyDtoH(void* dst, const void* src, size_t size) {
} // namespace mlu
thread_local cnmlCoreVersion_t TargetWrapperMlu::mlu_core_version_{CNML_MLU270};
thread_local int TargetWrapperMlu::mlu_core_number_{1};
thread_local bool TargetWrapperMlu::use_first_conv_{false};
thread_local std::vector<float> TargetWrapperMlu::mean_vec_;
thread_local std::vector<float> TargetWrapperMlu::std_vec_;
thread_local DataLayoutType TargetWrapperMlu::input_layout_{DATALAYOUT(kNCHW)};
size_t TargetWrapperMlu::num_devices() {
uint32_t dev_count = 0;
CNRT_CALL(cnrtGetDeviceCount(&dev_count)) << " cnrt get device count failed";
......@@ -77,6 +84,47 @@ void TargetWrapperMlu::MemcpySync(void* dst,
LOG(FATAL) << "Unsupported IoDirection" << static_cast<int>(dir);
}
}
void TargetWrapperMlu::SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
break;
case (lite_api::MLUCoreVersion::MLU_270):
mlu_core_version_ = CNML_MLU270;
break;
default:
mlu_core_version_ = CNML_MLU270;
break;
}
mlu_core_number_ = core_number;
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
input_layout_ = input_layout;
}
cnmlCoreVersion_t TargetWrapperMlu::MLUCoreVersion() {
return mlu_core_version_;
}
int TargetWrapperMlu::MLUCoreNumber() { return mlu_core_number_; }
bool TargetWrapperMlu::UseFirstConv() { return use_first_conv_; }
// const std::vector<float>& TargetWrapperMlu::MeanVec() const { return
// mean_vec_; }
const std::vector<float>& TargetWrapperMlu::MeanVec() { return mean_vec_; }
// const std::vector<float>& TargetWrapperMlu::StdVec() const { return std_vec_;
// }
const std::vector<float>& TargetWrapperMlu::StdVec() { return std_vec_; }
DataLayoutType TargetWrapperMlu::InputLayout() { return input_layout_; }
// void TargetWrapperMlu::MemcpyAsync(void* dst,
// const void* src,
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/mlu/mlu_utils.h"
#include "lite/core/target_wrapper.h"
......@@ -43,11 +44,32 @@ class TargetWrapper<TARGET(kMLU)> {
const void* src,
size_t size,
IoDirection dir);
static void SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout);
static cnmlCoreVersion_t MLUCoreVersion();
static int MLUCoreNumber();
static bool UseFirstConv();
// static const std::vector<float>& MeanVec() const;
// static const std::vector<float>& StdVec() const;
static const std::vector<float>& MeanVec();
static const std::vector<float>& StdVec();
static DataLayoutType InputLayout();
// static void MemcpyAsync(void* dst,
// const void* src,
// size_t size,
// IoDirection dir,
// const queue_t& queue);
private:
static thread_local cnmlCoreVersion_t mlu_core_version_;
static thread_local int mlu_core_number_;
static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
};
} // namespace lite
......
......@@ -227,10 +227,12 @@ class Context<TargetType::kMLU> {
void SetIoQueue(cnrtQueue_t queue) { io_queue_ = queue; }
cnmlCoreVersion_t MLUCoreVersion() {
return DeviceInfo::Global().MLUCoreVersion();
return paddle::lite::TargetWrapperMlu::MLUCoreVersion();
}
int MLUCoreNumber() { return DeviceInfo::Global().MLUCoreNumber(); }
int MLUCoreNumber() {
return paddle::lite::TargetWrapperMlu::MLUCoreNumber();
}
u32_t affinity() { return affinity_; }
......
......@@ -66,14 +66,14 @@ thread_local std::vector<int> DeviceInfo::active_ids_;
thread_local TensorLite DeviceInfo::workspace_;
thread_local int64_t DeviceInfo::count_ = 0;
#ifdef LITE_WITH_MLU
thread_local cnmlCoreVersion_t DeviceInfo::mlu_core_version_{CNML_MLU270};
thread_local int DeviceInfo::mlu_core_number_{1};
thread_local bool DeviceInfo::use_first_conv_{false};
thread_local std::vector<float> DeviceInfo::mean_vec_;
thread_local std::vector<float> DeviceInfo::std_vec_;
thread_local DataLayoutType DeviceInfo::input_layout_{DATALAYOUT(kNCHW)};
#endif
// #ifdef LITE_WITH_MLU
// thread_local cnmlCoreVersion_t DeviceInfo::mlu_core_version_{CNML_MLU270};
// thread_local int DeviceInfo::mlu_core_number_{1};
// thread_local bool DeviceInfo::use_first_conv_{false};
// thread_local std::vector<float> DeviceInfo::mean_vec_;
// thread_local std::vector<float> DeviceInfo::std_vec_;
// thread_local DataLayoutType DeviceInfo::input_layout_{DATALAYOUT(kNCHW)};
// #endif
#ifdef TARGET_IOS
const int DEFAULT_L1_CACHE_SIZE = 64 * 1024;
......@@ -1089,44 +1089,44 @@ int DeviceInfo::Setup() {
return 0;
}
#ifdef LITE_WITH_MLU
void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
break;
case (lite_api::MLUCoreVersion::MLU_270):
mlu_core_version_ = CNML_MLU270;
break;
default:
mlu_core_version_ = CNML_MLU270;
break;
}
mlu_core_number_ = core_number;
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
input_layout_ = input_layout;
}
cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; }
int DeviceInfo::MLUCoreNumber() { return mlu_core_number_; }
bool DeviceInfo::UseFirstConv() { return use_first_conv_; }
const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; }
const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; }
DataLayoutType DeviceInfo::InputLayout() const { return input_layout_; }
#endif // LITE_WITH_MLU
// #ifdef LITE_WITH_MLU
// void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
// int core_number,
// bool use_first_conv,
// const std::vector<float>& mean_vec,
// const std::vector<float>& std_vec,
// DataLayoutType input_layout) {
// switch (core_version) {
// case (lite_api::MLUCoreVersion::MLU_220):
// mlu_core_version_ = CNML_MLU220;
// break;
// case (lite_api::MLUCoreVersion::MLU_270):
// mlu_core_version_ = CNML_MLU270;
// break;
// default:
// mlu_core_version_ = CNML_MLU270;
// break;
// }
// mlu_core_number_ = core_number;
// use_first_conv_ = use_first_conv;
// mean_vec_ = mean_vec;
// std_vec_ = std_vec;
// input_layout_ = input_layout;
// }
//
// cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; }
//
// int DeviceInfo::MLUCoreNumber() { return mlu_core_number_; }
//
// bool DeviceInfo::UseFirstConv() { return use_first_conv_; }
//
// const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; }
//
// const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; }
//
// DataLayoutType DeviceInfo::InputLayout() const { return input_layout_; }
//
// #endif // LITE_WITH_MLU
void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) {
#ifdef ARM_WITH_OMP
......
......@@ -55,20 +55,20 @@ class DeviceInfo {
int Setup();
void SetRunMode(lite_api::PowerMode mode, int thread_num);
#ifdef LITE_WITH_MLU
void SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec,
DataLayoutType input_layout);
cnmlCoreVersion_t MLUCoreVersion();
int MLUCoreNumber();
bool UseFirstConv();
const std::vector<float>& MeanVec() const;
const std::vector<float>& StdVec() const;
DataLayoutType InputLayout() const;
#endif
// #ifdef LITE_WITH_MLU
// void SetMLURunMode(lite_api::MLUCoreVersion core_version,
// int core_number,
// bool use_first_conv,
// const std::vector<float>& mean_vec,
// const std::vector<float>& std_vec,
// DataLayoutType input_layout);
// cnmlCoreVersion_t MLUCoreVersion();
// int MLUCoreNumber();
// bool UseFirstConv();
// const std::vector<float>& MeanVec() const;
// const std::vector<float>& StdVec() const;
// DataLayoutType InputLayout() const;
// #endif
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch) { arch_ = arch; }
......@@ -120,14 +120,14 @@ class DeviceInfo {
static thread_local TensorLite workspace_;
static thread_local int64_t count_;
#ifdef LITE_WITH_MLU
static thread_local cnmlCoreVersion_t mlu_core_version_;
static thread_local int mlu_core_number_;
static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
#endif
// #ifdef LITE_WITH_MLU
// static thread_local cnmlCoreVersion_t mlu_core_version_;
// static thread_local int mlu_core_number_;
// static thread_local bool use_first_conv_;
// static thread_local std::vector<float> mean_vec_;
// static thread_local std::vector<float> std_vec_;
// static thread_local DataLayoutType input_layout_;
// #endif
void SetDotInfo(int argc, ...);
void SetFP16Info(int argc, ...);
......
......@@ -569,11 +569,11 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch;
// arg_in and arg_out are assumed to be NHWC which user should be aware of.
// Thus here we change these args' layout to NHWC
if (lite::DeviceInfo::Global().InputLayout() == DATALAYOUT(kNHWC)) {
if (lite::TargetWrapperMlu::InputLayout() == DATALAYOUT(kNHWC)) {
ModifyLayout(graph.get());
}
if (lite::DeviceInfo::Global().UseFirstConv()) {
if (lite::TargetWrapperMlu::UseFirstConv()) {
GatherAndModifyFirstConvNodes(graph.get());
}
......
......@@ -164,7 +164,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const auto input_scale = op_info->GetAttr<float>("input_scale");
bool use_first_conv = false;
if (lite::DeviceInfo::Global().UseFirstConv() && input_dims[1] == 3) {
if (lite::TargetWrapperMlu::UseFirstConv() && input_dims[1] == 3) {
use_first_conv = true;
}
......@@ -192,11 +192,11 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph->FPType());
graph->BindConstRawData("first_conv_mean_tensor",
lite::DeviceInfo::Global().MeanVec().data(),
lite::TargetWrapperMlu::MeanVec().data(),
3,
false);
graph->BindConstRawData("first_conv_std_tensor",
lite::DeviceInfo::Global().StdVec().data(),
lite::TargetWrapperMlu::StdVec().data(),
3,
false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册