未验证 提交 15291548 编写于 作者: 石晓伟 提交者: GitHub

Support Bitmain Anakin (#18542)

* update anakin-engine interfaces for content-dnn

test=develop

* support only-gpu mode of Anakin

modify eltwise parse

test=develop

* modification for thread-safe

test=develop

* Integrated template instance

test=develop

* increase template parameters

test=develop

* support MLU predictor

test=develop

* update anakin cmake files

test=develop

* update TargetWrapper::set_device

* update the initialization of anakin subgraph

test=develop

* use the default constructor of base class

test=develop

* load model from buffer with length

test=develop

* modify the access level of class

test=develop

* support anakin for bitmain arch

test=develop

* remove files

* checkout cmakelists

test=develop
上级 9b3d3b83
......@@ -26,6 +26,19 @@ if(ANAKIN_FOUND)
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()
if(ANAKIN_FOUND)
if (ANAKIN_MLU AND NOT WITH_GPU AND NOT ANAKIN_X86)
message(STATUS "Compile with anakin mlu place.")
add_definitions(-DANAKIN_MLU_PLACE)
elseif(ANAKIN_BM AND NOT WITH_GPU AND NOT ANAKIN_X86)
message(STATUS "Compile with anakin bm place.")
add_definitions(-DANAKIN_BM_PLACE)
elseif(ANAKIN_X86)
message(STATUS "Compile with anakin x86 place.")
add_definitions(-DANAKIN_X86_PLACE)
endif()
endif()
if(ANAKIN_FOUND AND WITH_GPU AND WITH_DSO)
message(STATUS "Compile with anakin subgraph.")
set(ANAKIN_SUBGRAPH ON)
......
......@@ -68,13 +68,8 @@ cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_
ARGS --dirname=${WORD2VEC_MODEL_DIR})
if(ANAKIN_FOUND)
if (ANAKIN_MLU AND NOT WITH_GPU AND NOT ANAKIN_X86)
message(STATUS "Compile with anakin mlu place.")
add_definitions(-DANAKIN_MLU_PLACE)
elseif(ANAKIN_X86)
message(STATUS "Compile with anakin x86 place.")
add_definitions(-DANAKIN_X86_PLACE)
endif()
# Do not turn warnings into errors.
set_source_files_properties(api.cc api_anakin_engine.cc PROPERTIES COMPILE_FLAGS "-Wno-error")
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc)
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc)
......
......@@ -34,10 +34,10 @@ extern std::once_flag PaddleInferenceAnakinPredictor<T, P, R>::init_anakin_;
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitEnv() {
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
std::call_once(this->init_anakin_, [this]() {
anakin::Env<T>::env_init(this->config_.max_stream);
});
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
}
template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitNet() {
......@@ -54,14 +54,19 @@ template <typename T, Precision P, OpRunType R>
void PaddleInferenceAnakinPredictor<T, P, R>::InitGraph() {
this->graph_p_ =
std::make_shared<anakin::graph::Graph<T, anakin::Precision::FP32>>();
if (!(this->graph_p_->load(this->config_.model_file))) {
LOG(FATAL) << "fail to load graph from " << this->config_.model_file;
if (!this->config_.model_file.empty()) {
this->graph_p_->load(this->config_.model_file);
} else if (this->config_.model_buf_p) {
this->graph_p_->load(this->config_.model_buf_p,
this->config_.model_buf_len);
} else {
LOG(FATAL) << "Model load error.";
}
auto inputs = this->graph_p_->get_ins();
for (auto &input_str : inputs) {
if (this->config_.init_inputs_shape.find(input_str) ==
this->config_.init_inputs_shape.end()) {
LOG(FATAL) << input_str << " is not implemented.";
LOG(FATAL) << input_str << " should be set in init_inputs_shape.";
}
std::vector<int> shape =
this->config_.init_inputs_shape.find(input_str)->second;
......@@ -189,6 +194,7 @@ template <typename T, Precision P, OpRunType R>
bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) {
anakin::TargetWrapper<T>::set_device(this->config_.device_id);
for (const auto &input : inputs) {
if (input.dtype != PaddleDType::FLOAT32) {
LOG(FATAL) << "Only support float type inputs. " << input.name
......@@ -321,6 +327,27 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::Predict() {
}
#endif
#ifdef ANAKIN_BM_PLACE
template <Precision P, OpRunType R>
void PaddleInferenceAnakinBMPredictor<P, R>::OptimizeGraph() {
if (!this->graph_p_->fusion_optimize()) {
LOG(FATAL) << "Graph optimization error.";
}
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinBMPredictor<P, R>::InitNet() {
std::unique_lock<std::mutex> lock(this->mutex_);
this->executor_p_ = new anakin::Net<anakin::BM, P, R>();
this->executor_p_->fusion_init(*this->graph_p_, this->ctx_p_, true);
}
template <Precision P, OpRunType R>
void PaddleInferenceAnakinBMPredictor<P, R>::Predict() {
anakin::TargetWrapper<anakin::BM>::device_sync();
this->executor_p_->fusion_prediction();
anakin::TargetWrapper<anakin::BM>::device_sync();
}
#endif
#ifdef PADDLE_WITH_CUDA
template class PaddleInferenceAnakinPredictor<
anakin::NV, anakin::Precision::FP32, ::anakin::OpRunType::ASYNC>;
......@@ -333,6 +360,10 @@ template class PaddleInferenceAnakinPredictor<
template class PaddleInferenceAnakinMLUPredictor<anakin::Precision::FP32,
::anakin::OpRunType::SYNC>;
#endif
#ifdef ANAKIN_BM_PLACE
template class PaddleInferenceAnakinBMPredictor<anakin::Precision::FP32,
::anakin::OpRunType::ASYNC>;
#endif
// A factory to help create difference predictor.
template <>
......@@ -361,7 +392,16 @@ CreatePaddlePredictor<contrib::AnakinConfig, PaddleEngineKind::kAnakin>(
config));
}
#endif
LOG(FATAL) << "Anakin Predictor create on unknown platform.";
#ifdef ANAKIN_BM_PLACE
if (config.target_type == contrib::AnakinConfig::BM) {
return std::unique_ptr<PaddlePredictor>(
new PaddleInferenceAnakinBMPredictor<anakin::Precision::FP32,
::anakin::OpRunType::ASYNC>(
config));
}
#endif
LOG(FATAL) << "Anakin Predictor create on unknown platform: "
<< config.target_type;
return nullptr;
}
template <typename T, Precision P, OpRunType R>
......
......@@ -92,4 +92,19 @@ class PaddleInferenceAnakinMLUPredictor final
void Predict() override;
};
#endif
#ifdef ANAKIN_BM_PLACE
template <Precision P, OpRunType R>
class PaddleInferenceAnakinBMPredictor final
: public PaddleInferenceAnakinPredictor<anakin::BM, P, R> {
public:
explicit PaddleInferenceAnakinBMPredictor(const AnakinConfig& config) {
this->ResetConfig(config);
this->InitPredictor();
}
void OptimizeGraph() override;
void InitNet() override;
void Predict() override;
};
#endif
} // namespace paddle
......@@ -25,7 +25,7 @@ namespace paddle {
namespace contrib {
// Configurations for Anakin engine.
struct AnakinConfig : public PaddlePredictor::Config {
enum TargetType { NVGPU = 0, X86, MLU };
enum TargetType { NVGPU = 0, X86, MLU, BM };
int device_id{0};
std::string model_file;
std::map<std::string, std::vector<int>> init_inputs_shape;
......@@ -34,6 +34,8 @@ struct AnakinConfig : public PaddlePredictor::Config {
int max_stream{4};
int data_stream_id{0};
int compute_stream_id{0};
char* model_buf_p{nullptr};
size_t model_buf_len{0};
TargetType target_type;
#ifdef ANAKIN_MLU_PLACE
int model_parallel{8};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册