未验证 提交 71b2ed61 编写于 作者: 石晓伟 提交者: GitHub

support MLU nums, test=develop (#19372)

上级 e2c6bada
......@@ -70,9 +70,9 @@ cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_
if(ANAKIN_FOUND)
# Do not turn warnings into errors.
set_source_files_properties(api.cc api_anakin_engine.cc PROPERTIES COMPILE_FLAGS "-Wno-error")
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS boost xxhash)
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS boost xxhash framework_proto eigen3)
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS boost xxhash)
cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS boost xxhash framework_proto eigen3)
target_link_libraries(inference_anakin_api_shared anakin anakin_saber_common)
function(anakin_target target_name)
target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
......
......@@ -42,6 +42,7 @@ void PaddleInferenceAnakinPredictor<T, P, R>::InitEnv() {
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitNet() {
std::unique_lock<std::mutex> lock(this->mutex_);
delete this->executor_p_;
this->executor_p_ = new anakin::Net<T, P, R>(*this->graph_p_, true);
}
template <typename T, Precision P, OpRunType R>
......@@ -89,7 +90,7 @@ void PaddleInferenceAnakinPredictor<T, P, R>::InitPredictor() {
this->InitNet();
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::Predict() {
void PaddleInferenceAnakinPredictor<T, P, R>::Predict(int batch_size) {
anakin::TargetWrapper<T>::device_sync();
this->executor_p_->prediction();
anakin::TargetWrapper<T>::device_sync();
......@@ -99,7 +100,7 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::Run(
const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, int batch_size) {
if (this->config_.re_allocable) {
return this->RunImpl(inputs, output_data);
return this->RunImpl(inputs, output_data, batch_size);
} else {
// Run inputs data that exceeds batch size in batches.
// 1. Reassign the batch size.
......@@ -194,7 +195,7 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::Run(
template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) {
std::vector<PaddleTensor> *output_data, int batch_size) {
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
for (const auto &input : inputs) {
if (input.dtype != PaddleDType::FLOAT32) {
......@@ -207,12 +208,12 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
LOG(FATAL) << " input " << input.name
<< "'s shape size should be equal to that of net";
}
#ifndef ANAKIN_MLU_PLACE
int sum = 1;
for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
if (sum > net_shape.count()) {
if (this->config_.re_allocable) {
this->graph_p_->Reshape(input.name, input.shape);
delete this->executor_p_;
this->InitNet();
d_tensor_p = this->executor_p_->get_in(input.name);
} else {
......@@ -221,6 +222,7 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
"memory.";
}
}
#endif
std::vector<int> tmp_shape;
for (auto s : input.shape) {
tmp_shape.push_back(s);
......@@ -229,8 +231,9 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
tmp_shape);
#ifndef ANAKIN_MLU_PLACE
d_tensor_p->reshape(tmp_shape);
#endif
if (input.lod.size() > 0) {
if (input.lod.size() > 1) {
LOG(FATAL) << " input lod first dim should <=1, but you set "
......@@ -246,9 +249,9 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
}
d_tensor_p->copy_from(h_tensor);
}
this->Predict();
this->Predict(batch_size);
if (output_data->empty()) {
LOG(FATAL) << "At least one output should be set with tensors' names.";
LOG(FATAL) << "The output param in the Run function is incorrect.";
}
for (auto &output : *output_data) {
if (std::find(this->output_names_.begin(), this->output_names_.end(),
......@@ -256,14 +259,18 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
LOG(FATAL) << output.name << " is not in the outputs of the graph.";
}
auto *d_tensor_p = this->executor_p_->get_out(output.name);
output.shape = d_tensor_p->valid_shape();
if (output.data.length() < d_tensor_p->valid_size() * sizeof(float)) {
output.data.Resize(d_tensor_p->valid_size() * sizeof(float));
auto tmp_shape = d_tensor_p->valid_shape();
#ifdef ANAKIN_MLU_PLACE
tmp_shape.set_num(batch_size);
#endif
output.shape = tmp_shape;
if (output.data.length() < tmp_shape.count() * sizeof(float)) {
output.data.Resize(tmp_shape.count() * sizeof(float));
}
auto *data = static_cast<float *>(output.data.data());
anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
d_tensor_p->valid_shape());
tmp_shape);
h_tensor.copy_from(*d_tensor_p);
}
return true;
......@@ -317,6 +324,8 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::SetContext() {
this->config_.compute_stream_id);
this->ctx_p_->set_model_parallel(this->config_.model_parallel);
this->ctx_p_->set_fusion(this->config_.op_fuse);
this->ctx_p_->enable_batch_changable();
this->ctx_p_->enable_channel_duplicate();
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::OptimizeGraph() {
......@@ -327,14 +336,13 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::OptimizeGraph() {
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::InitNet() {
std::unique_lock<std::mutex> lock(this->mutex_);
delete this->executor_p_;
this->executor_p_ = new anakin::Net<anakin::MLU, P, R>();
this->executor_p_->fusion_init(*this->graph_p_, this->ctx_p_, true);
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
anakin::TargetWrapper<anakin::MLU>::device_sync();
this->executor_p_->fusion_prediction();
anakin::TargetWrapper<anakin::MLU>::device_sync();
void PaddleInferenceAnakinMLUPredictor<P, R>::Predict(int batch_size) {
this->executor_p_->fusion_prediction(batch_size);
}
#endif
......@@ -353,14 +361,13 @@ void PaddleInferenceAnakinBMPredictor<P, R>::OptimizeGraph() {
template <Precision P, OpRunType R>
void PaddleInferenceAnakinBMPredictor<P, R>::InitNet() {
std::unique_lock<std::mutex> lock(this->mutex_);
delete this->executor_p_;
this->executor_p_ = new anakin::Net<anakin::BM, P, R>();
this->executor_p_->fusion_init(*this->graph_p_, this->ctx_p_, true);
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinBMPredictor<P, R>::Predict() {
anakin::TargetWrapper<anakin::BM>::device_sync();
void PaddleInferenceAnakinBMPredictor<P, R>::Predict(int batch_size) {
this->executor_p_->fusion_prediction();
anakin::TargetWrapper<anakin::BM>::device_sync();
}
#endif
......
......@@ -73,7 +73,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
virtual void OptimizeGraph();
virtual void InitNet();
virtual void SetContext();
virtual void Predict();
virtual void Predict(int batch_size);
virtual std::unique_ptr<PaddlePredictor> New();
static std::mutex mutex_;
AnakinConfig config_;
......@@ -85,7 +85,7 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
private:
bool RunImpl(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data);
std::vector<PaddleTensor>* output_data, int batch_size = -1);
static std::once_flag init_anakin_;
};
......@@ -103,7 +103,7 @@ class PaddleInferenceAnakinMLUPredictor final
void SetContext() override;
void OptimizeGraph() override;
void InitNet() override;
void Predict() override;
void Predict(int batch_size) override;
};
#endif
......@@ -120,7 +120,7 @@ class PaddleInferenceAnakinBMPredictor final
std::unique_ptr<PaddlePredictor> New() override;
void OptimizeGraph() override;
void InitNet() override;
void Predict() override;
void Predict(int batch_size) override;
};
#endif
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册