提交 046de2ac 编写于 作者: C cuichaowen 提交者: Yan Chunwei

Improve anakin feature (#11961)

上级 baff71d5
...@@ -8,6 +8,7 @@ set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header f ...@@ -8,6 +8,7 @@ set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header f
set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library") set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library")
set(ANAKIN_COMPILE_EXTRA_FLAGS set(ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-error=unused-but-set-variable -Wno-unused-but-set-variable
-Wno-error=unused-variable -Wno-unused-variable -Wno-error=unused-variable -Wno-unused-variable
-Wno-error=format-extra-args -Wno-format-extra-args -Wno-error=format-extra-args -Wno-format-extra-args
-Wno-error=comment -Wno-comment -Wno-error=comment -Wno-comment
...@@ -19,7 +20,7 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS ...@@ -19,7 +20,7 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-reorder -Wno-reorder
-Wno-error=cpp) -Wno-error=cpp)
set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/3.0/anakin_release_simple.tar.gz") set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/Version0.1.0/anakin.tar.gz")
# A helper function used in Anakin, currently, to use it, one need to recursively include # A helper function used in Anakin, currently, to use it, one need to recursively include
# nearly all the header files. # nearly all the header files.
...@@ -41,9 +42,9 @@ if (NOT EXISTS "${ANAKIN_INSTALL_DIR}") ...@@ -41,9 +42,9 @@ if (NOT EXISTS "${ANAKIN_INSTALL_DIR}")
message(STATUS "Download Anakin library from ${ANAKIN_LIBRARY_URL}") message(STATUS "Download Anakin library from ${ANAKIN_LIBRARY_URL}")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}") execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}")
execute_process(COMMAND bash -c "rm -rf ${ANAKIN_INSTALL_DIR}/*") execute_process(COMMAND bash -c "rm -rf ${ANAKIN_INSTALL_DIR}/*")
execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; wget -q ${ANAKIN_LIBRARY_URL}") execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; wget --no-check-certificate -q ${ANAKIN_LIBRARY_URL}")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}") execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}")
execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; tar xzf anakin_release_simple.tar.gz") execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; tar xzf anakin.tar.gz")
endif() endif()
if (WITH_ANAKIN) if (WITH_ANAKIN)
......
...@@ -19,6 +19,7 @@ endif(APPLE) ...@@ -19,6 +19,7 @@ endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api) set(inference_deps paddle_inference_api paddle_fluid_api)
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine)
endif() endif()
...@@ -63,6 +64,8 @@ endif() ...@@ -63,6 +64,8 @@ endif()
if (WITH_ANAKIN) # only needed in CI if (WITH_ANAKIN) # only needed in CI
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's, # Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to # so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to
# compile the libinference_anakin_api.a and compile with anakin.so.
fetch_include_recursively(${ANAKIN_INCLUDE})
# compile the libinference_anakin_api.a and anakin.so. # compile the libinference_anakin_api.a and anakin.so.
nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc) nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc)
nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc) nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc)
...@@ -73,7 +76,7 @@ if (WITH_ANAKIN) # only needed in CI ...@@ -73,7 +76,7 @@ if (WITH_ANAKIN) # only needed in CI
if (WITH_TESTING) if (WITH_TESTING)
cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc
ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api) DEPS inference_anakin_api_shared)
target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
endif(WITH_TESTING) endif(WITH_TESTING)
endif() endif()
...@@ -18,26 +18,36 @@ ...@@ -18,26 +18,36 @@
namespace paddle { namespace paddle {
PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor( template <typename Target>
PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
const AnakinConfig &config) { const AnakinConfig &config) {
CHECK(Init(config)); CHECK(Init(config));
} }
bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) { template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
if (!(graph_.load(config.model_file))) { if (!(graph_.load(config.model_file))) {
LOG(FATAL) << "fail to load graph from " << config.model_file;
return false; return false;
} }
graph_.ResetBatchSize("input_0", config.max_batch_size); auto inputs = graph_.get_ins();
for (auto &input_str : inputs) {
graph_.ResetBatchSize(input_str, config.max_batch_size);
}
// optimization for graph // optimization for graph
if (!(graph_.Optimize())) { if (!(graph_.Optimize())) {
return false; return false;
} }
// construct executer // construct executer
executor_.init(graph_); if (executor_p_ == nullptr) {
executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
anakin::Precision::FP32>(graph_, true);
}
return true; return true;
} }
bool PaddleInferenceAnakinPredictor::Run( template <typename Target>
bool PaddleInferenceAnakinPredictor<Target>::Run(
const std::vector<PaddleTensor> &inputs, const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, int batch_size) { std::vector<PaddleTensor> *output_data, int batch_size) {
for (const auto &input : inputs) { for (const auto &input : inputs) {
...@@ -46,7 +56,29 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -46,7 +56,29 @@ bool PaddleInferenceAnakinPredictor::Run(
<< "'s type is not float"; << "'s type is not float";
return false; return false;
} }
auto d_tensor_in_p = executor_.get_in(input.name); auto d_tensor_in_p = executor_p_->get_in(input.name);
auto net_shape = d_tensor_in_p->valid_shape();
if (net_shape.size() != input.shape.size()) {
LOG(ERROR) << " input " << input.name
<< "'s shape size should be equal to that of net";
return false;
}
int sum = 1;
for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
if (sum > net_shape.count()) {
graph_.Reshape(input.name, input.shape);
delete executor_p_;
executor_p_ = new anakin::Net<Target, anakin::saber::AK_FLOAT,
anakin::Precision::FP32>(graph_, true);
d_tensor_in_p = executor_p_->get_in(input.name);
}
anakin::saber::Shape tmp_shape;
for (auto s : input.shape) {
tmp_shape.push_back(s);
}
d_tensor_in_p->reshape(tmp_shape);
float *d_data_p = d_tensor_in_p->mutable_data(); float *d_data_p = d_tensor_in_p->mutable_data();
if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()), if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
d_tensor_in_p->valid_size() * sizeof(float), d_tensor_in_p->valid_size() * sizeof(float),
...@@ -56,16 +88,17 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -56,16 +88,17 @@ bool PaddleInferenceAnakinPredictor::Run(
} }
cudaStreamSynchronize(NULL); cudaStreamSynchronize(NULL);
} }
cudaDeviceSynchronize();
executor_.prediction(); executor_p_->prediction();
cudaDeviceSynchronize();
if (output_data->empty()) { if (output_data->empty()) {
LOG(ERROR) << "At least one output should be set with tensors' names."; LOG(ERROR) << "At least one output should be set with tensors' names.";
return false; return false;
} }
for (auto &output : *output_data) { for (auto &output : *output_data) {
auto *tensor = executor_.get_out(output.name); auto *tensor = executor_p_->get_out(output.name);
output.shape = tensor->shape(); output.shape = tensor->valid_shape();
if (output.data.length() < tensor->valid_size() * sizeof(float)) { if (output.data.length() < tensor->valid_size() * sizeof(float)) {
output.data.Resize(tensor->valid_size() * sizeof(float)); output.data.Resize(tensor->valid_size() * sizeof(float));
} }
...@@ -81,19 +114,23 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -81,19 +114,23 @@ bool PaddleInferenceAnakinPredictor::Run(
return true; return true;
} }
anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32> template <typename Target>
&PaddleInferenceAnakinPredictor::get_executer() { anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
return executor_; &PaddleInferenceAnakinPredictor<Target>::get_executer() {
return *executor_p_;
} }
// the cloned new Predictor of anakin share the same net weights from original // the cloned new Predictor of anakin share the same net weights from original
// Predictor // Predictor
std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() { template <typename Target>
std::unique_ptr<PaddlePredictor>
PaddleInferenceAnakinPredictor<Target>::Clone() {
VLOG(3) << "Anakin Predictor::clone"; VLOG(3) << "Anakin Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new PaddleInferenceAnakinPredictor()); std::unique_ptr<PaddlePredictor> cls(
new PaddleInferenceAnakinPredictor<Target>());
// construct executer from other graph // construct executer from other graph
auto anakin_predictor_p = auto anakin_predictor_p =
dynamic_cast<PaddleInferenceAnakinPredictor *>(cls.get()); dynamic_cast<PaddleInferenceAnakinPredictor<Target> *>(cls.get());
if (!anakin_predictor_p) { if (!anakin_predictor_p) {
LOG(ERROR) << "fail to call Init"; LOG(ERROR) << "fail to call Init";
return nullptr; return nullptr;
...@@ -103,14 +140,28 @@ std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() { ...@@ -103,14 +140,28 @@ std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() {
return std::move(cls); return std::move(cls);
} }
template class PaddleInferenceAnakinPredictor<anakin::NV>;
template class PaddleInferenceAnakinPredictor<anakin::X86>;
// A factory to help create difference predictor. // A factory to help create difference predictor.
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
AnakinConfig, PaddleEngineKind::kAnakin>(const AnakinConfig &config) { AnakinConfig, PaddleEngineKind::kAnakin>(const AnakinConfig &config) {
VLOG(3) << "Anakin Predictor create."; VLOG(3) << "Anakin Predictor create.";
std::unique_ptr<PaddlePredictor> x( if (config.target_type == AnakinConfig::NVGPU) {
new PaddleInferenceAnakinPredictor(config)); VLOG(3) << "Anakin Predictor create on [ NVIDIA GPU ].";
return x; std::unique_ptr<PaddlePredictor> x(
} new PaddleInferenceAnakinPredictor<anakin::NV>(config));
return x;
} else if (config.target_type == AnakinConfig::X86) {
VLOG(3) << "Anakin Predictor create on [ Intel X86 ].";
std::unique_ptr<PaddlePredictor> x(
new PaddleInferenceAnakinPredictor<anakin::X86>(config));
return x;
} else {
VLOG(3) << "Anakin Predictor create on unknown platform.";
return nullptr;
}
};
} // namespace paddle } // namespace paddle
...@@ -20,14 +20,16 @@ limitations under the License. */ ...@@ -20,14 +20,16 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
// from anakin
#include "framework/core/net/net.h" #include "framework/core/net/net.h"
#include "framework/graph/graph.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "saber/core/shape.h"
#include "saber/saber_types.h" #include "saber/saber_types.h"
namespace paddle { namespace paddle {
template <typename Target>
class PaddleInferenceAnakinPredictor : public PaddlePredictor { class PaddleInferenceAnakinPredictor : public PaddlePredictor {
public: public:
PaddleInferenceAnakinPredictor() {} PaddleInferenceAnakinPredictor() {}
...@@ -42,19 +44,21 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor { ...@@ -42,19 +44,21 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32>& anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>&
get_executer(); get_executer();
~PaddleInferenceAnakinPredictor() override{}; ~PaddleInferenceAnakinPredictor() override {
delete executor_p_;
executor_p_ = nullptr;
};
private: private:
bool Init(const AnakinConfig& config); bool Init(const AnakinConfig& config);
anakin::graph::Graph<anakin::NV, anakin::saber::AK_FLOAT, anakin::graph::Graph<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
anakin::Precision::FP32>
graph_; graph_;
anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32> anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>*
executor_; executor_p_{nullptr};
AnakinConfig config_; AnakinConfig config_;
}; };
......
...@@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
DEFINE_string(model, "", "Directory of the inference model."); DEFINE_string(model, "", "Directory of the inference model(mobile_v2).");
namespace paddle { namespace paddle {
AnakinConfig GetConfig() { AnakinConfig GetConfig() {
AnakinConfig config; AnakinConfig config;
// using AnakinConfig::X86 if you need to use cpu to do inference
config.target_type = AnakinConfig::NVGPU;
config.model_file = FLAGS_model; config.model_file = FLAGS_model;
config.device = 0; config.device = 0;
config.max_batch_size = 1; config.max_batch_size = 1;
...@@ -36,7 +38,6 @@ TEST(inference, anakin) { ...@@ -36,7 +38,6 @@ TEST(inference, anakin) {
CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(config); CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(config);
float data[1 * 3 * 224 * 224] = {1.0f}; float data[1 * 3 * 224 * 224] = {1.0f};
PaddleTensor tensor; PaddleTensor tensor;
tensor.name = "input_0"; tensor.name = "input_0";
tensor.shape = std::vector<int>({1, 3, 224, 224}); tensor.shape = std::vector<int>({1, 3, 224, 224});
...@@ -44,22 +45,20 @@ TEST(inference, anakin) { ...@@ -44,22 +45,20 @@ TEST(inference, anakin) {
tensor.dtype = PaddleDType::FLOAT32; tensor.dtype = PaddleDType::FLOAT32;
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> paddle_tensor_feeds; std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
paddle_tensor_feeds.emplace_back(std::move(tensor));
PaddleTensor tensor_out; PaddleTensor tensor_out;
tensor_out.name = "prob_out"; tensor_out.name = "prob_out";
tensor_out.shape = std::vector<int>({1000, 1}); tensor_out.shape = std::vector<int>({});
tensor_out.data = PaddleBuf(); tensor_out.data = PaddleBuf();
tensor_out.dtype = PaddleDType::FLOAT32; tensor_out.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs(1, tensor_out);
outputs.emplace_back(std::move(tensor_out));
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
float* data_o = static_cast<float*>(outputs[0].data.data()); float* data_o = static_cast<float*>(outputs[0].data.data());
for (size_t j = 0; j < 1000; ++j) { for (size_t j = 0; j < outputs[0].data.length(); ++j) {
LOG(INFO) << "output[" << j << "]: " << data_o[j]; LOG(INFO) << "output[" << j << "]: " << data_o[j];
} }
} }
......
...@@ -126,9 +126,11 @@ struct NativeConfig : public PaddlePredictor::Config { ...@@ -126,9 +126,11 @@ struct NativeConfig : public PaddlePredictor::Config {
// Configurations for Anakin engine. // Configurations for Anakin engine.
struct AnakinConfig : public PaddlePredictor::Config { struct AnakinConfig : public PaddlePredictor::Config {
enum TargetType { NVGPU = 0, X86 };
int device; int device;
std::string model_file; std::string model_file;
int max_batch_size{-1}; int max_batch_size{-1};
TargetType target_type;
}; };
struct TensorRTConfig : public NativeConfig { struct TensorRTConfig : public NativeConfig {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册