未验证 提交 5082642b 编写于 作者: Y Yan Chunwei 提交者: GitHub

feature/analysis to support sub-graph for TRT engine (#11538)

上级 bc28cf61
...@@ -18,7 +18,7 @@ if(APPLE) ...@@ -18,7 +18,7 @@ if(APPLE)
endif(APPLE) endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api) set(inference_deps paddle_inference_api paddle_fluid_api paddle_inference_tensorrt_subgraph_engine)
function(inference_api_test TARGET_NAME) function(inference_api_test TARGET_NAME)
if (WITH_TESTING) if (WITH_TESTING)
...@@ -50,6 +50,14 @@ cc_test(test_paddle_inference_api ...@@ -50,6 +50,14 @@ cc_test(test_paddle_inference_api
inference_api_test(test_paddle_inference_api_impl inference_api_test(test_paddle_inference_api_impl
ARGS test_word2vec test_image_classification) ARGS test_word2vec test_image_classification)
if(WITH_GPU AND TENSORRT_FOUND)
cc_library(paddle_inference_tensorrt_subgraph_engine
SRCS paddle_inference_api_tensorrt_subgraph_engine.cc
DEPS paddle_inference_api analysis tensorrt_engine paddle_inference_api paddle_fluid_api)
inference_api_test(test_paddle_inference_api_tensorrt_subgraph_engine ARGS test_word2vec)
endif()
if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's, # Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to # so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to
......
...@@ -73,12 +73,12 @@ struct PaddleTensor { ...@@ -73,12 +73,12 @@ struct PaddleTensor {
}; };
enum class PaddleEngineKind { enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility. kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference. kAnakin, // Use Anakin for inference.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
// TODO(Superjomn) support following engines latter. // TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference. // kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin. // kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
}; };
/* /*
...@@ -130,6 +130,11 @@ struct AnakinConfig : public PaddlePredictor::Config { ...@@ -130,6 +130,11 @@ struct AnakinConfig : public PaddlePredictor::Config {
int max_batch_size{-1}; int max_batch_size{-1};
}; };
struct TensorRTConfig : public NativeConfig {
// Determine whether a subgraph will be executed by TRT.
int min_subgraph_size{1};
};
// A factory to help create different predictors. // A factory to help create different predictors.
// //
// FOR EXTENSION DEVELOPER: // FOR EXTENSION DEVELOPER:
......
...@@ -89,6 +89,7 @@ bool NativePaddlePredictor::Init( ...@@ -89,6 +89,7 @@ bool NativePaddlePredictor::Init(
LOG(ERROR) << "fail to load inference model."; LOG(ERROR) << "fail to load inference model.";
return false; return false;
} }
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
executor_->CreateVariables( executor_->CreateVariables(
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0); *inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
...@@ -119,6 +120,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -119,6 +120,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
return false; return false;
} }
for (size_t i = 0; i < feed_target_names_.size(); ++i) { for (size_t i = 0; i < feed_target_names_.size(); ++i) {
VLOG(4) << "setting " << i << "-th target";
feed_targets[feed_target_names_[i]] = &feeds[i]; feed_targets[feed_target_names_[i]] = &feeds[i];
} }
// get fetch variable // get fetch variable
...@@ -130,14 +132,16 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -130,14 +132,16 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
} }
// Run the inference program // Run the inference program
// if share variables, we need not create variables // if share variables, we need not create variables
VLOG(4) << "Run prepared context";
executor_->RunPreparedContext( executor_->RunPreparedContext(
ctx_.get(), ctx_.get(),
sub_scope_ != nullptr ? sub_scope_ : scope_.get(), sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
&feed_targets, &feed_targets,
&fetch_targets, &fetch_targets,
false /* don't create variable eatch time */); false /* don't create variable eatch time */);
VLOG(4) << "Finish prepared context";
if (!GetFetch(fetchs, output_data)) { if (!GetFetch(fetchs, output_data)) {
LOG(ERROR) << "fail to get fetchs"; LOG(ERROR) << "fail to get fetches";
return false; return false;
} }
VLOG(3) << "predict cost: " << timer.toc() << "ms"; VLOG(3) << "predict cost: " << timer.toc() << "ms";
......
...@@ -44,7 +44,7 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -44,7 +44,7 @@ class NativePaddlePredictor : public PaddlePredictor {
~NativePaddlePredictor() override; ~NativePaddlePredictor() override;
private: protected:
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<framework::LoDTensor> *feeds); std::vector<framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs, bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/contrib/inference/paddle_inference_api.h"
#include "paddle/contrib/inference/paddle_inference_api_impl.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
class TensorRTSubgraphPredictor : public NativePaddlePredictor {
public:
explicit TensorRTSubgraphPredictor(const TensorRTConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
// Analyze inference_program
Argument argument;
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "transformed program:\n"
<< argument.transformed_program_desc->SerializeAsString();
VLOG(5) << "to prepare executor";
*inference_program_->Proto() = *argument.transformed_program_desc;
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
executor_->CreateVariables(
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true;
}
private:
TensorRTConfig config_;
};
template <>
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
const TensorRTConfig& config) {
VLOG(3) << "create TensorRTSubgraphPredictor";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory,
0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictor> predictor(
new TensorRTSubgraphPredictor(config));
if (!dynamic_cast<TensorRTSubgraphPredictor*>(predictor.get())
->Init(nullptr)) {
return nullptr;
}
return std::move(predictor);
}
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle {
DEFINE_string(dirname, "", "Directory of the inference model.");
void Main(bool use_gpu) {
//# 1. Create PaddlePredictor with a config.
TensorRTConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model";
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
auto predictor =
CreatePaddlePredictor<TensorRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config);
for (int batch_id = 0; batch_id < 3; batch_id++) {
//# 2. Prepare input.
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}),
.data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64};
// For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> slots(4, tensor);
//# 3. Run
std::vector<PaddleTensor> outputs;
CHECK(predictor->Run(slots, &outputs));
//# 4. Get output.
ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
}
}
}
TEST(paddle_inference_api_tensorrt_subgraph_engine, main) { Main(true); }
} // namespace paddle
\ No newline at end of file
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
fluid_to_data_flow_graph_pass.cc fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc data_flow_graph_to_fluid_pass.cc
tensorrt_subgraph_pass.cc
dfg_graphviz_draw_pass.cc dfg_graphviz_draw_pass.cc
DEPS framework_proto) tensorrt_subgraph_pass.cc
tensorrt_subgraph_node_mark_pass.cc
analyzer.cc
helper.cc
DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
...@@ -28,5 +30,7 @@ inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_ ...@@ -28,5 +30,7 @@ inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs.");
class DfgPassManagerImpl final : public DfgPassManager {
public:
DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs.
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
auto trt_teller = [](const Node* node) {
if (!node->IsFunction()) return false;
return static_cast<const Function*>(node)->func_type() == "mul";
};
AddPass("tensorrt-subgraph-marker",
new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
}
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
LOG(INFO) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
}
// Add the graphviz debuger pass if the parent pass has one.
void AddGraphvizDebugerPass(Pass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) {
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
Register(debuger_pass->repr(), debuger_pass);
}
}
};
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll();
PADDLE_ENFORCE(x->Finalize());
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains Analyzer, an class that exposed as a library that analyze
* and optimize
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
* control whether
* an process is applied on the program.
*
* The processes are called Passes in analysis, the Passes are placed in a
* pipeline, the first
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
* a data flow
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
* graph to a
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
* which take a
* node or data flow graph as input.
*
* The Analyzer can be used in two methods, the first is a executable file which
* can be used to
* pre-process the inference model and can be controlled by passing difference
* command flags;
* the other way is to compose inside the inference API as a runtime pre-process
* phase in the
* inference service.
*/
#include <gflags/gflags.h>
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle {
namespace inference {
namespace analysis {
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root);
class Analyzer : public OrderedRegistry<PassManager> {
public:
// Register all the pass-managers.
Analyzer();
void Run(Argument* argument);
DISABLE_COPY_AND_ASSIGN(Analyzer);
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, main) {
Analyzer analyser;
analyser.Run(&argument);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -41,6 +41,9 @@ struct Argument { ...@@ -41,6 +41,9 @@ struct Argument {
// The original program desc. // The original program desc.
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc; std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
// The processed program desc.
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
// It is a better idea that the inputs and outputs of this graph is set manully // It is a better idea that the inputs and outputs of this graph is set manually
// before, but there must be a Pass that helps to prune the unnecessary ops that // before, but there must be a Pass that helps to prune the unnecessary ops that
// do not contribute to the given targets, so in this pass, analysis and get the // do not contribute to the given targets, so in this pass, analysis and get the
// inputs and outputs is OK. // inputs and outputs is OK.
...@@ -50,6 +50,25 @@ void DataFlowGraph::Build() { ...@@ -50,6 +50,25 @@ void DataFlowGraph::Build() {
outputs.push_back(out); outputs.push_back(out);
} }
} }
Clean();
}
void DataFlowGraph::Clean() {
for (auto &node : nodes.nodes()) {
std::unordered_set<Node *> inlinks_set(node->inlinks.begin(),
node->inlinks.end());
std::unordered_set<Node *> outlinks_set(node->outlinks.begin(),
node->outlinks.end());
if (inlinks_set.size() < node->inlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->inlinks.assign(inlinks_set.begin(), inlinks_set.end());
}
if (outlinks_set.size() < node->outlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->outlinks.assign(outlinks_set.begin(), outlinks_set.end());
}
}
} }
std::string DataFlowGraph::DotString() const { std::string DataFlowGraph::DotString() const {
......
...@@ -47,6 +47,10 @@ struct DataFlowGraph { ...@@ -47,6 +47,10 @@ struct DataFlowGraph {
// Output a DOT graph file for debug. // Output a DOT graph file for debug.
std::string DotString() const; std::string DotString() const;
private:
// Remove duplicate edges and so on.
void Clean();
}; };
/* /*
...@@ -133,17 +137,24 @@ struct GraphTraits<DataFlowGraph> { ...@@ -133,17 +137,24 @@ struct GraphTraits<DataFlowGraph> {
// Extract the inputs and outputs of a graph. The inputs and outputs of a // Extract the inputs and outputs of a graph. The inputs and outputs of a
// sub-graph is the inputs nodes and output nodes that doesn't inside the // sub-graph is the inputs nodes and output nodes that doesn't inside the
// sub-graph. // sub-graph.
std::pair< static std::pair<std::vector<Node *>, std::vector<Node *>>
std::vector<Node *>, ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) {
std::vector<
Node *>> static ExtractInputAndOutputOfSubGraph(std::vector<Node *>
&graph) {
std::unordered_set<Node *> nodes(graph.begin(), graph.end()); std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs; std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs; std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inlinks) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) { for (auto &node : graph) {
for (auto *in : node->inlinks) { for (auto *in : node->inlinks) {
if (!nodes.count(in) && in->type() == Node::Type::kValue) { // The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
!inlink_in_subgraph(in)) {
inputs.insert(in); inputs.insert(in);
} }
} }
......
...@@ -13,21 +13,34 @@ ...@@ -13,21 +13,34 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using framework::proto::ProgramDesc;
std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes);
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
desc_ = argument->origin_program_desc.get(); PADDLE_ENFORCE(!argument->transformed_program_desc);
// Here some logic from program_desc.cc and will not add new interfaces into // The transformed_program_desc should inherit all the VarDesc and BlockDesc
// framework::ProgramDesc class, use some UT to assure the correctness. // from the original program desc. The operators of the main block(the first
auto* block = desc_->mutable_blocks()->Add(); // block) should rewritten by data flow graph.
block->set_idx(framework::kRootBlockIndex); argument->transformed_program_desc.reset(
block->set_parent_idx(framework::kNoneBlockIndex); new ProgramDesc(*argument->origin_program_desc));
argument->transformed_program_desc->mutable_blocks(framework::kRootBlockIndex)
->clear_ops();
desc_ = argument->transformed_program_desc.get();
argument_ = argument;
return true; return true;
} }
...@@ -37,14 +50,17 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { ...@@ -37,14 +50,17 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
auto traits = GraphTraits<DataFlowGraph>(graph); auto traits = GraphTraits<DataFlowGraph>(graph);
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) { for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
if (it->deleted()) continue; if (it->deleted()) continue;
switch (it->type()) { switch (it->type()) {
case Node::Type::kFunction: case Node::Type::kFunction: {
LOG(INFO) << "add function " << it->name(); LOG(INFO) << "add function " << it->repr();
AddFluidOp(&(*it)); AddFluidOp(&(*it));
break; } break;
case Node::Type::kFunctionBlock: case Node::Type::kFunctionBlock: {
LOG(INFO) << "add engine op " << it->repr() << " , "
<< static_cast<FunctionBlock*>(&(*it))->subgraph.size();
AddEngineOp(&(*it)); AddEngineOp(&(*it));
break; } break;
default: default:
continue; continue;
} }
...@@ -52,12 +68,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { ...@@ -52,12 +68,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
} }
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
LOG(INFO) << "processing func " << node->name();
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc()); auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
// currently only the main block is analyzed. // currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops(); auto* op = main_block->add_ops();
LOG(INFO) << "to copy the op";
*op = *ori_op; // copy the attributes, by default, these will not be changed *op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase. // by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt // The inputs and outputs of the existing ops are not changed by tensorrt
...@@ -65,11 +79,89 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { ...@@ -65,11 +79,89 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
// NOTE It might be changed by other passes in the long run. // NOTE It might be changed by other passes in the long run.
} }
void CreateTrtEngineOp(Node* node, const DataFlowGraph& graph,
const framework::proto::BlockDesc& block) {
static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc;
auto* func = static_cast<FunctionBlock*>(node);
// collect inputs
std::vector<std::string> io;
for (auto* x : func->inlinks) {
io.push_back(x->name());
}
desc.SetInput("Xs", io);
// collect outputs
io.clear();
for (auto* x : func->outlinks) {
io.push_back(x->name());
}
desc.SetOutput("Ys", io);
desc.SetType("tensorrt_engine");
// Set attrs
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
SetAttr(desc.Proto(), "engine_unique_key",
"trt-" + std::to_string(counter++));
SetAttr(desc.Proto(), "max_batch", 100); // TODO(Superjomn) add config latter
SetAttr(desc.Proto(), "max_workspace",
1024); // TODO(Superjomn) add config latter
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
node->SetPbMsg(desc.Proto()->SerializeAsString());
}
std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes) {
std::vector<std::string> parameters;
for (const auto& node : nodes) {
if (!node->IsValue()) continue;
PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first");
framework::proto::VarDesc var;
var.ParseFromString(node->pb_msg());
if (var.persistable()) {
parameters.push_back(var.name());
}
}
return parameters;
}
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) { void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
// auto* op = main_block->add_ops();
// TODO(Superjomn) Here need to expose some arguments for default setting. // TODO(Superjomn) Here need to expose some arguments for default setting.
PADDLE_ENFORCE(node->IsFunctionBlock());
auto* block_node = static_cast<FunctionBlock*>(node);
framework::proto::BlockDesc proto;
framework::BlockDesc block_desc(nullptr, &proto);
// copy ops.
for (auto* node : block_node->subgraph) {
auto* op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg());
}
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
op->ParseFromString(node->pb_msg());
}
namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public:
using Config = DFG_GraphvizDrawPass::Config;
DFG_DebuggerPass(const Config& config) : DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "dfg-to-fluid-debuger-pass"; }
bool Finalize() override { return true; }
};
}
Pass* DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger"));
} }
} // namespace analysis } // namespace analysis
......
...@@ -40,10 +40,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass { ...@@ -40,10 +40,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
return "Transform a DFG to a Fluid ProgramDesc"; return "Transform a DFG to a Fluid ProgramDesc";
} }
Pass *CreatePrinterPass(std::ostream &os, Pass *CreateGraphvizDebugerPass() const override;
const std::string &banner) const override {
return nullptr;
}
protected: protected:
// Add a Fluid Op into the ProgramDesc. // Add a Fluid Op into the ProgramDesc.
...@@ -53,6 +50,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass { ...@@ -53,6 +50,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
private: private:
framework::proto::ProgramDesc *desc_; framework::proto::ProgramDesc *desc_;
Argument *argument_;
}; };
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -18,12 +18,19 @@ namespace paddle { ...@@ -18,12 +18,19 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
int DFG_GraphvizDrawPass::counter_{0};
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph); auto content = Draw(graph);
std::ofstream file(GenDotPath()); auto dot_path = GenDotPath();
std::ofstream file(dot_path);
file.write(content.c_str(), content.size()); file.write(content.c_str(), content.size());
file.close(); file.close();
LOG(INFO) << "draw dot to " << GenDotPath();
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message;
LOG(INFO) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
} }
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
...@@ -41,9 +48,7 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { ...@@ -41,9 +48,7 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
if (!config_.display_deleted_node && node.deleted()) continue; if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) { for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue; if (!config_.display_deleted_node && in->deleted()) continue;
for (auto &in : node.inlinks) { dot.AddEdge(in->repr(), node.repr(), {});
dot.AddEdge(in->repr(), node.repr(), {});
}
} }
} }
return dot.Build(); return dot.Build();
......
...@@ -50,20 +50,25 @@ class DFG_GraphvizDrawPass : public DataFlowGraphPass { ...@@ -50,20 +50,25 @@ class DFG_GraphvizDrawPass : public DataFlowGraphPass {
bool Initialize(Argument *argument) override { return true; } bool Initialize(Argument *argument) override { return true; }
void Run(DataFlowGraph *graph) override; void Run(DataFlowGraph *graph) override;
bool Finalize() override { return Pass::Finalize(); } bool Finalize() override { return true; }
std::string repr() const override { return "DFG graphviz drawer"; } std::string repr() const override { return "DFG graphviz drawer"; }
std::string description() const override { std::string description() const override {
return "Debug a DFG by draw with graphviz"; return "Debug a DFG by draw with graphviz";
} }
private: protected:
// A counter to add a number prefix to the debugger image output so that they
// will sort in the triggered order.
static int counter_;
// Path of the dot file to output. // Path of the dot file to output.
std::string GenDotPath() const { std::string GenDotPath() const {
return config_.dir + "/" + "graph_" + config_.id + ".dot"; return config_.dir + "/" + std::to_string(counter_++) + "-graph_" +
config_.id + ".dot";
} }
std::string Draw(DataFlowGraph *graph); virtual std::string Draw(DataFlowGraph *graph);
Config config_; Config config_;
}; };
......
...@@ -31,7 +31,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { ...@@ -31,7 +31,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
pass.Run(&dfg); pass.Run(&dfg);
// test content // test content
std::ifstream file("./graph_test.dot"); std::ifstream file("./0-graph_test.dot");
ASSERT_TRUE(file.is_open()); ASSERT_TRUE(file.is_open());
std::string line; std::string line;
...@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { ...@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
no++; no++;
} }
// DFG is sensitive to ProgramDesc, be careful to change the existing models. // DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 112); ASSERT_EQ(no, 82);
} }
} // namespace analysis } // namespace analysis
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle { namespace paddle {
...@@ -33,7 +35,7 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { ...@@ -33,7 +35,7 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
return true; return true;
} }
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); } bool FluidToDataFlowGraphPass::Finalize() { return true; }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
...@@ -46,6 +48,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -46,6 +48,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
auto *v = graph->nodes.Create(Node::Type::kValue); auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name()); v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var))); v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id(); var2id[var.name()] = v->id();
} }
for (int i = 0; i < main_block.ops_size(); i++) { for (int i = 0; i < main_block.ops_size(); i++) {
...@@ -56,6 +59,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -56,6 +59,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
// Link to the original protobuf message's memory, make it easier to // Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc. // generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op))); o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs // set inputs and outputs
// TODO(Superjomn) make sure the InputNames is the real variable name. // TODO(Superjomn) make sure the InputNames is the real variable name.
for (int j = 0; j < op.inputs_size(); j++) { for (int j = 0; j < op.inputs_size(); j++) {
...@@ -79,9 +84,19 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -79,9 +84,19 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
graph->Build(); graph->Build();
} }
Pass *FluidToDataFlowGraphPass::CreatePrinterPass( namespace {
std::ostream &os, const std::string &banner) const { class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
return nullptr; public:
using Config = DFG_GraphvizDrawPass::Config;
DFG_DebuggerPass(const Config &config) : DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "fluid-to-dfg-debuger-pass"; }
bool Finalize() override { return true; }
};
}
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, "fluid-to-dfg-debuger"));
} }
} // namespace analysis } // namespace analysis
......
...@@ -46,8 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass { ...@@ -46,8 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
return "transform a fluid ProgramDesc to a data flow graph."; return "transform a fluid ProgramDesc to a data flow graph.";
} }
Pass *CreatePrinterPass(std::ostream &os, Pass *CreateGraphvizDebugerPass() const override;
const std::string &banner) const override;
private: private:
framework::proto::ProgramDesc const *desc_; framework::proto::ProgramDesc const *desc_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace inference {
namespace analysis {
template <>
void SetAttr<std::string>(framework::proto::OpDesc *op, const std::string &name,
const std::string &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(data);
}
template <>
void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
const int &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name,
const int64_t &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc *op,
const std::string &name,
const std::vector<std::string> &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto &s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -14,10 +14,12 @@ limitations under the License. */ ...@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdio>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -26,6 +28,10 @@ namespace paddle { ...@@ -26,6 +28,10 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
template <typename T>
void SetAttr(framework::proto::OpDesc *op, const std::string &name,
const T &data);
template <typename Vec> template <typename Vec>
int AccuDims(Vec &&vec, int size) { int AccuDims(Vec &&vec, int size) {
int res = 1; int res = 1;
...@@ -93,7 +99,7 @@ template <typename T> ...@@ -93,7 +99,7 @@ template <typename T>
class OrderedRegistry { class OrderedRegistry {
public: public:
T *Register(const std::string &name, T *x) { T *Register(const std::string &name, T *x) {
PADDLE_ENFORCE(!dic_.count(name)); PADDLE_ENFORCE(!dic_.count(name), "duplicate key [%s]", name);
dic_[name] = data_.size(); dic_[name] = data_.size();
data_.emplace_back(std::unique_ptr<T>(x)); data_.emplace_back(std::unique_ptr<T>(x));
return data_.back().get(); return data_.back().get();
...@@ -117,6 +123,20 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { ...@@ -117,6 +123,20 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>(); return *var->GetMutable<T>();
} }
static void ExecShellCommand(const std::string &cmd, std::string *message) {
char buffer[128];
std::shared_ptr<FILE> pipe(popen(cmd.c_str(), "r"), pclose);
if (!pipe) {
LOG(ERROR) << "error running command: " << cmd;
return;
}
while (!feof(pipe.get())) {
if (fgets(buffer, 128, pipe.get()) != nullptr) {
*message += buffer;
}
}
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
...@@ -20,6 +20,17 @@ namespace paddle { ...@@ -20,6 +20,17 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
template <>
std::string &NodeAttr::As<std::string>() {
if (data_.empty()) {
type_hash_ = typeid(std::string).hash_code();
}
PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code());
return data_;
}
std::string &NodeAttr::String() { return As<std::string>(); }
std::vector<Dot::Attr> Value::dot_attrs() const { std::vector<Dot::Attr> Value::dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"), return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
Dot::Attr("shape", "box"), Dot::Attr("shape", "box"),
......
...@@ -35,6 +35,44 @@ namespace analysis { ...@@ -35,6 +35,44 @@ namespace analysis {
class NodeMap; class NodeMap;
// A helper class to maintain the status from Pass.
struct NodeAttr {
// NOTE T should be a primary type or a struct combined by several primary
// types.
// NOTE the STL containers should not use here.
// Some usages
// Attr attr;
// attr.Bool() = true;
bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
std::string &String();
private:
template <typename T>
T &As() {
// init storage in the first usage.
if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code();
data_.resize(sizeof(T));
}
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
}
private:
std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()};
};
/* /*
* Node Representation. * Node Representation.
* *
...@@ -50,8 +88,6 @@ class Node { ...@@ -50,8 +88,6 @@ class Node {
Node() = default; Node() = default;
struct Attr;
// Cast to a subclass type, Function for example. // Cast to a subclass type, Function for example.
template <typename Subclass> template <typename Subclass>
Subclass &As() { Subclass &As() {
...@@ -71,7 +107,7 @@ class Node { ...@@ -71,7 +107,7 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will // Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists. // silently create a new attribute if not exists.
Attr &attr(const std::string &name) const { return attrs_[name]; } NodeAttr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; } int id() const { return id_; }
...@@ -80,6 +116,9 @@ class Node { ...@@ -80,6 +116,9 @@ class Node {
void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; } void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; }
void *pb_desc() const { return attr("pb_desc").Pointer(); } void *pb_desc() const { return attr("pb_desc").Pointer(); }
void SetPbMsg(const std::string &s) { attr("pb_msg").String() = s; }
const std::string &pb_msg() const { return attr("pb_msg").String(); }
void SetDeleted() { deleted_ = true; } void SetDeleted() { deleted_ = true; }
bool deleted() const { return deleted_; } bool deleted() const { return deleted_; }
...@@ -94,43 +133,6 @@ class Node { ...@@ -94,43 +133,6 @@ class Node {
// Output links. // Output links.
std::vector<Node *> outlinks; std::vector<Node *> outlinks;
// A helper class to maintain the status from Pass.
struct Attr {
// NOTE T should be a primary type or a struct combined by several primary
// types.
// NOTE the STL containers should not use here.
// Some usages
// Attr attr;
// attr.Bool() = true;
bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
private:
template <typename T>
T &As() {
// init storage in the first usage.
if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code();
data_.resize(sizeof(T));
}
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
}
private:
std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()};
};
// Type checks. // Type checks.
bool IsFunction() const { return type_ == Node::Type::kFunction; } bool IsFunction() const { return type_ == Node::Type::kFunction; }
bool IsValue() const { return type_ == Node::Type::kValue; } bool IsValue() const { return type_ == Node::Type::kValue; }
...@@ -150,7 +152,7 @@ class Node { ...@@ -150,7 +152,7 @@ class Node {
Type type_{Type::kNone}; Type type_{Type::kNone};
// Mark this node is deleted by some pass. // Mark this node is deleted by some pass.
bool deleted_{false}; bool deleted_{false};
mutable std::unordered_map<std::string, Attr> attrs_; mutable std::unordered_map<std::string, NodeAttr> attrs_;
}; };
class Function; class Function;
...@@ -213,6 +215,10 @@ class Function : public Node { ...@@ -213,6 +215,10 @@ class Function : public Node {
struct FunctionBlock : public Node { struct FunctionBlock : public Node {
std::string repr() const override { return "block-" + std::to_string(id()); } std::string repr() const override { return "block-" + std::to_string(id()); }
std::vector<Node *> subgraph; std::vector<Node *> subgraph;
protected:
FunctionBlock() { SetType(Node::Type::kFunctionBlock); }
friend class NodeMap;
}; };
class NodeMap { class NodeMap {
...@@ -227,7 +233,7 @@ class NodeMap { ...@@ -227,7 +233,7 @@ class NodeMap {
void Delete(size_t id); void Delete(size_t id);
const std::vector<std::unique_ptr<Node>> &nodes() { return nodes_; } const std::vector<std::unique_ptr<Node>> &nodes() const { return nodes_; }
size_t size() const { return nodes_.size(); } size_t size() const { return nodes_.size(); }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file contains all the flags that declared in Node::Attr.
*
* The Node::Attr is designed to share information between different passes, one
* can get other's attributes in a Node by the flags in this file.
*/
#pragma once
namespace paddle {
namespace inference {
namespace analysis {
#define DECLARE_NODE_ATTR(flag__) const char ATTR_##flag__[] = #flag__;
DECLARE_NODE_ATTR(supported_by_tensorrt) // bool
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -60,6 +60,9 @@ class Pass { ...@@ -60,6 +60,9 @@ class Pass {
return nullptr; return nullptr;
} }
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; }
// Run on a single Node. // Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function. // Run on a single Function.
......
...@@ -19,6 +19,18 @@ namespace paddle { ...@@ -19,6 +19,18 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
bool PassManager::Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
for (auto& pass : data_) { for (auto& pass : data_) {
......
...@@ -50,17 +50,7 @@ class PassManager : public OrderedRegistry<Pass> { ...@@ -50,17 +50,7 @@ class PassManager : public OrderedRegistry<Pass> {
// globally shared, so pass them as the arguemnts for all the pass managers. // globally shared, so pass them as the arguemnts for all the pass managers.
virtual bool Initialize(const Argument& argument) { return false; } virtual bool Initialize(const Argument& argument) { return false; }
virtual bool Initialize(Argument* argument) { virtual bool Initialize(Argument* argument);
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Call all the passes' Finalize methods. // Call all the passes' Finalize methods.
virtual bool Finalize() { virtual bool Finalize() {
......
...@@ -64,6 +64,7 @@ TEST_F(DFG_Tester, DFG_pass_manager) { ...@@ -64,6 +64,7 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
manager.Register("graphviz", new DFG_GraphvizDrawPass(config)); manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass); manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
ASSERT_TRUE(&argument);
ASSERT_TRUE(manager.Initialize(&argument)); ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll(); manager.RunAll();
} }
......
...@@ -119,10 +119,12 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } ...@@ -119,10 +119,12 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
void SubGraphFuse::ReplaceNodesWithSubGraphs() { void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) { for (auto &subgraph : subgraphs) {
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block // replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph // Node that contains this subgraph 2. Mark the nodes inside the sub-graph
// as deleted. 3. Replace the deleted node with the new Block Node. // as deleted. 3. Replace the deleted node with the new Block Node.
auto *block_node = graph_->nodes.Create(Node::Type::kFunctionBlock); auto *block_node = static_cast<FunctionBlock *>(
graph_->nodes.Create(Node::Type::kFunctionBlock));
auto io = ExtractInputAndOutputOfSubGraph(subgraph); auto io = ExtractInputAndOutputOfSubGraph(subgraph);
block_node->inlinks = std::move(io.first); block_node->inlinks = std::move(io.first);
block_node->outlinks = std::move(io.second); block_node->outlinks = std::move(io.second);
...@@ -130,21 +132,25 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { ...@@ -130,21 +132,25 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
// TODO(Superjomn) need a unified mechanism to treat deleted node in each // TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass. // pass.
node->SetDeleted(); node->SetDeleted();
block_node->subgraph.push_back(node);
} }
std::unordered_map<Node *, Node *> // Change all the sub-graph's inputs and outputs corresponding inlink and
delelte_node_map; // deleted node to BlockNode // outlink to this sub-graph node.
for (auto *n : block_node->inlinks) { auto inlink_or_outlink_cleaner = [&](std::vector<Node *> &nodes) {
n->inlinks.clear(); for (auto *&n : nodes) {
} if (subgraph_uniq.count(n)) {
for (auto *n : block_node->outlinks) { n = block_node;
n->outlinks.clear(); }
} }
for (auto *n : block_node->inlinks) { std::unordered_set<Node *> uniq(nodes.begin(), nodes.end());
n->outlinks.push_back(block_node); nodes.assign(uniq.begin(), uniq.end());
};
for (auto *i : block_node->inlinks) {
inlink_or_outlink_cleaner(i->outlinks);
} }
for (auto *n : block_node->outlinks) { for (auto *&o : block_node->outlinks) {
n->inlinks.push_back(n); inlink_or_outlink_cleaner(o->inlinks);
} }
} }
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/node_attr_flags.h"
namespace paddle {
namespace inference {
namespace analysis {
void TensorRTSubgraphNodeMarkPass::Run(DataFlowGraph *graph) {
for (auto &node : graph->nodes.nodes()) {
node->attr(ATTR_supported_by_tensorrt).Bool() = teller_(node.get());
}
}
class DfgDebuggerPass : public DFG_GraphvizDrawPass {
public:
DfgDebuggerPass(const DFG_GraphvizDrawPass::Config &config)
: DFG_GraphvizDrawPass(config) {}
std::string repr() const override {
return "tensorrt-subgraph-node-mark-debugger";
}
bool Finalize() override { return true; }
protected:
std::string Draw(DataFlowGraph *graph) override {
Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
auto dot_attr = node.dot_attrs();
if (node.attr(ATTR_supported_by_tensorrt).Bool()) {
dot_attr.assign(
{Dot::Attr{"color", "green"}, Dot::Attr{"style", "filled"}});
}
dot.AddNode(node.repr(), dot_attr);
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
dot.AddEdge(in->repr(), node.repr(), {});
}
}
return dot.Build();
}
};
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config(
FLAGS_inference_analysis_graphviz_log_root, "tensorrt_marked_node");
return new DfgDebuggerPass(config);
}
bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; }
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines TensorRTSubgraphNodeMarkPass which helps to mark the ops
* that supported by TensorRT engine.
*/
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Mark the operators that TensorRT engine supports.
*/
class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
public:
using teller_t = SubGraphSplitter::NodeInsideSubgraphTeller;
TensorRTSubgraphNodeMarkPass(const teller_t& teller) : teller_(teller) {}
bool Initialize(Argument* argument) override { return true; }
// This class get a sub-graph as input and determine whether to transform this
// sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override;
std::string repr() const { return "tensorrt-sub-subgraph-mark"; }
std::string description() const { return "tensorrt sub-graph mark pass"; }
Pass* CreateGraphvizDebugerPass() const override;
bool Finalize() override;
private:
teller_t teller_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/node_attr_flags.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
// init
FluidToDataFlowGraphPass pass;
ASSERT_TRUE(pass.Initialize(&argument));
argument.main_dfg.reset(new DataFlowGraph);
pass.Run(argument.main_dfg.get());
TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) {
return node->IsFunction() &&
static_cast<const Function*>(node)->func_type() == "mul";
};
TensorRTSubgraphNodeMarkPass pass1(teller);
ASSERT_TRUE(pass1.Initialize(&argument));
pass1.Run(argument.main_dfg.get());
int counter{0};
for (auto& node : argument.main_dfg->nodes.nodes()) {
counter += node->attr(ATTR_supported_by_tensorrt).Bool();
}
LOG(INFO) << counter << " nodes marked";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( ...@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
: node_inside_subgraph_teller_(teller) {} : node_inside_subgraph_teller_(teller) {}
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_); SubGraphFuse(graph, node_inside_subgraph_teller_)();
} }
} // namespace analysis } // namespace analysis
......
...@@ -38,6 +38,11 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { ...@@ -38,6 +38,11 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
// sub-graph into TensorRT. // sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override; void Run(DataFlowGraph* graph) override;
bool Finalize() override { return true; }
std::string repr() const { return "tensorrt-sub-graph"; }
std::string description() const { return "tensorrt sub graph pass"; }
private: private:
NodeInsideSubgraphTeller node_inside_subgraph_teller_; NodeInsideSubgraphTeller node_inside_subgraph_teller_;
}; };
......
...@@ -23,49 +23,48 @@ namespace paddle { ...@@ -23,49 +23,48 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
DEFINE_string(model_dir, "", "inference test model dir"); DEFINE_string(dot_dir, "./", "");
TEST(TensorRTSubGraph, single_pass) { TEST_F(DFG_Tester, tensorrt_single_pass) {
auto desc = LoadProgramDesc(); std::unordered_set<std::string> teller_set(
auto dfg = ProgramDescToDFG(desc); {"elementwise_add", "mul", "sigmoid"});
SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) {
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false; if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || if (teller_set.count(func->func_type())) return true;
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
return false; return false;
}; };
DFG_GraphvizDrawPass::Config config{"./", "test"}; LOG(INFO) << "init";
DFG_GraphvizDrawPass dfg_pass(config); DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
dfg_pass.Initialize(); DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
DFG_GraphvizDrawPass dfg_pass1(config);
dfg_pass1.Initialize();
dfg_pass.Run(&dfg);
DFG_GraphvizDrawPass dfg_pass(config);
DFG_GraphvizDrawPass dfg_pass1(config1);
FluidToDataFlowGraphPass pass0;
TensorRTSubGraphPass trt_pass(std::move(teller)); TensorRTSubGraphPass trt_pass(std::move(teller));
trt_pass.Initialize();
trt_pass.Run(&dfg); LOG(INFO) << "Initialize";
dfg_pass.Initialize(&argument);
dfg_pass1.Initialize(&argument);
pass0.Initialize(&argument);
trt_pass.Initialize(&argument);
dfg_pass1.Run(&dfg); LOG(INFO) << "Run";
argument.main_dfg.reset(new DataFlowGraph);
pass0.Run(argument.main_dfg.get());
dfg_pass.Run(argument.main_dfg.get());
trt_pass.Run(argument.main_dfg.get());
dfg_pass1.Run(argument.main_dfg.get());
// Check the TRT op's block desc // Check the TRT op's block desc
for (auto node : dfg.nodes.nodes()) { for (auto& node : argument.main_dfg->nodes.nodes()) {
if (node->IsFunctionBlock()) { if (node->IsFunctionBlock()) {
LOG(INFO) << "get function block";
} }
} }
} }
TEST(TensorRTSubGraph, pass_manager) {}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -226,7 +226,8 @@ op_library(sequence_softmax_op DEPS softmax) ...@@ -226,7 +226,8 @@ op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine) op_library(tensorrt_engine_op DEPS tensorrt_engine)
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter) DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter
analysis)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
......
...@@ -53,6 +53,7 @@ template <typename DeviceContext, typename T> ...@@ -53,6 +53,7 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> { class TensorRTEngineKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
VLOG(4) << "TensorRTEngineKernel executing";
auto engine_name = context.Attr<std::string>("engine_uniq_key"); auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) { if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context); Prepare(context);
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
...@@ -51,48 +52,10 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, ...@@ -51,48 +52,10 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
*var = *desc.Proto(); *var = *desc.Proto();
} }
template <typename T>
void SetAttr(framework::proto::OpDesc* op, const std::string& name,
const T& data);
template <>
void SetAttr<std::string>(framework::proto::OpDesc* op, const std::string& name,
const std::string& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(data);
}
template <>
void SetAttr<int>(framework::proto::OpDesc* op, const std::string& name,
const int& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name,
const int64_t& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc* op,
const std::string& name,
const std::vector<std::string>& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto& s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace } // namespace
using inference::analysis::SetAttr;
TEST(TensorRTEngineOp, manual) { TEST(TensorRTEngineOp, manual) {
framework::ProgramDesc program; framework::ProgramDesc program;
auto* block_ = program.Proto()->add_blocks(); auto* block_ = program.Proto()->add_blocks();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册