提交 050a68dd 编写于 作者: N nhzlx

fix comments

test=develop
上级 fcc93d96
...@@ -63,7 +63,6 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -63,7 +63,6 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
Graph *graph) const { Graph *graph) const {
auto *op_desc = node->Op(); auto *op_desc = node->Op();
static int counter{0};
auto &subgraph = *Agent(node).subgraph(); auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty()); PADDLE_ENFORCE(!subgraph.empty());
...@@ -192,8 +191,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -192,8 +191,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
block_desc.Proto()->SerializeAsString()); block_desc.Proto()->SerializeAsString());
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size")); SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size")); SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(op_desc->Proto(), "engine_uniq_key",
"trt-" + std::to_string(counter++));
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes())); SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
} }
......
...@@ -29,7 +29,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -29,7 +29,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable(); AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph."); AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch_size", "the maximum batch size."); AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size."); AddAttr<int>("workspace_size", "the workspace size.");
AddComment("TensorRT engine operator."); AddComment("TensorRT engine operator.");
......
...@@ -65,7 +65,6 @@ using inference::tensorrt::TensorRTEngine; ...@@ -65,7 +65,6 @@ using inference::tensorrt::TensorRTEngine;
class TensorRTEngineOp : public framework::OperatorBase { class TensorRTEngineOp : public framework::OperatorBase {
private: private:
std::string engine_name_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_; std::unordered_set<std::string> param_names_;
mutable std::unique_ptr<TensorRTEngine> trt_engine_; mutable std::unique_ptr<TensorRTEngine> trt_engine_;
...@@ -78,7 +77,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -78,7 +77,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) { : framework::OperatorBase(type, inputs, outputs, attrs) {
engine_name_ = Attr<std::string>("engine_uniq_key");
input_names_ = Inputs("Xs"); input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size"); max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int>("workspace_size"); workspace_size_ = Attr<int>("workspace_size");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册