未验证 提交 76c95af0 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Fix BUG: Mask RCNN inference diff When using AnalysisPredictor. (#19213)

* fix mask rcnn bug:
1. affine channel fuse (diff)
2. condition block op (memory leak)
3. merge lod tensor op (diff)
4. memroy optim (diff)
test=develop

* fix ci aboud PADDLE_ENFOCE
fix merge lod infer op ut
test=develop
上级 5fc8de44
......@@ -52,7 +52,6 @@ pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
......
......@@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) {
"analsis_passes is not valid in the argument.");
for (auto &pass : argument->analysis_passes()) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass")
continue;
auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass);
ptr->Run(argument);
......
......@@ -30,7 +30,7 @@ using namespace framework; // NOLINT
TEST(Analyzer, analysis_without_tensorrt) {
Argument argument;
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetEnableAnalysisOptim(false);
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"});
......@@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
Argument argument;
argument.SetEnableAnalysisOptim(false);
argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"});
......
......@@ -62,6 +62,9 @@ struct Argument {
using anakin_max_shape_t = std::map<std::string, std::vector<int>>;
bool Has(const std::string& key) const { return valid_fields_.count(key); }
// If we set the model using config.SetModelBuffer,
// the model and parameter will occupy additional CPU resources.
// Use this interface to release these resources.
void PartiallyRelease() {
if (Has("model_program_path")) {
if (Has("model_from_memory") && model_from_memory()) {
......@@ -130,6 +133,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
// The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
......@@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p
cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(ir_graph_clean_pass SRCS ir_graph_clean_pass.cc DEPS analysis_pass)
cc_library(analysis_passes SRCS passes.cc DEPS
ir_graph_build_pass
......@@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS
memory_optim_pass
inference_op_replace_pass
ir_graph_to_program_pass
ir_graph_clean_pass
)
set(analysis_deps ${analysis_deps}
......
......@@ -20,9 +20,9 @@ namespace inference {
namespace analysis {
void InferenceOpReplacePass::RunImpl(Argument* argument) {
if (!argument->use_gpu()) return;
std::unordered_map<std::string, std::string> replaced_map{
{"conditional_block", "conditional_block_infer"},
{"merge_lod_tensor", "merge_lod_tensor_infer"},
};
auto& graph = argument->main_graph();
......
......@@ -12,56 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {
namespace ir {
class InferCleanGraphPass : public FusePassBase {
public:
virtual ~InferCleanGraphPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<const Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node);
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
// Collect all the operators to help tracking number of operators.
++valid_op;
}
namespace inference {
namespace analysis {
void IrInferCleanGraphPass::RunImpl(Argument* argument) {
auto& graph = argument->main_graph();
auto is_valid_node = [](framework::ir::Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<const framework::ir::Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph.Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node);
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
++valid_op;
}
GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op);
}
void CleanEdges(std::vector<Node*>* nodes,
const std::unordered_set<Node*>& to_remove) const {
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
}
};
GraphSafeRemoveNodes(&graph, invalid_nodes);
}
} // namespace ir
} // namespace framework
} // namespace analysis
} // namespace inference
} // namespace paddle
REGISTER_PASS(infer_clean_graph_pass,
paddle::framework::ir::InferCleanGraphPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class IrInferCleanGraphPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir_graph_clean_pass"; }
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) {
void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const {
const int fake_batch_size = 1;
auto valid_var = [&](framework::ir::Node* node) -> bool {
std::set<std::string> invalid_op = {"while", "conditional_block",
std::set<std::string> invalid_op = {"while",
"conditional_block",
"tensorrt_engine",
"conditional_block_infer"};
"conditional_block_infer",
"merge_lod_tensor_infer",
"merge_lod_tensor",
"equal",
"lod_reset"};
for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp());
std::string op_type = tmp->Op()->Type();
......
......@@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass {
int sort_kind) const;
void CollectVarMemorySize(space_table_t *space_table) const;
void CollectVarMemorySize0(space_table_t *space_table) const;
void CollectVarMemorySize(
const std::unordered_map<std::string, size_t> &batch_var_ave_dim,
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
......@@ -32,6 +33,8 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace(
......
......@@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram(
const std::shared_ptr<framework::ProgramDesc> &program) {
if (!program) {
if (!LoadProgramDesc()) return false;
// If not cloned, the parameters should be loaded.
// If config_.ir_optim() is True, parameters is loaded in
// OptimizeInferenceProgram(), but other persistable variables
......@@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram(
// So in both case, create persistable variables at first.
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// Optimize the program, and load parameters and modify them in the
// scope_.
// This will change the scope_ address.
if (config_.ir_optim()) {
status_ir_optim_enabled_ = true;
OptimizeInferenceProgram();
} else {
// Load parameters
LOG(INFO) << "load parameters ";
LoadParameters();
}
// if enable_ir_optim_ is false,
// the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will
// not be executed.
OptimizeInferenceProgram();
} else {
// If the program is passed from external, no need to optimize it, this
// logic is used in the clone scenario.
......@@ -396,6 +388,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetStaticMemoryOptim(config_.static_memory_optim_);
argument_.SetStaticMemoryOptimForceUpdate(
......@@ -467,8 +460,6 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;
PrepareArgument();
Analyzer().Run(&argument_);
......
......@@ -178,10 +178,8 @@ class AnalysisPredictor : public PaddlePredictor {
private:
// Some status here that help to determine the status inside the predictor.
bool status_program_optimized_{false};
bool status_is_cloned_{false};
bool status_use_gpu_{false};
bool status_ir_optim_enabled_{false};
};
} // namespace paddle
......@@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) {
ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned off, so program shouldn't be optimized.
ASSERT_FALSE(predictor->status_program_optimized_);
LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size();
// 2. Dummy Input Data
......@@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned on, so program should be optimized.
ASSERT_TRUE(predictor->status_program_optimized_);
// 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor;
......
......@@ -400,13 +400,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({
"infer_clean_graph_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass",
"cpu_quantize_pass", "cpu_quantize_squash_pass",
});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
{"ir_graph_clean_pass", "ir_analysis_pass", "memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_);
}
......
......@@ -71,8 +71,7 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({
"infer_clean_graph_pass", //
"conv_affine_channel_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
......@@ -91,7 +90,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
// The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
"quant_conv2d_dequant_fuse_pass", //
"simplify_anakin_priorbox_detection_out_pass", //
"fillconstant_elementwisemul_fuse", //
......@@ -105,9 +103,8 @@ const std::vector<std::string> kAnakinSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"infer_clean_graph_pass", //
// "identity_scale_op_clean_pass", //
"conv_affine_channel_fuse_pass", //
// "identity_scale_op_clean_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
......@@ -141,8 +138,7 @@ void GpuPassStrategy::EnableNgraph() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
passes_.assign({"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
passes_.assign({"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", //
......
......@@ -72,7 +72,7 @@ class PaddlePassBuilder {
protected:
std::vector<std::string> analysis_passes_{
{"ir_graph_build_pass", "ir_analysis_pass",
{"ir_graph_build_pass", "ir_graph_clean_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass",
"inference_op_replace_pass"}};
std::vector<std::string> passes_;
......
......@@ -51,8 +51,8 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
#if IS_TRT_VERSION_GE(5000)
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
#if IS_TRT_VERSION_GE(5000)
if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_->setFp16Mode(support_fp16);
......@@ -62,9 +62,10 @@ void TensorRTEngine::FreezeNetwork() {
}
}
#else
LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT "
"is at least 5."
"So, use FP32 to run.";
if (enable_fp16)
LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT "
"is at least 5."
"So, use FP32 to run.";
#endif
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
......
......@@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
protected:
void RunBase(const framework::Scope &scope,
const platform::Place &dev_place) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
......@@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase {
out_lod->insert(out_lod->begin(), x.lod()[i]);
}
}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
}
};
class MergeLoDTensorInferOp : public MergeLoDTensorOp {
public:
MergeLoDTensorInferOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: MergeLoDTensorOp(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
framework::Variable *in_true_var = scope.FindVar(Input("InTrue"));
framework::Variable *in_false_var = scope.FindVar(Input("InFalse"));
in_true_var->Clear();
in_false_var->Clear();
in_true_var->GetMutable<framework::LoDTensor>();
in_false_var->GetMutable<framework::LoDTensor>();
}
};
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
......@@ -196,3 +223,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp,
ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker);
REGISTER_OPERATOR(merge_lod_tensor_infer, ops::MergeLoDTensorInferOp,
ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape,
paddle::framework::EmptyGradOpMaker);
......@@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward
from paddle.fluid.layers.control_flow import split_lod_tensor
from paddle.fluid.layers.control_flow import merge_lod_tensor
from paddle.fluid.layer_helper import LayerHelper
class TestCPULoDTensorArrayOps(unittest.TestCase):
......@@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
expect_false=expect_false,
expect_out=tensor)
def test_split_and_merge_lod_tensor_level_0(self):
def split_and_merge_lod_tensor_level_0(self, use_merge_lod_infer=False):
tensor = core.LoDTensor()
tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place())
tensor.set_recursive_sequence_lengths([[3, 6, 1]])
......@@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
mask=mask,
expect_true=expect_true,
expect_false=expect_false,
expect_out=tensor)
def main(self, tensor, mask, expect_true, expect_false, expect_out,
level=0):
expect_out=tensor,
use_merge_lod_infer=use_merge_lod_infer)
def test_split_and_merge_lod_tensor_1(self):
self.split_and_merge_lod_tensor_level_0()
def test_split_and_merge_lod_tensor_2(self):
self.split_and_merge_lod_tensor_level_0(True)
def main(self,
tensor,
mask,
expect_true,
expect_false,
expect_out,
level=0,
use_merge_lod_infer=False):
place = self.place()
program = Program()
with program_guard(program):
......@@ -103,11 +117,36 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
out_true, out_false = split_lod_tensor(input=x, mask=y, level=level)
out_true.persistable = True
out_false.persistable = True
out = merge_lod_tensor(
in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
out.persistable = True
if use_merge_lod_infer:
input_dict = {
'X': x,
'Mask': mask,
'InTrue': out_true,
'InFalse': out_false,
'level': level
}
helper = LayerHelper('merge_lod_tensor_infer')
out = helper.create_variable_for_type_inference(
dtype=out_true.dtype)
helper.append_op(
type='merge_lod_tensor_infer',
inputs={
'X': x,
'Mask': y,
'InTrue': out_true,
'InFalse': out_false
},
outputs={'Out': out},
attrs={'level': level})
out.persistable = True
else:
out = merge_lod_tensor(
in_true=out_true,
in_false=out_false,
mask=y,
x=x,
level=level)
out.persistable = True
exe = Executor(place)
scope = core.Scope()
......@@ -122,9 +161,9 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
var_false = scope.find_var(out_false.name).get_tensor()
var_out = scope.find_var(out.name).get_tensor()
self.check_tensor_same(var_true, expect_true)
self.check_tensor_same(var_false, expect_false)
if not use_merge_lod_infer:
self.check_tensor_same(var_true, expect_true)
self.check_tensor_same(var_false, expect_false)
self.check_tensor_same(var_out, expect_out)
def check_tensor_same(self, actual, expect):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册