提交 69d37f81 编写于 作者: N nhzlx

cherry-pick from feature/anakin-engine: refine anakin subgraph. #16157

support change input size
上级 a1d200a5
...@@ -1527,6 +1527,16 @@ PDNode *patterns::AnakinDetectionPattern::operator()( ...@@ -1527,6 +1527,16 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
->assert_is_op_output("box_coder") ->assert_is_op_output("box_coder")
->AsIntermediate(); ->AsIntermediate();
auto transpose_before_nms =
pattern->NewNode(GetNodeName("transpose_before_nms"))
->assert_is_op("transpose2");
auto transpose_before_nms_out =
pattern->NewNode(GetNodeName("transpose_before_nms_out"))
->assert_is_op_output("transpose2")
->assert_is_op_input("multiclass_nms", "Scores")
->AsIntermediate();
auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms")) auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
->assert_is_op("multiclass_nms") ->assert_is_op("multiclass_nms")
->assert_op_has_n_inputs("multiclass_nms", 2); ->assert_op_has_n_inputs("multiclass_nms", 2);
...@@ -1577,8 +1587,10 @@ PDNode *patterns::AnakinDetectionPattern::operator()( ...@@ -1577,8 +1587,10 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]}); {concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
box_coder_out->LinksFrom({box_coder_op}); box_coder_out->LinksFrom({box_coder_op});
multiclass_nms_op transpose_before_nms->LinksFrom({conv_in[kMultiClassSecondInputNmsOffset]});
->LinksFrom({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset]}) transpose_before_nms_out->LinksFrom({transpose_before_nms});
multiclass_nms_op->LinksFrom({box_coder_out, transpose_before_nms_out})
.LinksTo({multiclass_nms_out}); .LinksTo({multiclass_nms_out});
return multiclass_nms_out; return multiclass_nms_out;
......
...@@ -45,7 +45,7 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -45,7 +45,7 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
input_nodes.push_back(gpd.mutable_pattern() input_nodes.push_back(gpd.mutable_pattern()
->NewNode("x" + std::to_string(times + 1)) ->NewNode("x" + std::to_string(times + 1))
->assert_is_op_input("multiclass_nms", "Scores") ->assert_is_op_input("transpose2")
->AsInput()); ->AsInput());
patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name); patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name);
...@@ -106,6 +106,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -106,6 +106,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
Node *box_coder_out = subgraph.at(pattern.GetPDNode("box_coder_out")); Node *box_coder_out = subgraph.at(pattern.GetPDNode("box_coder_out"));
Node *multiclass_nms_second_input = subgraph.at(input_nodes[times + 1]); Node *multiclass_nms_second_input = subgraph.at(input_nodes[times + 1]);
Node *transpose_before_nms =
subgraph.at(pattern.GetPDNode("transpose_before_nms"));
Node *transpose_before_nms_out =
subgraph.at(pattern.GetPDNode("transpose_before_nms_out"));
Node *multiclass_nms = subgraph.at(pattern.GetPDNode("multiclass_nms")); Node *multiclass_nms = subgraph.at(pattern.GetPDNode("multiclass_nms"));
Node *multiclass_nms_out = Node *multiclass_nms_out =
subgraph.at(pattern.GetPDNode("multiclass_nms_out")); subgraph.at(pattern.GetPDNode("multiclass_nms_out"));
...@@ -133,11 +138,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -133,11 +138,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
nodes[i * kNumFields + kPriorBoxLocOffset]->Name()); nodes[i * kNumFields + kPriorBoxLocOffset]->Name());
} }
int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis")); // int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
framework::OpDesc concat1_desc; framework::OpDesc concat1_desc;
concat1_desc.SetType("concat"); concat1_desc.SetType("concat");
concat1_desc.SetInput("X", concat1_input_names); concat1_desc.SetInput("X", concat1_input_names);
concat1_desc.SetAttr("axis", axis); concat1_desc.SetAttr("axis", 2);
concat1_desc.SetOutput("Out", {concat_out1->Name()}); concat1_desc.SetOutput("Out", {concat_out1->Name()});
auto *new_add_concat_op = graph->CreateOpNode(&concat1_desc); auto *new_add_concat_op = graph->CreateOpNode(&concat1_desc);
...@@ -184,6 +189,8 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -184,6 +189,8 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
delete_nodes.insert(concat_out2); delete_nodes.insert(concat_out2);
delete_nodes.insert(box_coder_op); delete_nodes.insert(box_coder_op);
delete_nodes.insert(box_coder_out); delete_nodes.insert(box_coder_out);
delete_nodes.insert(transpose_before_nms);
delete_nodes.insert(transpose_before_nms_out);
delete_nodes.insert(multiclass_nms); delete_nodes.insert(multiclass_nms);
new_add_concat_op->outputs.push_back(concat_out1); new_add_concat_op->outputs.push_back(concat_out1);
......
cc_library(anakin_engine SRCS engine.cc) cc_library(anakin_engine SRCS engine.cc)
nv_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto) cc_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto)
target_link_libraries(anakin_engine anakin anakin_saber_common) target_link_libraries(anakin_engine anakin anakin_saber_common)
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine) cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
add_subdirectory(convert) add_subdirectory(convert)
...@@ -43,11 +43,13 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -43,11 +43,13 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
auto output = op_desc.Output("Y").front(); auto output = op_desc.Output("Y").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front();
auto epsilon = boost::get<float>(op_desc.GetAttr("epsilon")); auto epsilon = boost::get<float>(op_desc.GetAttr("epsilon"));
// auto momentum = boost::get<float>(op_desc.GetAttr("momentum"));
auto bn_op_name = op_name + ":bn"; auto bn_op_name = op_name + ":bn";
auto bn_output = bn_op_name + "_output"; auto bn_output = bn_op_name + "_output";
engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output}); engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output});
engine_->AddOpAttr(bn_op_name, "epsilon", epsilon); engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));
auto scale_op_name = op_name + ":scale"; auto scale_op_name = op_name + ":scale";
auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name, auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name,
......
...@@ -27,8 +27,8 @@ namespace paddle { ...@@ -27,8 +27,8 @@ namespace paddle {
namespace inference { namespace inference {
namespace anakin { namespace anakin {
void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op, void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
const framework::Scope &scope, const framework::Scope& scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front(); auto input_name = op_desc.Input("Input").front();
...@@ -42,34 +42,45 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -42,34 +42,45 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op,
auto fixed_ratios = auto fixed_ratios =
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios")); boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios"));
auto densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities")); auto densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities"));
std::vector<float> dens;
for (auto& ele : densities) {
dens.push_back(static_cast<float>(ele));
}
// lack flip // lack flip
auto clip = boost::get<bool>(op_desc.GetAttr("clip")); // auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances")); auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances"));
for (auto& ele : variances) {
LOG(INFO) << ele;
}
// lack img_h, img_w // lack img_h, img_w
auto step_h = boost::get<float>(op_desc.GetAttr("step_h")); auto step_h = boost::get<float>(op_desc.GetAttr("step_h"));
auto step_w = boost::get<float>(op_desc.GetAttr("step_w")); auto step_w = boost::get<float>(op_desc.GetAttr("step_w"));
auto offset = boost::get<float>(op_desc.GetAttr("offset")); auto offset = boost::get<float>(op_desc.GetAttr("offset"));
std::vector<std::string> order = {"MIN", "COM", "MAX"}; PTuple<std::string> t_order;
t_order.push_back("MIN");
t_order.push_back("COM");
t_order.push_back("MAX");
std::vector<float> temp_v = {}; std::vector<float> temp_v = {};
engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name}); engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", temp_v); engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_sizes", fixed_sizes); engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratios", fixed_ratios); engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios);
engine_->AddOpAttr<PTuple<int>>(op_name, "density", densities); engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens);
engine_->AddOpAttr(op_name, "is_flip", false); engine_->AddOpAttr(op_name, "is_flip", static_cast<bool>(false));
engine_->AddOpAttr(op_name, "is_clip", clip); engine_->AddOpAttr(op_name, "is_clip", static_cast<bool>(false));
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances); engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0)); engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0)); engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
engine_->AddOpAttr(op_name, "step_h", step_h); engine_->AddOpAttr(op_name, "step_h", step_h);
engine_->AddOpAttr(op_name, "step_w", step_w); engine_->AddOpAttr(op_name, "step_w", step_w);
engine_->AddOpAttr(op_name, "offset", offset); engine_->AddOpAttr(op_name, "offset", offset);
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", order); engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", t_order);
} }
} // namespace anakin } // namespace anakin
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "framework/core/types.h" #include "framework/core/types.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -68,6 +69,35 @@ class AnakinOpConverter { ...@@ -68,6 +69,35 @@ class AnakinOpConverter {
ConvertOp(op, parameters, scope, engine); ConvertOp(op, parameters, scope, engine);
} }
} }
// The scope here should be inited with the parameter vars.
void ConvertBlockToAnakinEngine(
framework::BlockDesc *block_desc, const framework::Scope &scope,
const std::vector<std::string> &inputs,
const std::unordered_set<std::string> &parameters,
const std::vector<std::string> &outputs, AnakinNvEngine *engine) {
framework::proto::BlockDesc *block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine);
engine->Freeze();
for (auto &input : inputs) {
if (parameters.count(input)) continue;
auto *var = block_desc->FindVar(input);
PADDLE_ENFORCE(var, "no variable called %s", input);
auto var_shape = var->GetShape();
PADDLE_ENFORCE(var_shape.size() == 4);
std::vector<int> input_shape;
for (int i = 0; i < var_shape.size(); i++) {
input_shape.push_back(var_shape[i]);
}
input_shape[0] = 1;
engine->SetInputShape(input, input_shape);
}
engine->Optimize();
engine->InitGraph();
}
void SetEngine(AnakinNvEngine *engine) { engine_ = engine; } void SetEngine(AnakinNvEngine *engine) { engine_ = engine; }
virtual ~AnakinOpConverter() {} virtual ~AnakinOpConverter() {}
......
...@@ -55,7 +55,7 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -55,7 +55,7 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
if (pool_type == "max") { if (pool_type == "max") {
anakin_pool_type = "MAX"; anakin_pool_type = "MAX";
} else if (pool_type == "avg") { } else if (pool_type == "avg") {
anakin_pool_type = "AVG"; anakin_pool_type = "AVGEXC";
} else { } else {
PADDLE_THROW("TensorRT unsupported pooling type!"); PADDLE_THROW("TensorRT unsupported pooling type!");
} }
......
...@@ -33,7 +33,7 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -33,7 +33,7 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
auto output = op_desc.Output("Out").front(); auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Softmax", {input}, {output}); engine_->AddOp(op_name, "Softmax", {input}, {output});
engine_->AddOpAttr(op_name, "axis", 1); engine_->AddOpAttr(op_name, "axis", 2);
} }
} // namespace anakin } // namespace anakin
......
...@@ -52,8 +52,9 @@ TEST(batch_norm_op, test) { ...@@ -52,8 +52,9 @@ TEST(batch_norm_op, test) {
desc.SetOutput("SavedVariance", {"batch_norm_save_variance"}); desc.SetOutput("SavedVariance", {"batch_norm_save_variance"});
float eps = 1e-5f; float eps = 1e-5f;
bool is_test = true;
desc.SetAttr("epsilon", eps); desc.SetAttr("epsilon", eps);
desc.SetAttr("is_test", true); desc.SetAttr("is_test", is_test);
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
......
...@@ -64,11 +64,52 @@ void test_pool2d(bool global_pooling, bool ceil_mode, ...@@ -64,11 +64,52 @@ void test_pool2d(bool global_pooling, bool ceil_mode,
validator.Execute(1); validator.Execute(1);
} }
void test_pool2d2(bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
auto* pool2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("pool2d");
ASSERT_TRUE(pool2d_converter);
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, scope);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d_x", {1, 1, 17, 17});
validator.DeclOutputVar("pool2d_out", {1, 1, 17, 17});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pool2d");
desc.SetInput("X", {"pool2d_x"});
desc.SetOutput("Out", {"pool2d_out"});
std::vector<int> ksize({3, 3});
std::vector<int> strides({1, 1});
std::vector<int> paddings({1, 1});
std::string pooling_t = pool_type;
desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", true);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(1);
}
TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); } TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); } TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); } TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); }
TEST(Pool2dOpConverter, avg_ceil_test2) { test_pool2d2(false, true, "avg"); }
} // namespace anakin } // namespace anakin
} // namespace inference } // namespace inference
......
...@@ -168,7 +168,7 @@ class AnakinConvertValidation { ...@@ -168,7 +168,7 @@ class AnakinConvertValidation {
outputs.insert({output, tensor}); outputs.insert({output, tensor});
} }
engine_->Execute(inputs, outputs); engine_->Execute(inputs, outputs, stream_);
int i_output = 0; int i_output = 0;
for (const auto& output : op_desc_->OutputArgumentNames()) { for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue; if (neglected_output.count(output)) continue;
......
...@@ -33,9 +33,12 @@ namespace inference { ...@@ -33,9 +33,12 @@ namespace inference {
namespace anakin { namespace anakin {
template <typename TargetT, Precision PrecisionType, OpRunType RunType> template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary) AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary,
int device)
: graph_(new AnakinGraphT<TargetT, PrecisionType>()), : graph_(new AnakinGraphT<TargetT, PrecisionType>()),
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {} net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {
device_ = device;
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType> template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::~AnakinEngine() {} AnakinEngine<TargetT, PrecisionType, RunType>::~AnakinEngine() {}
...@@ -63,33 +66,44 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp( ...@@ -63,33 +66,44 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp(
template <typename TargetT, Precision PrecisionType, OpRunType RunType> template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute( void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs, const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs) { const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream) {
for (const auto &input : inputs) { for (const auto &input : inputs) {
auto *tensor = input.second; auto *tensor = input.second;
auto *data = tensor->data<float>(); auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims()); auto fluid_input_shape = framework::vectorize2int(tensor->dims());
auto *anakin_input = net_->get_in(input.first); auto *anakin_input = net_->get_in(input.first);
auto anakin_input_shape = anakin_input->valid_shape(); auto net_shape = anakin_input->shape();
PADDLE_ENFORCE(tensor->numel(), anakin_input_shape.count(), if (tensor->numel() > net_shape.count()) {
"the fluid input size should be equal to anakin"); graph_->Reshape(input.first, fluid_input_shape);
net_.reset(new AnakinNetT<TargetT, PrecisionType, RunType>(true));
net_->init(*graph_);
anakin_input = net_->get_in(input.first);
}
anakin_input->reshape(fluid_input_shape);
net_shape = anakin_input->shape();
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0, ::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_input_shape); net_shape);
anakin_input->copy_from(tmp_anakin_tensor); anakin_input->share_from(tmp_anakin_tensor);
} }
net_->prediction();
for (const auto &output : outputs) { for (const auto &output : outputs) {
platform::CUDAPlace gpu_place(device_);
auto *tensor = output.second; auto *tensor = output.second;
auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims());
auto *anakin_output = net_->get_out(output.first); auto *anakin_output = net_->get_out(output.first);
auto *anakin_data = anakin_output->data();
auto anakin_output_shape = anakin_output->valid_shape(); auto anakin_output_shape = anakin_output->valid_shape();
PADDLE_ENFORCE(tensor->numel(), anakin_output_shape.count(), tensor->Resize(framework::make_ddim(anakin_output_shape));
"the fluid output size should be equal to anakin"); auto *fluid_data = tensor->mutable_data<float>(gpu_place);
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_output_shape); memory::Copy(gpu_place, static_cast<void *>(fluid_data), gpu_place,
anakin_output->share_from(tmp_anakin_tensor); static_cast<void *>(anakin_data),
tensor->numel() * sizeof(float), stream);
} }
net_->prediction();
cudaDeviceSynchronize(); cudaDeviceSynchronize();
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
...@@ -26,8 +27,12 @@ ...@@ -26,8 +27,12 @@
#include "framework/core/net/net.h" #include "framework/core/net/net.h"
#include "framework/core/types.h" #include "framework/core/types.h"
#include "framework/graph/graph.h" #include "framework/graph/graph.h"
#include "framework/graph/graph_global_mem.h"
#include "saber/saber_types.h" #include "saber/saber_types.h"
using anakin::Precision;
using anakin::saber::NV;
namespace anakin { namespace anakin {
template <typename, Precision, OpRunType> template <typename, Precision, OpRunType>
...@@ -50,7 +55,7 @@ class AnakinEngine { ...@@ -50,7 +55,7 @@ class AnakinEngine {
using GraphT = ::anakin::graph::Graph<TargetT, PrecisionType>; using GraphT = ::anakin::graph::Graph<TargetT, PrecisionType>;
public: public:
explicit AnakinEngine(bool need_summary = false); explicit AnakinEngine(bool need_summary = false, int device = 0);
~AnakinEngine(); ~AnakinEngine();
void InitGraph(); void InitGraph();
void SetInputShape(const std::string &name, std::vector<int> shape); void SetInputShape(const std::string &name, std::vector<int> shape);
...@@ -69,14 +74,50 @@ class AnakinEngine { ...@@ -69,14 +74,50 @@ class AnakinEngine {
void Freeze(); void Freeze();
void Optimize(); void Optimize();
void Save(std::string path) { graph_->save(path); } void Save(std::string path) { graph_->save(path); }
// void SaveSerializedData(std::string& data) { graph_->save_to_string(data);
// }
// void LoadSerializedData(const std::string& data) {
// graph_->load_from_string(data); }
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs, void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs); const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream);
private: private:
int device_;
std::unique_ptr<GraphT> graph_; std::unique_ptr<GraphT> graph_;
std::unique_ptr<NetT> net_; std::unique_ptr<NetT> net_;
}; };
class AnakinEngineManager {
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
public:
bool HasEngine(const std::string &name) const {
if (engines_.count(name) == 0) return false;
return engines_.at(name).get() != nullptr;
}
AnakinNvEngineT *Get(const std::string &name) const {
return engines_.at(name).get();
}
AnakinNvEngineT *Create(bool need_summary, int device,
std::string engine_name) {
std::unique_lock<std::mutex> lk(mut_);
auto *p = new AnakinEngine<NV, Precision::FP32>(need_summary, device);
engines_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto &item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<AnakinNvEngineT>> engines_;
std::mutex mut_;
};
} // namespace anakin } // namespace anakin
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -17,9 +17,6 @@ limitations under the License. */ ...@@ -17,9 +17,6 @@ limitations under the License. */
#include <map> #include <map>
#include "framework/core/net/net.h"
#include "framework/graph/graph.h"
#include "framework/graph/graph_global_mem.h"
#include "paddle/fluid/inference/anakin/engine.h" #include "paddle/fluid/inference/anakin/engine.h"
using anakin::graph::GraphGlobalMem; using anakin::graph::GraphGlobalMem;
...@@ -84,7 +81,9 @@ TEST_F(TestAnakinEngine, Execute) { ...@@ -84,7 +81,9 @@ TEST_F(TestAnakinEngine, Execute) {
auto *y_data = y.mutable_data<float>(platform::CUDAPlace()); auto *y_data = y.mutable_data<float>(platform::CUDAPlace());
std::map<std::string, framework::LoDTensor *> outputs = {{"y", &y}}; std::map<std::string, framework::LoDTensor *> outputs = {{"y", &y}};
engine_->Execute(inputs, outputs); cudaStream_t stream;
engine_->Execute(inputs, outputs, stream);
auto *y_data_gpu = y_data; auto *y_data_gpu = y_data;
float y_data_cpu[2]; float y_data_cpu[2];
cudaMemcpy(y_data_cpu, y_data_gpu, sizeof(float) * 2, cudaMemcpyDeviceToHost); cudaMemcpy(y_data_cpu, y_data_gpu, sizeof(float) * 2, cudaMemcpyDeviceToHost);
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -55,6 +56,7 @@ struct Argument { ...@@ -55,6 +56,7 @@ struct Argument {
using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>; using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
using fusion_statis_t = std::unordered_map<std::string, int>; using fusion_statis_t = std::unordered_map<std::string, int>;
using engine_opt_info_t = std::map<std::string, std::string>;
bool Has(const std::string& key) const { return valid_fields_.count(key); } bool Has(const std::string& key) const { return valid_fields_.count(key); }
...@@ -107,12 +109,14 @@ struct Argument { ...@@ -107,12 +109,14 @@ struct Argument {
private: \ private: \
unique_ptr_t field__##_; unique_ptr_t field__##_;
DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
// Model path // Model path
DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string); DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
// Model specified with program and parameters files. // Model specified with program and parameters files.
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(engine_opt_info, EngineOptInfo, engine_opt_info_t);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
...@@ -146,6 +150,8 @@ struct Argument { ...@@ -146,6 +150,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
bool); bool);
DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool);
// Memory optimized related. // Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
DECL_ARGUMENT_FIELD(static_memory_optim, StaticMemoryOptim, bool); DECL_ARGUMENT_FIELD(static_memory_optim, StaticMemoryOptim, bool);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -71,6 +72,11 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -71,6 +72,11 @@ void IRPassManager::CreatePasses(Argument *argument,
if (pass_name == "anakin_subgraph_pass") { if (pass_name == "anakin_subgraph_pass") {
pass->Set("program", pass->Set("program",
new framework::ProgramDesc *(&argument->main_program())); new framework::ProgramDesc *(&argument->main_program()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("engine_opt_info", new std::map<std::string, std::string>(
argument->engine_opt_info()));
pass->Set("predictor_id", new int(argument->predictor_id()));
} }
if (pass_name == "tensorrt_subgraph_pass") { if (pass_name == "tensorrt_subgraph_pass") {
...@@ -95,6 +101,9 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -95,6 +101,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("use_static_engine", pass->Set("use_static_engine",
new bool(argument->tensorrt_use_static_engine())); new bool(argument->tensorrt_use_static_engine()));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("engine_opt_info", new std::map<std::string, std::string>(
argument->engine_opt_info()));
} }
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/op_teller.h" #include "paddle/fluid/inference/anakin/op_teller.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h" #include "paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h"
...@@ -45,12 +46,20 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl( ...@@ -45,12 +46,20 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl(
return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
}; };
SubGraphFuser fuser(graph.get(), teller, 3 /* min_subgraph_size */); SubGraphFuser fuser(graph.get(), teller, 0 /* min_subgraph_size */);
fuser(); fuser();
std::vector<std::string> graph_param_names =
ExtractAnakinParameters(graph->Nodes());
// those parameter already exist in anakin, and should not have another copy
// in
// fluid.
std::vector<std::string> repetitive_params;
for (auto *node : graph->Nodes()) { for (auto *node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) { if (node->IsOp() && !Agent(node).subgraph()->empty()) {
CreateAnakinOp(node, graph.get()); CreateAnakinOp(node, graph.get(), graph_param_names, &repetitive_params);
std::unordered_set<const Node *> nodes2remove( std::unordered_set<const Node *> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove);
...@@ -64,13 +73,15 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl( ...@@ -64,13 +73,15 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl(
} }
} }
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove);
graph->Set(framework::ir::kRepetitiveParamAttr,
new std::vector<std::string>(repetitive_params));
return graph; return graph;
} }
std::string GenerateAnakinEngineKey( std::string GenerateAnakinEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_inputs, const std::set<std::string> &engine_outputs,
const std::set<std::string> &engine_outputs) { std::string id) {
std::string engine_hash_key = ""; std::string engine_hash_key = "";
for (auto name : engine_inputs) { for (auto name : engine_inputs) {
engine_hash_key += name; engine_hash_key += name;
...@@ -78,12 +89,15 @@ std::string GenerateAnakinEngineKey( ...@@ -78,12 +89,15 @@ std::string GenerateAnakinEngineKey(
for (auto name : engine_outputs) { for (auto name : engine_outputs) {
engine_hash_key += name; engine_hash_key += name;
} }
engine_hash_key += id;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key)); auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key; return engine_key;
} }
void AnakinSubgraphPass::CreateAnakinOp(framework::ir::Node *node, void AnakinSubgraphPass::CreateAnakinOp(
Graph *graph) const { framework::ir::Node *node, Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const {
auto *op_desc = node->Op(); auto *op_desc = node->Op();
auto &subgraph = *Agent(node).subgraph(); auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty()); PADDLE_ENFORCE(!subgraph.empty());
...@@ -117,10 +131,16 @@ void AnakinSubgraphPass::CreateAnakinOp(framework::ir::Node *node, ...@@ -117,10 +131,16 @@ void AnakinSubgraphPass::CreateAnakinOp(framework::ir::Node *node,
// is unique. // is unique.
std::set<std::string> input_names; std::set<std::string> input_names;
std::set<std::string> input_names_with_id; std::set<std::string> input_names_with_id;
std::vector<std::string> params;
for (auto *x : node->inputs) { for (auto *x : node->inputs) {
input_names.insert(x->Name()); input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id())); input_names_with_id.insert(x->Name() + std::to_string(x->id()));
if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) {
params.push_back(x->Name());
}
} }
std::copy(params.begin(), params.end(),
std::back_inserter(*repetitive_params));
op_desc->SetInput( op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end())); "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
...@@ -231,10 +251,25 @@ void AnakinSubgraphPass::CreateAnakinOp(framework::ir::Node *node, ...@@ -231,10 +251,25 @@ void AnakinSubgraphPass::CreateAnakinOp(framework::ir::Node *node,
SetAttr(op_desc->Proto(), "parameters", SetAttr(op_desc->Proto(), "parameters",
ExtractAnakinParameters(graph->Nodes())); ExtractAnakinParameters(graph->Nodes()));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
auto engine_key = int predictor_id = Get<int>("predictor_id");
GenerateAnakinEngineKey(input_names_with_id, output_names_with_id); auto engine_key = GenerateAnakinEngineKey(
input_names_with_id, output_names_with_id, std::to_string(predictor_id));
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
auto *anakin_engine =
inference::Singleton<anakin::AnakinEngineManager>::Global().Create(
true, Get<int>("gpu_device_id"), engine_key);
auto *scope = param_scope();
std::unordered_set<std::string> param_set(params.begin(), params.end());
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
inference::Singleton<inference::anakin::AnakinOpConverter>::Global()
.ConvertBlockToAnakinEngine(
&block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, anakin_engine);
} }
std::vector<std::string> ExtractAnakinParameters( std::vector<std::string> ExtractAnakinParameters(
...@@ -246,7 +281,7 @@ std::vector<std::string> ExtractAnakinParameters( ...@@ -246,7 +281,7 @@ std::vector<std::string> ExtractAnakinParameters(
for (const auto &node : nodes) { for (const auto &node : nodes) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
std::string op_type = node->Op()->Type(); std::string op_type = node->Op()->Type();
if (op_type == "feed") { if (op_type == "feed" || op_type == "fetch") {
std::vector<std::string> output_names = node->Op()->OutputArgumentNames(); std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
std::copy(output_names.begin(), output_names.end(), std::copy(output_names.begin(), output_names.end(),
std::back_inserter(feed_outputs)); std::back_inserter(feed_outputs));
......
...@@ -15,8 +15,13 @@ ...@@ -15,8 +15,13 @@
#pragma once #pragma once
#include <paddle/fluid/framework/ir/fuse_pass_base.h> #include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/anakin/engine.h"
using anakin::Precision;
using anakin::saber::NV;
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -27,8 +32,9 @@ class AnakinSubgraphPass : public framework::ir::FusePassBase { ...@@ -27,8 +32,9 @@ class AnakinSubgraphPass : public framework::ir::FusePassBase {
std::unique_ptr<framework::ir::Graph> graph) const override; std::unique_ptr<framework::ir::Graph> graph) const override;
private: private:
void CreateAnakinOp(framework::ir::Node *x, void CreateAnakinOp(framework::ir::Node *x, framework::ir::Graph *graph,
framework::ir::Graph *graph) const; const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const;
void CleanIntermediateOutputs(framework::ir::Node *node); void CleanIntermediateOutputs(framework::ir::Node *node);
}; };
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <map>
#include <set> #include <set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -219,7 +220,17 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -219,7 +220,17 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
SetAttr(op_desc->Proto(), "engine_serialized_data", std::string("")); bool load_from_memory = Get<bool>("model_from_memory");
std::string trt_engine_serialized_data = "";
if (load_from_memory) {
std::map<std::string, std::string> engine_opt_info =
Get<std::map<std::string, std::string>>("engine_opt_info");
if (engine_opt_info.count(engine_key)) {
trt_engine_serialized_data = engine_opt_info[engine_key];
}
}
SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data);
std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator; std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
if (enable_int8 && calibration_data.size() != 0) { if (enable_int8 && calibration_data.size() != 0) {
...@@ -230,10 +241,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -230,10 +241,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// When in int8 mode and calibration_mode, the program just produce the // When in int8 mode and calibration_mode, the program just produce the
// calibration table data. // calibration table data.
bool calibration_mode = (enable_int8 && calibration_data.size() == 0); bool calibration_mode = (enable_int8 && calibration_data.size() == 0);
if (!calibration_mode && use_static_engine) { if (!calibration_mode && use_static_engine &&
trt_engine_serialized_data.empty()) {
std::copy(params.begin(), params.end(), std::copy(params.begin(), params.end(),
std::back_inserter(*repetitive_params)); std::back_inserter(*repetitive_params));
std::string trt_engine_serialized_data = GetTrtEngineSerializedData( trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key); Get<std::string>("model_opt_cache_dir"), engine_key);
if (trt_engine_serialized_data.empty()) { if (trt_engine_serialized_data.empty()) {
......
...@@ -64,8 +64,3 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI ...@@ -64,8 +64,3 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
anakin_target(inference_anakin_api) anakin_target(inference_anakin_api)
anakin_target(inference_anakin_api_shared) anakin_target(inference_anakin_api_shared)
endif() endif()
if (WITH_ANAKIN_SUBGRAPH)
inference_analysis_test(test_anakin_model SRCS mobilenet_test.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(anakin_conv_model SRCS conv_anakin_test.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(life_feature_test SRCS life_feature_test.cc EXTRA_DEPS paddle_fluid)
endif()
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
extern const std::vector<std::string> kAnakinSubgraphPasses;
PassStrategy *AnalysisConfig::pass_builder() const { PassStrategy *AnalysisConfig::pass_builder() const {
if (!pass_builder_.get()) { if (!pass_builder_.get()) {
...@@ -230,6 +231,20 @@ void AnalysisConfig::Update() { ...@@ -230,6 +231,20 @@ void AnalysisConfig::Update() {
} }
} }
if (use_anakin_) {
PADDLE_ENFORCE(!use_tensorrt_,
"Anakin sub-graph and TensorRT sub-graph are not allowed to "
"run at the same time!");
PADDLE_ENFORCE(
use_gpu_,
"Anakin sub-graph engine need gpu, please use the EnableGpu API.");
pass_builder()->ClearPasses();
for (const auto &pass : kAnakinSubgraphPasses) {
pass_builder()->AppendPass(pass);
}
}
if (ir_debug_) { if (ir_debug_) {
pass_builder()->TurnOnDebug(); pass_builder()->TurnOnDebug();
} }
...@@ -266,7 +281,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -266,7 +281,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << specify_input_name_; ss << specify_input_name_;
ss << cpu_math_library_num_threads_; ss << cpu_math_library_num_threads_;
ss << use_anakin_;
return ss.str(); return ss.str();
} }
...@@ -316,6 +331,11 @@ void AnalysisConfig::SetModelBuffer(const char *prog_buffer, ...@@ -316,6 +331,11 @@ void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
Update(); Update();
} }
void AnalysisConfig::SetEngineOptInfo(
std::map<std::string, std::string> engine_opt_info) {
engine_opt_info_ = engine_opt_info;
}
NativeConfig AnalysisConfig::ToNativeConfig() const { NativeConfig AnalysisConfig::ToNativeConfig() const {
NativeConfig config; NativeConfig config;
config.model_dir = model_dir_; config.model_dir = model_dir_;
...@@ -332,5 +352,8 @@ void AnalysisConfig::SwitchIrDebug(int x) { ...@@ -332,5 +352,8 @@ void AnalysisConfig::SwitchIrDebug(int x) {
ir_debug_ = x; ir_debug_ = x;
Update(); Update();
} }
void AnalysisConfig::EnableAnakinEngine() {
use_anakin_ = true;
Update();
}
} // namespace paddle } // namespace paddle
...@@ -351,7 +351,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -351,7 +351,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetStaticMemoryOptimForceUpdate( argument_.SetStaticMemoryOptimForceUpdate(
config_.static_memory_optim_force_update_); config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
argument_.SetEngineOptInfo(config_.engine_opt_info_);
// Analyze inference_program // Analyze inference_program
argument_.SetUseAnakin(config_.anakin_engine_enabled());
argument_.SetPredictorID(predictor_id_);
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_.SetModelDir(config_.model_dir());
} else { } else {
...@@ -375,6 +378,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -375,6 +378,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
} }
if (config_.use_gpu() && config_.anakin_engine_enabled()) {
LOG(INFO) << "Anakin subgraph engine is enabled";
}
if (config_.use_mkldnn_) { if (config_.use_mkldnn_) {
LOG(INFO) << "MKLDNN is enabled"; LOG(INFO) << "MKLDNN is enabled";
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
...@@ -404,7 +411,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -404,7 +411,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
VLOG(3) << "create AnalysisConfig"; VLOG(3) << "create AnalysisConfig";
if (config.use_gpu()) { if (config.use_gpu()) {
// 1. GPU memory // 1. GPU memory
PADDLE_ENFORCE_GT(config.memory_pool_init_size_mb(), 0.f); PADDLE_ENFORCE_GE(config.memory_pool_init_size_mb(), 0.f);
PADDLE_ENFORCE_GE(config.gpu_device_id(), 0, "Invalid device id %d", PADDLE_ENFORCE_GE(config.gpu_device_id(), 0, "Invalid device id %d",
config.gpu_device_id()); config.gpu_device_id());
std::vector<std::string> flags; std::vector<std::string> flags;
......
...@@ -45,7 +45,9 @@ using framework::NaiveExecutor; ...@@ -45,7 +45,9 @@ using framework::NaiveExecutor;
*/ */
class AnalysisPredictor : public PaddlePredictor { class AnalysisPredictor : public PaddlePredictor {
public: public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {} explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {
predictor_id_ = inference::GetUniqueId();
}
~AnalysisPredictor(); ~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope, bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
...@@ -152,6 +154,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -152,6 +154,7 @@ class AnalysisPredictor : public PaddlePredictor {
const size_t max_shape_collect_count_{1000}; const size_t max_shape_collect_count_{1000};
int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true. int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true.
std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_; std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_;
int predictor_id_;
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#pragma once #pragma once
#include <cassert> #include <cassert>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
/*! \file */ /*! \file */
...@@ -140,6 +142,14 @@ struct AnalysisConfig { ...@@ -140,6 +142,14 @@ struct AnalysisConfig {
/** A boolean state telling whether the TensorRT engine is used. /** A boolean state telling whether the TensorRT engine is used.
*/ */
bool tensorrt_engine_enabled() const { return use_tensorrt_; } bool tensorrt_engine_enabled() const { return use_tensorrt_; }
/**
* \brief Turn on the usage of Anakin sub-graph engine.
*/
void EnableAnakinEngine();
/** A boolean state indicating whether the Anakin sub-graph engine is used.
*/
bool anakin_engine_enabled() const { return use_anakin_; }
/** \brief Control whether to debug IR graph analysis phase. /** \brief Control whether to debug IR graph analysis phase.
* *
...@@ -185,6 +195,7 @@ struct AnalysisConfig { ...@@ -185,6 +195,7 @@ struct AnalysisConfig {
/** A boolean state telling whether the model is set from the CPU memory. /** A boolean state telling whether the model is set from the CPU memory.
*/ */
bool model_from_memory() const { return model_from_memory_; } bool model_from_memory() const { return model_from_memory_; }
void SetEngineOptInfo(std::map<std::string, std::string> engine_opt_info);
/** Turn on memory optimize /** Turn on memory optimize
* NOTE still in development, will release latter. * NOTE still in development, will release latter.
...@@ -258,6 +269,8 @@ struct AnalysisConfig { ...@@ -258,6 +269,8 @@ struct AnalysisConfig {
std::string serialized_info_cache_; std::string serialized_info_cache_;
mutable std::unique_ptr<PassStrategy> pass_builder_; mutable std::unique_ptr<PassStrategy> pass_builder_;
bool use_anakin_{false};
std::map<std::string, std::string> engine_opt_info_;
}; };
} // namespace paddle } // namespace paddle
...@@ -68,6 +68,17 @@ void GpuPassStrategy::EnableMKLDNN() { ...@@ -68,6 +68,17 @@ void GpuPassStrategy::EnableMKLDNN() {
LOG(ERROR) << "GPU not support MKLDNN yet"; LOG(ERROR) << "GPU not support MKLDNN yet";
} }
// The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
"simplify_anakin_detection_pattern_pass3", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"anakin_subgraph_pass",
});
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"infer_clean_graph_pass", // "infer_clean_graph_pass", //
...@@ -120,4 +131,5 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -120,4 +131,5 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
}); });
use_gpu_ = false; use_gpu_ = false;
} }
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
} // namespace paddle } // namespace paddle
...@@ -45,6 +45,7 @@ class PaddlePassBuilder { ...@@ -45,6 +45,7 @@ class PaddlePassBuilder {
/** Delete all the passes that has type `pass_type`. */ /** Delete all the passes that has type `pass_type`. */
void DeletePass(const std::string &pass_type); void DeletePass(const std::string &pass_type);
void ClearPasses();
/** Append an analysis pass. */ /** Append an analysis pass. */
void AppendAnalysisPass(const std::string &pass); void AppendAnalysisPass(const std::string &pass);
...@@ -142,4 +143,6 @@ class GpuPassStrategy : public PassStrategy { ...@@ -142,4 +143,6 @@ class GpuPassStrategy : public PassStrategy {
virtual ~GpuPassStrategy() = default; virtual ~GpuPassStrategy() = default;
}; };
extern const std::vector<std::string> kAnakinSubgraphPasses;
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -52,8 +53,9 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -52,8 +53,9 @@ class AnakinEngineOp : public framework::OperatorBase {
private: private:
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_; std::unordered_set<std::string> param_names_;
mutable std::unique_ptr<AnakinNvEngineT> anakin_engine_; mutable AnakinNvEngineT *anakin_engine_;
std::string engine_key_; std::string engine_key_;
std::string engine_serialized_data_;
public: public:
AnakinEngineOp(const std::string &type, AnakinEngineOp(const std::string &type,
...@@ -67,6 +69,7 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -67,6 +69,7 @@ class AnakinEngineOp : public framework::OperatorBase {
for (const auto &param : params) { for (const auto &param : params) {
param_names_.insert(param); param_names_.insert(param);
} }
anakin_engine_ = nullptr;
} }
protected: protected:
...@@ -77,12 +80,12 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -77,12 +80,12 @@ class AnakinEngineOp : public framework::OperatorBase {
void RunAnakin(const framework::Scope &scope, void RunAnakin(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
if (anakin_engine_.get() == nullptr) { auto *engine = GetEngine(scope, dev_place);
anakin_engine_.reset(new AnakinEngine<NV, Precision::FP32>(true)); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
Prepare(scope, dev_place, anakin_engine_.get()); auto &dev_ctx = *pool.Get(dev_place);
} auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
auto *engine = anakin_engine_.get();
PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs");
std::vector<std::string> output_maps = std::vector<std::string> output_maps =
...@@ -95,24 +98,48 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -95,24 +98,48 @@ class AnakinEngineOp : public framework::OperatorBase {
auto &t = auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x); inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
auto t_shape = framework::vectorize(t.dims()); auto t_shape = framework::vectorize(t.dims());
auto *anakin_input = engine->Net()->get_in(x);
auto net_shape = anakin_input->shape();
size_t anakin_net_input_size = net_shape.count() * sizeof(float);
size_t fluid_input_size = t.memory_size();
if (fluid_input_size < anakin_net_input_size) {
framework::LoDTensor temp_t;
auto t_dims = t.dims();
temp_t.Resize(t_dims);
TensorCopySync(t, dev_place, &temp_t);
t.Resize(framework::make_ddim(net_shape));
t.mutable_data<float>(dev_place);
TensorCopySync(temp_t, dev_place, &t);
}
inputs.insert({x, &t}); inputs.insert({x, &t});
} }
std::map<std::string, framework::LoDTensor *> outputs; std::map<std::string, framework::LoDTensor *> outputs;
int output_index = 0; int output_index = 0;
for (const auto &y : Outputs("Ys")) { for (const auto &y : Outputs("Ys")) {
std::vector<int> ddim = // std::vector<int> ddim =
engine->Net()->get_out(output_maps[output_index])->valid_shape(); // engine->Net()->get_out(output_maps[output_index])->valid_shape();
// we need get the output anakin output shape. // we need get the output anakin output shape.
auto *fluid_v = scope.FindVar(y); auto *fluid_v = scope.FindVar(y);
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>(); auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(framework::make_ddim(ddim)); // fluid_t->Resize(framework::make_ddim(ddim));
fluid_t->mutable_data<float>(boost::get<platform::CUDAPlace>(dev_place)); // fluid_t->mutable_data<float>(boost::get<platform::CUDAPlace>(dev_place));
outputs.insert({output_maps[output_index], fluid_t}); outputs.insert({output_maps[output_index], fluid_t});
output_index += 1; output_index += 1;
} }
engine->Execute(inputs, outputs); engine->Execute(inputs, outputs, stream);
}
AnakinNvEngineT *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (anakin_engine_ == nullptr) {
anakin_engine_ =
inference::Singleton<inference::anakin::AnakinEngineManager>::Global()
.Get(engine_key_);
}
return anakin_engine_;
} }
void Prepare(const framework::Scope &scope, const platform::Place &dev_place, void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
...@@ -128,8 +155,6 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -128,8 +155,6 @@ class AnakinEngineOp : public framework::OperatorBase {
inference::Singleton<inference::anakin::AnakinOpConverter>::Global() inference::Singleton<inference::anakin::AnakinOpConverter>::Global()
.ConvertBlock(block_desc, param_names_, scope, engine); .ConvertBlock(block_desc, param_names_, scope, engine);
engine->Freeze(); engine->Freeze();
engine->Optimize();
for (const auto &x : Inputs("Xs")) { for (const auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue; if (param_names_.count(x)) continue;
auto &t = auto &t =
...@@ -142,6 +167,9 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -142,6 +167,9 @@ class AnakinEngineOp : public framework::OperatorBase {
} }
engine->SetInputShape(x, t_shape); engine->SetInputShape(x, t_shape);
} }
engine->Optimize();
engine->InitGraph(); engine->InitGraph();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册