提交 12a1a095 编写于 作者: J jackzhang235

add set_input_layout

上级 97c2d205
......@@ -36,16 +36,12 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif
#ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init();
mlu_core_version_ = config.mlu_core_version();
mlu_core_number_ = config.mlu_core_number();
use_first_conv_ = config.use_first_conv();
mean_vec_ = config.mean();
std_vec_ = config.std();
lite::DeviceInfo::Global().SetMLURunMode(mlu_core_version_,
mlu_core_number_,
use_first_conv_,
mean_vec_,
std_vec_);
lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
config.mlu_first_conv_std(),
config.mlu_input_layout());
#endif // LITE_WITH_MLU
auto places = config.valid_places();
std::vector<std::string> passes{};
......
......@@ -203,6 +203,38 @@ void ConfigBase::set_threads(int threads) {
#endif
}
void CxxConfig::mlu_set_mlu_core_version(
lite_api::MLUCoreVersion core_version) {
mlu_core_version_ = core_version;
}
void CxxConfig::mlu_set_mlu_core_number(int core_number) {
mlu_core_number_ = core_number;
}
void CxxConfig::mlu_set_input_layout()(DataLayoutType layout) {
mlu_input_layout_ = layout;
}
void CxxConfig::mlu_set_use_first_conv(bool use_first_conv) {
mlu_use_first_conv_ = use_first_conv;
}
void CxxConfig::mlu_set_first_conv_mean(const std::vector<float> &mean) {
mlu_first_conv_mean_ = mean;
}
void CxxConfig::mlu_set_first_conv_std(const std::vector<float> &std) {
mlu_first_conv_std_ = std;
}
lite_api::MLUCoreVersion CxxConfig::mlu_core_version() const {
return mlu_core_version_;
}
int CxxConfig::mlu_core_number() const { return mlu_core_number_; }
DataLayoutType CxxConfig::mlu_input_layout() const { return mlu_input_layout_; }
bool CxxConfig::mlu_use_first_conv() const { return mlu_use_first_conv_; }
std::vector<float> CxxConfig::mlu_first_conv_mean() const {
return mlu_first_conv_mean_;
}
std::vector<float> CxxConfig::mlu_first_conv_std() const {
return mlu_first_conv_std_;
}
// set model data in combined format, `set_model_from_file` refers to loading
// model from file, set_model_from_buffer refers to loading model from memory
// buffer
......
......@@ -106,11 +106,6 @@ class LITE_API PaddlePredictor {
protected:
int threads_{1};
lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND};
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLU_270};
int mlu_core_number_{1};
bool use_first_conv_{false};
std::vector<float> mean_vec_;
std::vector<float> std_vec_;
};
/// Base class for all the configs.
......@@ -141,11 +136,13 @@ class LITE_API CxxConfig : public ConfigBase {
#ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1;
#endif
bool use_firstconv_{false};
std::vector<float> mean_ = {0.0f};
std::vector<float> std_ = {1.0f};
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270};
int mlu_core_number_{1};
DataLayoutType mlu_input_layout_{DATALAYOUT(kNCHW)};
bool mlu_use_first_conv_{false};
std::vector<float> mlu_first_conv_mean_;
std::vector<float> mlu_first_conv_std_;
public:
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
......@@ -173,20 +170,20 @@ class LITE_API CxxConfig : public ConfigBase {
return x86_math_library_math_threads_;
}
#endif
void set_use_firstconv(const bool firstconv) { use_firstconv_ = firstconv; }
void set_mean(const std::vector<float> mean) { mean_ = mean; }
void set_std(const std::vector<float> std) { std_ = std; }
void set_mlu_core_version(lite_api::MLUCoreVersion core_version) {
mlu_core_version_ = core_version;
}
void set_mlu_core_number(int core_number) { mlu_core_number_ = core_number; }
bool use_first_conv() const { return use_firstconv_; }
std::vector<float> mean() const { return mean_; }
std::vector<float> std() const { return std_; }
lite_api::MLUCoreVersion mlu_core_version() const {
return mlu_core_version_;
}
int mlu_core_number() const { return mlu_core_number_; }
void mlu_set_mlu_core_version(lite_api::MLUCoreVersion core_version);
void mlu_set_mlu_core_number(int core_number);
void mlu_set_input_layout()(DataLayoutType layout);
void mlu_set_use_first_conv(bool use_first_conv);
void mlu_set_first_conv_mean(const std::vector<float>& mean);
void mlu_set_first_conv_std(const std::vector<float>& std);
lite_api::MLUCoreVersion mlu_core_version() const;
int mlu_core_number() const;
DataLayoutType mlu_input_layout() const;
bool mlu_use_first_conv() const;
std::vector<float> mlu_first_conv_mean() const;
std::vector<float> mlu_first_conv_std() const;
};
/// MobileConfig is the config for the light weight predictor, it will skip
......
......@@ -1093,7 +1093,8 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec) {
const std::vector<float>& std_vec,
DataLayoutType input_layout) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
......@@ -1109,6 +1110,7 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
input_layout_ = input_layout;
}
cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; }
......@@ -1121,6 +1123,8 @@ const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; }
const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; }
const DataLayoutType InputLayout() const { return input_layout_; }
#endif // LITE_WITH_MLU
void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) {
......
......@@ -60,12 +60,14 @@ class DeviceInfo {
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec);
const std::vector<float>& std_vec,
DataLayoutType input_layout);
cnmlCoreVersion_t MLUCoreVersion();
int MLUCoreNumber();
bool UseFirstConv();
const std::vector<float>& MeanVec() const;
const std::vector<float>& StdVec() const;
const DataLayoutType InputLayout() const;
#endif
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch) { arch_ = arch; }
......@@ -124,6 +126,7 @@ class DeviceInfo {
static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
#endif
void SetDotInfo(int argc, ...);
......
......@@ -539,7 +539,9 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch;
// arg_in and arg_out are assumed to be NHWC which user should be aware of.
// Thus here we change these args' layout to NHWC
ModifyLayout(graph.get());
if (lite::DeviceInfo::Global().InputLayout() == DATALAYOUT(kNHWC) {
ModifyLayout(graph.get());
}
if (lite::DeviceInfo::Global().UseFirstConv()) {
GatherFirstConvNodes(graph.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册