提交 32a9e050 编写于 作者: N nhzlx

mapping the variable name inside the subgraph

上级 d50f776b
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace paddle { namespace paddle {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true,
"Enable subgraph to TensorRT engine for acceleration"); "Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./", DEFINE_string(inference_analysis_graphviz_log_root, "./",
...@@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
// TODO(Superjomn) set the key with pass reprs. // TODO(Superjomn) set the key with pass reprs.
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass); AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
auto trt_teller = [](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;
return static_cast<const Function*>(node)->func_type() == "mul";
const auto* func = static_cast<const Function*>(node);
if (teller_set.count(func->func_type()))
return true;
else {
return false;
}
}; };
AddPass("tensorrt-subgraph-marker", AddPass("tensorrt-subgraph-marker",
new TensorRTSubgraphNodeMarkPass(trt_teller)); new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size"); DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size");
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
namespace analysis { namespace analysis {
...@@ -87,27 +87,90 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { ...@@ -87,27 +87,90 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
} }
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
const framework::proto::BlockDesc &block) { framework::proto::BlockDesc &block) {
static int counter{0}; static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock()); PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc; framework::OpDesc desc;
auto *func = static_cast<FunctionBlock *>(node); auto *func = static_cast<FunctionBlock *>(node);
// collect inputs // collect inputs
std::vector<std::string> io; std::unordered_set<std::string> input_names;
for (auto *x : func->inlinks) { for (auto *x : func->inlinks) {
io.push_back(x->name()); input_names.insert(x->name());
} }
desc.SetInput("Xs", io); desc.SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
// collect outputs std::unordered_set<std::string> output_names;
io.clear();
for (auto *x : func->outlinks) { for (auto *x : func->outlinks) {
io.push_back(x->name()); output_names.insert(x->name());
} }
desc.SetOutput("Ys", io);
std::vector<std::string> output_temp(output_names.begin(),
output_names.end());
desc.SetOutput("Ys", output_temp);
desc.SetType("tensorrt_engine"); desc.SetType("tensorrt_engine");
std::unordered_map<std::string, std::string> output_name_map;
auto subgraph_nodes = func->subgraph;
for (int index = 0; index < block.ops_size(); index++) {
framework::proto::OpDesc *op = block.mutable_ops(index);
// auto &op = block.mutable_ops(index);
auto correspond_node = subgraph_nodes[index];
PADDLE_ENFORCE_EQ(correspond_node->name(), op->type());
std::unordered_map<std::string, size_t> var2id;
for (auto *in_var : correspond_node->inlinks) {
var2id[in_var->name()] = in_var->id();
}
// TODO(zhaolong): add comments
for (int i = 0; i < op->inputs_size(); i++) {
framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i);
// auto &in_var = op->mutable_inputs(i);
std::vector<std::string> replaced_names;
for (int k = 0; k < in_var->arguments_size(); k++) {
std::string arg_value = in_var->arguments(k);
if (input_names.count(arg_value)) {
replaced_names.push_back(arg_value);
} else {
replaced_names.push_back(arg_value +
std::to_string(var2id[arg_value]));
}
}
in_var->clear_arguments();
for (size_t k = 0; k < replaced_names.size(); k++) {
in_var->add_arguments(replaced_names[k]);
}
}
var2id.clear();
for (auto out_var : correspond_node->outlinks) {
var2id[out_var->name()] = out_var->id();
}
for (int i = 0; i < op->outputs_size(); i++) {
framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
std::vector<std::string> replaced_names;
for (int k = 0; k < out_var->arguments_size(); k++) {
std::string arg_value = out_var->arguments(k);
if (output_names.count(arg_value)) {
output_name_map[arg_value] =
arg_value + std::to_string(var2id[arg_value]);
}
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
}
out_var->clear_arguments();
for (size_t k = 0; k < replaced_names.size(); k++) {
out_var->add_arguments(replaced_names[k]);
}
}
}
std::vector<std::string> output_mapping;
for (auto name : output_names) {
PADDLE_ENFORCE(output_name_map.count(name) != 0);
output_mapping.push_back(output_name_map[name]);
}
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc"); PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
// Set attrs // Set attrs
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString()); SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
...@@ -115,6 +178,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -115,6 +178,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize); SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size); SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
node->SetPbMsg(desc.Proto()->SerializeAsString()); node->SetPbMsg(desc.Proto()->SerializeAsString());
} }
...@@ -146,11 +210,13 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { ...@@ -146,11 +210,13 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
LOG(INFO) << "transformed variable size: " LOG(INFO) << "transformed variable size: "
<< block_desc.Proto()->vars().size(); << block_desc.Proto()->vars().size();
// copy ops. // copy ops.
for (auto *node : block_node->subgraph) { for (auto *node : block_node->subgraph) {
auto *op = block_desc.AppendOp(); auto *op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty()); PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg()); op->Proto()->ParseFromString(node->pb_msg());
} }
*block_desc.Proto()->mutable_vars() = *block_desc.Proto()->mutable_vars() =
argument_->origin_program_desc->blocks(0).vars(); argument_->origin_program_desc->blocks(0).vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
......
...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { ...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) { for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) {
if (node.attr(kMarkerAttrName).Bool()) { if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
......
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
activation_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
...@@ -55,7 +55,6 @@ class OpConverter { ...@@ -55,7 +55,6 @@ class OpConverter {
it = Registry<OpConverter>::Lookup("fc"); it = Registry<OpConverter>::Lookup("fc");
} }
} }
if (op_desc.Type().find("elementwise") != std::string::npos) { if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set<std::string> add_tensor_op_set{ static std::unordered_set<std::string> add_tensor_op_set{
"add", "mul", "sub", "div", "max", "min", "pow"}; "add", "mul", "sub", "div", "max", "min", "pow"};
...@@ -72,6 +71,8 @@ class OpConverter { ...@@ -72,6 +71,8 @@ class OpConverter {
"Unsupported elementwise type" + op_type); "Unsupported elementwise type" + op_type);
it = it =
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight"); Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight");
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
} else { } else {
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0, PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type); "Unsupported elementwise type" + op_type);
......
...@@ -55,18 +55,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) { ...@@ -55,18 +55,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
"TensorRT' tensor input requires at least 2 dimensions"); "TensorRT' tensor input requires at least 2 dimensions");
PADDLE_ENFORCE_LE(shape.size(), 4UL, PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions"); "TensorRT' tensor input requires at most 4 dimensions");
PADDLE_ENFORCE_EQ(shape.size(), 4UL);
switch (shape.size()) { return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
case 2:
return nvinfer1::Dims2(1, shape[1]);
case 3:
return nvinfer1::Dims3(1, shape[1], shape[2]);
case 4:
return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]);
default:
return nvinfer1::Dims();
}
return nvinfer1::Dims();
} }
} // namespace } // namespace
...@@ -86,6 +76,9 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -86,6 +76,9 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
parameters.insert(param); parameters.insert(param);
} }
std::vector<std::string> output_maps =
context.Attr<std::vector<std::string>>("output_name_mapping");
// TODO(Superjomn) replace this with a different stream // TODO(Superjomn) replace this with a different stream
auto *engine = Singleton<TRT_EngineManager>::Global().Create( auto *engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/, max_batch, max_workspace, nullptr /*engine hold its own stream*/,
...@@ -97,6 +90,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -97,6 +90,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
// Add inputs // Add inputs
VLOG(4) << "declare inputs"; VLOG(4) << "declare inputs";
for (auto &input : context.Inputs("Xs")) { for (auto &input : context.Inputs("Xs")) {
if (parameters.count(input)) continue;
VLOG(4) << "declare input " << input; VLOG(4) << "declare input " << input;
auto *var = block.FindVar(input); auto *var = block.FindVar(input);
// TensorRT engine need to create parameters. The parameter's description // TensorRT engine need to create parameters. The parameter's description
...@@ -122,7 +116,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -122,7 +116,7 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
block_desc, parameters, context.scope(), engine); block_desc, parameters, context.scope(), engine);
// Add outputs // Add outputs
for (auto &output : context.Outputs("Ys")) { for (auto &output : output_maps) {
engine->DeclareOutput(output); engine->DeclareOutput(output);
} }
......
...@@ -66,8 +66,17 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -66,8 +66,17 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
context.Attr<int>("max_batch")); context.Attr<int>("max_batch"));
std::vector<std::string> output_maps =
context.Attr<std::vector<std::string>>("output_name_mapping");
auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters;
for (const auto& param : params) {
parameters.insert(param);
}
// Convert input tensor from fluid to engine. // Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) { for (const auto& x : context.Inputs("Xs")) {
if (parameters.count(x)) continue;
// convert input and copy to TRT engine's buffer // convert input and copy to TRT engine's buffer
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>( auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), x); context.scope(), x);
...@@ -82,10 +91,12 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -82,10 +91,12 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// Execute the engine. // Execute the engine.
PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0); PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0);
engine->Execute(FLAGS_tensorrt_engine_batch_size); engine->Execute(FLAGS_tensorrt_engine_batch_size);
// Convert output tensor from engine to fluid // Convert output tensor from engine to fluid
int output_index = 0;
for (const auto& y : context.Outputs("Ys")) { for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid. // convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine->GetITensor(y); nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]);
auto dims = trt_t->getDimensions(); auto dims = trt_t->getDimensions();
// Use the output ITensor's dims to reshape the Fluid Tensor. // Use the output ITensor's dims to reshape the Fluid Tensor.
std::vector<int> ddim(dims.d, dims.d + dims.nbDims); std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
...@@ -102,7 +113,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -102,7 +113,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// TODO(Superjomn) change this float to dtype size. // TODO(Superjomn) change this float to dtype size.
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) *
FLAGS_tensorrt_engine_batch_size; FLAGS_tensorrt_engine_batch_size;
engine->GetOutputInCPU(y, engine->GetOutputInCPU(output_maps[output_index],
fluid_t->mutable_data<float>(platform::CPUPlace()), fluid_t->mutable_data<float>(platform::CPUPlace()),
size * sizeof(float)); size * sizeof(float));
//} else { //} else {
...@@ -110,6 +121,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -110,6 +121,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()), // y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
// size * sizeof(float)); // size * sizeof(float));
//} //}
output_index += 1;
} }
cudaStreamSynchronize(*engine->stream()); cudaStreamSynchronize(*engine->stream());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册