未验证 提交 76c95af0 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Fix BUG: Mask RCNN inference diff When using AnalysisPredictor. (#19213)

* fix mask rcnn bug:
1. affine channel fuse (diff)
2. condition block op (memory leak)
3. merge lod tensor op (diff)
4. memroy optim (diff)
test=develop

* fix ci aboud PADDLE_ENFOCE
fix merge lod infer op ut
test=develop
上级 5fc8de44
...@@ -52,7 +52,6 @@ pass_library(graph_viz_pass base) ...@@ -52,7 +52,6 @@ pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base) pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
......
...@@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) { ...@@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) {
"analsis_passes is not valid in the argument."); "analsis_passes is not valid in the argument.");
for (auto &pass : argument->analysis_passes()) { for (auto &pass : argument->analysis_passes()) {
string::PrettyLogH1("--- Running analysis [%s]", pass); string::PrettyLogH1("--- Running analysis [%s]", pass);
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass")
continue;
auto *ptr = PassRegistry::Global().Retreive(pass); auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass); PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass);
ptr->Run(argument); ptr->Run(argument);
......
...@@ -30,7 +30,7 @@ using namespace framework; // NOLINT ...@@ -30,7 +30,7 @@ using namespace framework; // NOLINT
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
Argument argument; Argument argument;
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); argument.SetEnableAnalysisOptim(false);
argument.SetUseGPU(false); argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"}); "ir_params_sync_among_devices_pass"});
...@@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
Argument argument; Argument argument;
argument.SetEnableAnalysisOptim(false);
argument.SetTensorRtMaxBatchSize(3); argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetUseGPU(false); argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"}); "ir_params_sync_among_devices_pass"});
......
...@@ -62,6 +62,9 @@ struct Argument { ...@@ -62,6 +62,9 @@ struct Argument {
using anakin_max_shape_t = std::map<std::string, std::vector<int>>; using anakin_max_shape_t = std::map<std::string, std::vector<int>>;
bool Has(const std::string& key) const { return valid_fields_.count(key); } bool Has(const std::string& key) const { return valid_fields_.count(key); }
// If we set the model using config.SetModelBuffer,
// the model and parameter will occupy additional CPU resources.
// Use this interface to release these resources.
void PartiallyRelease() { void PartiallyRelease() {
if (Has("model_program_path")) { if (Has("model_program_path")) {
if (Has("model_from_memory") && model_from_memory()) { if (Has("model_from_memory") && model_from_memory()) {
...@@ -130,6 +133,7 @@ struct Argument { ...@@ -130,6 +133,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
...@@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p ...@@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p
cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(ir_graph_clean_pass SRCS ir_graph_clean_pass.cc DEPS analysis_pass)
cc_library(analysis_passes SRCS passes.cc DEPS cc_library(analysis_passes SRCS passes.cc DEPS
ir_graph_build_pass ir_graph_build_pass
...@@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS ...@@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS
memory_optim_pass memory_optim_pass
inference_op_replace_pass inference_op_replace_pass
ir_graph_to_program_pass ir_graph_to_program_pass
ir_graph_clean_pass
) )
set(analysis_deps ${analysis_deps} set(analysis_deps ${analysis_deps}
......
...@@ -20,9 +20,9 @@ namespace inference { ...@@ -20,9 +20,9 @@ namespace inference {
namespace analysis { namespace analysis {
void InferenceOpReplacePass::RunImpl(Argument* argument) { void InferenceOpReplacePass::RunImpl(Argument* argument) {
if (!argument->use_gpu()) return;
std::unordered_map<std::string, std::string> replaced_map{ std::unordered_map<std::string, std::string> replaced_map{
{"conditional_block", "conditional_block_infer"}, {"conditional_block", "conditional_block_infer"},
{"merge_lod_tensor", "merge_lod_tensor_infer"},
}; };
auto& graph = argument->main_graph(); auto& graph = argument->main_graph();
......
...@@ -12,56 +12,36 @@ ...@@ -12,56 +12,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace framework { namespace inference {
namespace ir { namespace analysis {
class InferCleanGraphPass : public FusePassBase { void IrInferCleanGraphPass::RunImpl(Argument* argument) {
public: auto& graph = argument->main_graph();
virtual ~InferCleanGraphPass() {} auto is_valid_node = [](framework::ir::Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
protected: };
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph); std::unordered_set<const framework::ir::Node*> invalid_nodes;
PADDLE_ENFORCE(graph); int valid_op = 0;
for (auto* node : graph.Nodes()) {
auto is_valid_node = [](Node* x) { PADDLE_ENFORCE_NOT_NULL(node);
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); if (is_valid_node(node)) {
}; invalid_nodes.insert(node);
} else if (node->IsOp()) {
std::unordered_set<const Node*> invalid_nodes; ++valid_op;
int valid_op = 0;
for (auto* node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node);
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
// Collect all the operators to help tracking number of operators.
++valid_op;
}
} }
GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op);
} }
void CleanEdges(std::vector<Node*>* nodes, GraphSafeRemoveNodes(&graph, invalid_nodes);
const std::unordered_set<Node*>& to_remove) const { }
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
}
};
} // namespace ir } // namespace analysis
} // namespace framework } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_PASS(infer_clean_graph_pass,
paddle::framework::ir::InferCleanGraphPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class IrInferCleanGraphPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir_graph_clean_pass"; }
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) { ...@@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) {
void MemoryOptimizePass::CollectVarMemorySize( void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const { space_table_t* space_table) const {
const int fake_batch_size = 1; const int fake_batch_size = 1;
auto valid_var = [&](framework::ir::Node* node) -> bool { auto valid_var = [&](framework::ir::Node* node) -> bool {
std::set<std::string> invalid_op = {"while", "conditional_block", std::set<std::string> invalid_op = {"while",
"conditional_block",
"tensorrt_engine", "tensorrt_engine",
"conditional_block_infer"}; "conditional_block_infer",
"merge_lod_tensor_infer",
"merge_lod_tensor",
"equal",
"lod_reset"};
for (auto* tmp : node->inputs) { for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp()); CHECK(tmp->IsOp());
std::string op_type = tmp->Op()->Type(); std::string op_type = tmp->Op()->Type();
......
...@@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass { ...@@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass {
int sort_kind) const; int sort_kind) const;
void CollectVarMemorySize(space_table_t *space_table) const; void CollectVarMemorySize(space_table_t *space_table) const;
void CollectVarMemorySize0(space_table_t *space_table) const;
void CollectVarMemorySize( void CollectVarMemorySize(
const std::unordered_map<std::string, size_t> &batch_var_ave_dim, const std::unordered_map<std::string, size_t> &batch_var_ave_dim,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -32,6 +33,8 @@ PassRegistry::PassRegistry() { ...@@ -32,6 +33,8 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass)); std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass", passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass)); std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass", passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass)); std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace( passes_.emplace(
......
...@@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram(
const std::shared_ptr<framework::ProgramDesc> &program) { const std::shared_ptr<framework::ProgramDesc> &program) {
if (!program) { if (!program) {
if (!LoadProgramDesc()) return false; if (!LoadProgramDesc()) return false;
// If not cloned, the parameters should be loaded. // If not cloned, the parameters should be loaded.
// If config_.ir_optim() is True, parameters is loaded in // If config_.ir_optim() is True, parameters is loaded in
// OptimizeInferenceProgram(), but other persistable variables // OptimizeInferenceProgram(), but other persistable variables
...@@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram(
// So in both case, create persistable variables at first. // So in both case, create persistable variables at first.
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// Optimize the program, and load parameters and modify them in the // if enable_ir_optim_ is false,
// scope_. // the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will
// This will change the scope_ address. // not be executed.
if (config_.ir_optim()) { OptimizeInferenceProgram();
status_ir_optim_enabled_ = true;
OptimizeInferenceProgram();
} else {
// Load parameters
LOG(INFO) << "load parameters ";
LoadParameters();
}
} else { } else {
// If the program is passed from external, no need to optimize it, this // If the program is passed from external, no need to optimize it, this
// logic is used in the clone scenario. // logic is used in the clone scenario.
...@@ -396,6 +388,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -396,6 +388,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
void AnalysisPredictor::PrepareArgument() { void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetStaticMemoryOptim(config_.static_memory_optim_); argument_.SetStaticMemoryOptim(config_.static_memory_optim_);
argument_.SetStaticMemoryOptimForceUpdate( argument_.SetStaticMemoryOptimForceUpdate(
...@@ -467,8 +460,6 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -467,8 +460,6 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() { void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;
PrepareArgument(); PrepareArgument();
Analyzer().Run(&argument_); Analyzer().Run(&argument_);
......
...@@ -178,10 +178,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -178,10 +178,8 @@ class AnalysisPredictor : public PaddlePredictor {
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
bool status_program_optimized_{false};
bool status_is_cloned_{false}; bool status_is_cloned_{false};
bool status_use_gpu_{false}; bool status_use_gpu_{false};
bool status_ir_optim_enabled_{false};
}; };
} // namespace paddle } // namespace paddle
...@@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) { ...@@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) {
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned off, so program shouldn't be optimized. // ir is turned off, so program shouldn't be optimized.
ASSERT_FALSE(predictor->status_program_optimized_);
LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size(); LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size();
// 2. Dummy Input Data // 2. Dummy Input Data
...@@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) { ...@@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_TRUE(predictor->sub_scope_); ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned on, so program should be optimized.
ASSERT_TRUE(predictor->status_program_optimized_);
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor; PaddleTensor tensor;
......
...@@ -400,13 +400,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -400,13 +400,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto* builder = predictor_.config_.pass_builder(); auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({ builder->SetPasses({
"infer_clean_graph_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass",
}); });
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses( predictor_.argument_.SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"}); {"ir_graph_clean_pass", "ir_analysis_pass", "memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_); predictor_.argument_.SetQuantVarScales(scales_);
} }
......
...@@ -71,8 +71,7 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { ...@@ -71,8 +71,7 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void PaddlePassBuilder::ClearPasses() { passes_.clear(); } void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({ const std::vector<std::string> kTRTSubgraphPasses({
"infer_clean_graph_pass", // "conv_affine_channel_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"shuffle_channel_detect_pass", // "shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
...@@ -91,7 +90,6 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -91,7 +90,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
// The following passes works for Anakin sub-graph engine. // The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({ const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"simplify_anakin_priorbox_detection_out_pass", // "simplify_anakin_priorbox_detection_out_pass", //
"fillconstant_elementwisemul_fuse", // "fillconstant_elementwisemul_fuse", //
...@@ -105,9 +103,8 @@ const std::vector<std::string> kAnakinSubgraphPasses({ ...@@ -105,9 +103,8 @@ const std::vector<std::string> kAnakinSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"infer_clean_graph_pass", // // "identity_scale_op_clean_pass", //
// "identity_scale_op_clean_pass", // "conv_affine_channel_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
...@@ -141,8 +138,7 @@ void GpuPassStrategy::EnableNgraph() { ...@@ -141,8 +138,7 @@ void GpuPassStrategy::EnableNgraph() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will // NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones. // not be damaged by smaller ones.
passes_.assign({"infer_clean_graph_pass", // passes_.assign({"attention_lstm_fuse_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", // // "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", //
......
...@@ -72,7 +72,7 @@ class PaddlePassBuilder { ...@@ -72,7 +72,7 @@ class PaddlePassBuilder {
protected: protected:
std::vector<std::string> analysis_passes_{ std::vector<std::string> analysis_passes_{
{"ir_graph_build_pass", "ir_analysis_pass", {"ir_graph_build_pass", "ir_graph_clean_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass", "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass",
"inference_op_replace_pass"}}; "inference_op_replace_pass"}};
std::vector<std::string> passes_; std::vector<std::string> passes_;
......
...@@ -51,8 +51,8 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -51,8 +51,8 @@ void TensorRTEngine::FreezeNetwork() {
// build engine. // build engine.
infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_); infer_builder_->setMaxWorkspaceSize(max_workspace_);
#if IS_TRT_VERSION_GE(5000)
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
#if IS_TRT_VERSION_GE(5000)
if (enable_fp16) { if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16(); bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_->setFp16Mode(support_fp16); infer_builder_->setFp16Mode(support_fp16);
...@@ -62,9 +62,10 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -62,9 +62,10 @@ void TensorRTEngine::FreezeNetwork() {
} }
} }
#else #else
LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT " if (enable_fp16)
"is at least 5." LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT "
"So, use FP32 to run."; "is at least 5."
"So, use FP32 to run.";
#endif #endif
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8); bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
......
...@@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
private: protected:
void RunImpl(const framework::Scope &scope, void RunBase(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
...@@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase {
out_lod->insert(out_lod->begin(), x.lod()[i]); out_lod->insert(out_lod->begin(), x.lod()[i]);
} }
} }
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
}
};
class MergeLoDTensorInferOp : public MergeLoDTensorOp {
public:
MergeLoDTensorInferOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: MergeLoDTensorOp(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
framework::Variable *in_true_var = scope.FindVar(Input("InTrue"));
framework::Variable *in_false_var = scope.FindVar(Input("InFalse"));
in_true_var->Clear();
in_false_var->Clear();
in_true_var->GetMutable<framework::LoDTensor>();
in_false_var->GetMutable<framework::LoDTensor>();
}
}; };
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -196,3 +223,7 @@ namespace ops = paddle::operators; ...@@ -196,3 +223,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp,
ops::MergeLoDTensorOpProtoMaker, ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker);
REGISTER_OPERATOR(merge_lod_tensor_infer, ops::MergeLoDTensorInferOp,
ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape,
paddle::framework::EmptyGradOpMaker);
...@@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor ...@@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.layers.control_flow import split_lod_tensor from paddle.fluid.layers.control_flow import split_lod_tensor
from paddle.fluid.layers.control_flow import merge_lod_tensor from paddle.fluid.layers.control_flow import merge_lod_tensor
from paddle.fluid.layer_helper import LayerHelper
class TestCPULoDTensorArrayOps(unittest.TestCase): class TestCPULoDTensorArrayOps(unittest.TestCase):
...@@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
expect_false=expect_false, expect_false=expect_false,
expect_out=tensor) expect_out=tensor)
def test_split_and_merge_lod_tensor_level_0(self): def split_and_merge_lod_tensor_level_0(self, use_merge_lod_infer=False):
tensor = core.LoDTensor() tensor = core.LoDTensor()
tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place())
tensor.set_recursive_sequence_lengths([[3, 6, 1]]) tensor.set_recursive_sequence_lengths([[3, 6, 1]])
...@@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
mask=mask, mask=mask,
expect_true=expect_true, expect_true=expect_true,
expect_false=expect_false, expect_false=expect_false,
expect_out=tensor) expect_out=tensor,
use_merge_lod_infer=use_merge_lod_infer)
def main(self, tensor, mask, expect_true, expect_false, expect_out,
level=0): def test_split_and_merge_lod_tensor_1(self):
self.split_and_merge_lod_tensor_level_0()
def test_split_and_merge_lod_tensor_2(self):
self.split_and_merge_lod_tensor_level_0(True)
def main(self,
tensor,
mask,
expect_true,
expect_false,
expect_out,
level=0,
use_merge_lod_infer=False):
place = self.place() place = self.place()
program = Program() program = Program()
with program_guard(program): with program_guard(program):
...@@ -103,11 +117,36 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -103,11 +117,36 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
out_true, out_false = split_lod_tensor(input=x, mask=y, level=level) out_true, out_false = split_lod_tensor(input=x, mask=y, level=level)
out_true.persistable = True out_true.persistable = True
out_false.persistable = True out_false.persistable = True
if use_merge_lod_infer:
out = merge_lod_tensor( input_dict = {
in_true=out_true, in_false=out_false, mask=y, x=x, level=level) 'X': x,
'Mask': mask,
out.persistable = True 'InTrue': out_true,
'InFalse': out_false,
'level': level
}
helper = LayerHelper('merge_lod_tensor_infer')
out = helper.create_variable_for_type_inference(
dtype=out_true.dtype)
helper.append_op(
type='merge_lod_tensor_infer',
inputs={
'X': x,
'Mask': y,
'InTrue': out_true,
'InFalse': out_false
},
outputs={'Out': out},
attrs={'level': level})
out.persistable = True
else:
out = merge_lod_tensor(
in_true=out_true,
in_false=out_false,
mask=y,
x=x,
level=level)
out.persistable = True
exe = Executor(place) exe = Executor(place)
scope = core.Scope() scope = core.Scope()
...@@ -122,9 +161,9 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -122,9 +161,9 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
var_false = scope.find_var(out_false.name).get_tensor() var_false = scope.find_var(out_false.name).get_tensor()
var_out = scope.find_var(out.name).get_tensor() var_out = scope.find_var(out.name).get_tensor()
if not use_merge_lod_infer:
self.check_tensor_same(var_true, expect_true) self.check_tensor_same(var_true, expect_true)
self.check_tensor_same(var_false, expect_false) self.check_tensor_same(var_false, expect_false)
self.check_tensor_same(var_out, expect_out) self.check_tensor_same(var_out, expect_out)
def check_tensor_same(self, actual, expect): def check_tensor_same(self, actual, expect):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册