未验证 提交 6fbd224e 编写于 作者: Z Zhaolong Xing 提交者: GitHub

CHERRY PICK FROM 18941, 18860, 19213:Fix Mask RCNN bug AND Paddle-TRT fp16 support (#19378)

* CHERRY_PICK 18941, 18860: TRT fp16 support.

test=release/1.5

* CHERRY_PICK 19213: Fix BUG: Mask RCNN inference diff When using AnalysisPredictor.
    1. fix affine channel fuse pass.
    2. fix condition block op.
    3. fix merge lod tensor op bug.
    4. fix memory optim cause by reset lod op.

    test=release/1.5
上级 2656e90b
...@@ -52,7 +52,6 @@ pass_library(graph_viz_pass base) ...@@ -52,7 +52,6 @@ pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base) pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
......
...@@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) { ...@@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) {
"analsis_passes is not valid in the argument."); "analsis_passes is not valid in the argument.");
for (auto &pass : argument->analysis_passes()) { for (auto &pass : argument->analysis_passes()) {
string::PrettyLogH1("--- Running analysis [%s]", pass); string::PrettyLogH1("--- Running analysis [%s]", pass);
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass")
continue;
auto *ptr = PassRegistry::Global().Retreive(pass); auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass); PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass);
ptr->Run(argument); ptr->Run(argument);
......
...@@ -30,7 +30,7 @@ using namespace framework; // NOLINT ...@@ -30,7 +30,7 @@ using namespace framework; // NOLINT
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
Argument argument; Argument argument;
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); argument.SetEnableAnalysisOptim(false);
argument.SetUseGPU(false); argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"}); "ir_params_sync_among_devices_pass"});
...@@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
Argument argument; Argument argument;
argument.SetEnableAnalysisOptim(false);
argument.SetTensorRtMaxBatchSize(3); argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetUseGPU(false); argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"}); "ir_params_sync_among_devices_pass"});
......
...@@ -62,6 +62,9 @@ struct Argument { ...@@ -62,6 +62,9 @@ struct Argument {
using anakin_max_shape_t = std::map<std::string, std::vector<int>>; using anakin_max_shape_t = std::map<std::string, std::vector<int>>;
bool Has(const std::string& key) const { return valid_fields_.count(key); } bool Has(const std::string& key) const { return valid_fields_.count(key); }
// If we set the model using config.SetModelBuffer,
// the model and parameter will occupy additional CPU resources.
// Use this interface to release these resources.
void PartiallyRelease() { void PartiallyRelease() {
if (Has("model_program_path")) { if (Has("model_program_path")) {
if (Has("model_from_memory") && model_from_memory()) { if (Has("model_from_memory") && model_from_memory()) {
...@@ -130,6 +133,7 @@ struct Argument { ...@@ -130,6 +133,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
...@@ -84,13 +84,15 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -84,13 +84,15 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("program", pass->Set("program",
new framework::ProgramDesc *(&argument->main_program())); new framework::ProgramDesc *(&argument->main_program()));
bool enable_int8 = argument->tensorrt_precision_mode() == auto precision_mode = argument->tensorrt_precision_mode();
AnalysisConfig::Precision::kInt8; bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
pass->Set("predictor_id", new int(argument->predictor_id())); pass->Set("predictor_id", new int(argument->predictor_id()));
bool use_calib_mode = argument->tensorrt_use_calib_mode(); bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("enable_int8", new bool(enable_int8)); pass->Set("enable_int8", new bool(enable_int8));
pass->Set("use_calib_mode", new bool(use_calib_mode)); pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
bool use_static_engine = argument->tensorrt_use_static_engine(); bool use_static_engine = argument->tensorrt_use_static_engine();
bool model_from_memory = argument->model_from_memory(); bool model_from_memory = argument->model_from_memory();
......
...@@ -149,6 +149,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -149,6 +149,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
graph_var_map[node->Name()] = node; graph_var_map[node->Name()] = node;
} }
} }
auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
bool enable_fp16 = false;
if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true;
auto enable_int8 = Get<bool>("enable_int8"); auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode"); auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *Agent(node).subgraph(); auto &subgraph_nodes = *Agent(node).subgraph();
...@@ -216,6 +219,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -216,6 +219,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "calibration_data", calibration_data); SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "enable_fp16", enable_fp16);
SetAttr(op_desc->Proto(), "use_calib_mode", use_calib_mode); SetAttr(op_desc->Proto(), "use_calib_mode", use_calib_mode);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
SetAttr(op_desc->Proto(), "predictor_id", predictor_id); SetAttr(op_desc->Proto(), "predictor_id", predictor_id);
...@@ -244,7 +248,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -244,7 +248,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id), .Create(engine_key + std::to_string(predictor_id),
Get<int>("max_batch_size"), Get<int>("workspace_size"), Get<int>("max_batch_size"), Get<int>("workspace_size"),
enable_int8, calibrator.get(), Get<int>("gpu_device_id")); precision_mode, calibrator.get(), Get<int>("gpu_device_id"));
bool need_serialize = (use_static_engine && !load_from_memory); bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) { if (need_serialize) {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p ...@@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p
cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(ir_graph_clean_pass SRCS ir_graph_clean_pass.cc DEPS analysis_pass)
cc_library(analysis_passes SRCS passes.cc DEPS cc_library(analysis_passes SRCS passes.cc DEPS
ir_graph_build_pass ir_graph_build_pass
...@@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS ...@@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS
memory_optim_pass memory_optim_pass
inference_op_replace_pass inference_op_replace_pass
ir_graph_to_program_pass ir_graph_to_program_pass
ir_graph_clean_pass
) )
set(analysis_deps ${analysis_deps} set(analysis_deps ${analysis_deps}
......
...@@ -20,9 +20,9 @@ namespace inference { ...@@ -20,9 +20,9 @@ namespace inference {
namespace analysis { namespace analysis {
void InferenceOpReplacePass::RunImpl(Argument* argument) { void InferenceOpReplacePass::RunImpl(Argument* argument) {
if (!argument->use_gpu()) return;
std::unordered_map<std::string, std::string> replaced_map{ std::unordered_map<std::string, std::string> replaced_map{
{"conditional_block", "conditional_block_infer"}, {"conditional_block", "conditional_block_infer"},
{"merge_lod_tensor", "merge_lod_tensor_infer"},
}; };
auto& graph = argument->main_graph(); auto& graph = argument->main_graph();
......
...@@ -12,56 +12,36 @@ ...@@ -12,56 +12,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace framework { namespace inference {
namespace ir { namespace analysis {
class InferCleanGraphPass : public FusePassBase { void IrInferCleanGraphPass::RunImpl(Argument* argument) {
public: auto& graph = argument->main_graph();
virtual ~InferCleanGraphPass() {} auto is_valid_node = [](framework::ir::Node* x) {
protected:
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
}; };
std::unordered_set<const Node*> invalid_nodes; std::unordered_set<const framework::ir::Node*> invalid_nodes;
int valid_op = 0; int valid_op = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph.Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node); PADDLE_ENFORCE_NOT_NULL(node);
if (is_valid_node(node)) { if (is_valid_node(node)) {
invalid_nodes.insert(node); invalid_nodes.insert(node);
} else if (node->IsOp()) { } else if (node->IsOp()) {
// Collect all the operators to help tracking number of operators.
++valid_op; ++valid_op;
} }
} }
GraphSafeRemoveNodes(graph, invalid_nodes); GraphSafeRemoveNodes(&graph, invalid_nodes);
}
AddStatis(valid_op);
}
void CleanEdges(std::vector<Node*>* nodes, } // namespace analysis
const std::unordered_set<Node*>& to_remove) const { } // namespace inference
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
}
};
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(infer_clean_graph_pass,
paddle::framework::ir::InferCleanGraphPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class IrInferCleanGraphPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir_graph_clean_pass"; }
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) { ...@@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) {
void MemoryOptimizePass::CollectVarMemorySize( void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const { space_table_t* space_table) const {
const int fake_batch_size = 1; const int fake_batch_size = 1;
auto valid_var = [&](framework::ir::Node* node) -> bool { auto valid_var = [&](framework::ir::Node* node) -> bool {
std::set<std::string> invalid_op = {"while", "conditional_block", std::set<std::string> invalid_op = {"while",
"conditional_block",
"tensorrt_engine", "tensorrt_engine",
"conditional_block_infer"}; "conditional_block_infer",
"merge_lod_tensor_infer",
"merge_lod_tensor",
"equal",
"lod_reset"};
for (auto* tmp : node->inputs) { for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp()); CHECK(tmp->IsOp());
std::string op_type = tmp->Op()->Type(); std::string op_type = tmp->Op()->Type();
......
...@@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass { ...@@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass {
int sort_kind) const; int sort_kind) const;
void CollectVarMemorySize(space_table_t *space_table) const; void CollectVarMemorySize(space_table_t *space_table) const;
void CollectVarMemorySize0(space_table_t *space_table) const;
void CollectVarMemorySize( void CollectVarMemorySize(
const std::unordered_map<std::string, size_t> &batch_var_ave_dim, const std::unordered_map<std::string, size_t> &batch_var_ave_dim,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -32,6 +33,8 @@ PassRegistry::PassRegistry() { ...@@ -32,6 +33,8 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass)); std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass", passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass)); std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass", passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass)); std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace( passes_.emplace(
......
...@@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram(
const std::shared_ptr<framework::ProgramDesc> &program) { const std::shared_ptr<framework::ProgramDesc> &program) {
if (!program) { if (!program) {
if (!LoadProgramDesc()) return false; if (!LoadProgramDesc()) return false;
// If not cloned, the parameters should be loaded. // If not cloned, the parameters should be loaded.
// If config_.ir_optim() is True, parameters is loaded in // If config_.ir_optim() is True, parameters is loaded in
// OptimizeInferenceProgram(), but other persistable variables // OptimizeInferenceProgram(), but other persistable variables
...@@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram(
// So in both case, create persistable variables at first. // So in both case, create persistable variables at first.
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// Optimize the program, and load parameters and modify them in the // if enable_ir_optim_ is false,
// scope_. // the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will
// This will change the scope_ address. // not be executed.
if (config_.ir_optim()) {
status_ir_optim_enabled_ = true;
OptimizeInferenceProgram(); OptimizeInferenceProgram();
} else {
// Load parameters
LOG(INFO) << "load parameters ";
LoadParameters();
}
} else { } else {
// If the program is passed from external, no need to optimize it, this // If the program is passed from external, no need to optimize it, this
// logic is used in the clone scenario. // logic is used in the clone scenario.
...@@ -363,6 +355,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -363,6 +355,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
void AnalysisPredictor::PrepareArgument() { void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetStaticMemoryOptim(config_.static_memory_optim_); argument_.SetStaticMemoryOptim(config_.static_memory_optim_);
argument_.SetStaticMemoryOptimForceUpdate( argument_.SetStaticMemoryOptimForceUpdate(
...@@ -434,8 +427,6 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -434,8 +427,6 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() { void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;
PrepareArgument(); PrepareArgument();
Analyzer().Run(&argument_); Analyzer().Run(&argument_);
......
...@@ -175,10 +175,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -175,10 +175,8 @@ class AnalysisPredictor : public PaddlePredictor {
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
bool status_program_optimized_{false};
bool status_is_cloned_{false}; bool status_is_cloned_{false};
bool status_use_gpu_{false}; bool status_use_gpu_{false};
bool status_ir_optim_enabled_{false};
}; };
} // namespace paddle } // namespace paddle
...@@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) { ...@@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) {
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned off, so program shouldn't be optimized. // ir is turned off, so program shouldn't be optimized.
ASSERT_FALSE(predictor->status_program_optimized_);
LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size(); LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size();
// 2. Dummy Input Data // 2. Dummy Input Data
...@@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) { ...@@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_TRUE(predictor->sub_scope_); ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// ir is turned on, so program should be optimized.
ASSERT_TRUE(predictor->status_program_optimized_);
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor; PaddleTensor tensor;
......
...@@ -367,13 +367,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -367,13 +367,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto* builder = predictor_.config_.pass_builder(); auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({ builder->SetPasses({
"infer_clean_graph_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass",
}); });
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses( predictor_.argument_.SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"}); {"ir_graph_clean_pass", "ir_analysis_pass", "memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_); predictor_.argument_.SetQuantVarScales(scales_);
} }
......
...@@ -46,6 +46,7 @@ struct AnalysisConfig { ...@@ -46,6 +46,7 @@ struct AnalysisConfig {
enum class Precision { enum class Precision {
kFloat32 = 0, kFloat32 = 0,
kInt8, kInt8,
kHalf,
}; };
/** Set model with a directory. /** Set model with a directory.
......
...@@ -71,7 +71,6 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { ...@@ -71,7 +71,6 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void PaddlePassBuilder::ClearPasses() { passes_.clear(); } void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({ const std::vector<std::string> kTRTSubgraphPasses({
"infer_clean_graph_pass", //
"conv_affine_channel_fuse_pass", // "conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
...@@ -90,7 +89,6 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -90,7 +89,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
// The following passes works for Anakin sub-graph engine. // The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({ const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"simplify_anakin_priorbox_detection_out_pass", // "simplify_anakin_priorbox_detection_out_pass", //
"fillconstant_elementwisemul_fuse", // "fillconstant_elementwisemul_fuse", //
...@@ -104,7 +102,6 @@ const std::vector<std::string> kAnakinSubgraphPasses({ ...@@ -104,7 +102,6 @@ const std::vector<std::string> kAnakinSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"infer_clean_graph_pass", //
// "identity_scale_op_clean_pass", // // "identity_scale_op_clean_pass", //
"conv_affine_channel_fuse_pass", // "conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
...@@ -140,8 +137,7 @@ void GpuPassStrategy::EnableNgraph() { ...@@ -140,8 +137,7 @@ void GpuPassStrategy::EnableNgraph() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will // NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones. // not be damaged by smaller ones.
passes_.assign({"infer_clean_graph_pass", // passes_.assign({"attention_lstm_fuse_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", // // "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", //
......
...@@ -72,7 +72,7 @@ class PaddlePassBuilder { ...@@ -72,7 +72,7 @@ class PaddlePassBuilder {
protected: protected:
std::vector<std::string> analysis_passes_{ std::vector<std::string> analysis_passes_{
{"ir_graph_build_pass", "ir_analysis_pass", {"ir_graph_build_pass", "ir_graph_clean_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass", "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass",
"inference_op_replace_pass"}}; "inference_op_replace_pass"}};
std::vector<std::string> passes_; std::vector<std::string> passes_;
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include <gtest/gtest.h> #include <gtest/gtest.h> // NOLINT
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
...@@ -27,10 +27,8 @@ TEST(OpConverter, ConvertBlock) { ...@@ -27,10 +27,8 @@ TEST(OpConverter, ConvertBlock) {
auto* conv2d_op = block->AppendOp(); auto* conv2d_op = block->AppendOp();
// init trt engine // init trt engine
cudaStream_t stream_;
std::unique_ptr<TensorRTEngine> engine_; std::unique_ptr<TensorRTEngine> engine_;
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); engine_.reset(new TensorRTEngine(5, 1 << 15));
engine_.reset(new TensorRTEngine(5, 1 << 15, stream_));
engine_->InitNetwork(); engine_->InitNetwork();
engine_->DeclareInput("conv2d-X", nvinfer1::DataType::kFLOAT, engine_->DeclareInput("conv2d-X", nvinfer1::DataType::kFLOAT,
......
...@@ -80,8 +80,7 @@ class TRTConvertValidation { ...@@ -80,8 +80,7 @@ class TRTConvertValidation {
if_add_batch_(if_add_batch), if_add_batch_(if_add_batch),
max_batch_size_(max_batch_size) { max_batch_size_(max_batch_size) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset( engine_.reset(new TensorRTEngine(max_batch_size, workspace_size));
new TensorRTEngine(max_batch_size, workspace_size, false, nullptr, 0));
engine_->InitNetwork(); engine_->InitNetwork();
} }
......
...@@ -51,7 +51,27 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -51,7 +51,27 @@ void TensorRTEngine::FreezeNetwork() {
// build engine. // build engine.
infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_); infer_builder_->setMaxWorkspaceSize(max_workspace_);
if (enable_int8_) { bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
#if IS_TRT_VERSION_GE(5000)
if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_->setFp16Mode(support_fp16);
if (!support_fp16) {
LOG(INFO) << "You specify FP16 mode, but the hardware do not support "
"FP16 speed up, use FP32 instead.";
} else {
LOG(INFO) << "Run Paddle-TRT FP16 mode. ";
}
}
#else
if (enable_fp16)
LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT "
"is at least 5."
" So, use FP32 to run.";
#endif
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
if (enable_int8) {
infer_builder_->setInt8Mode(true); infer_builder_->setInt8Mode(true);
if (calibrator_) { if (calibrator_) {
infer_builder_->setInt8Calibrator(calibrator_); infer_builder_->setInt8Calibrator(calibrator_);
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...@@ -61,12 +62,14 @@ class TensorRTEngine { ...@@ -61,12 +62,14 @@ class TensorRTEngine {
nvinfer1::Weights w_; nvinfer1::Weights w_;
}; };
TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false, TensorRTEngine(
int max_batch, int max_workspace,
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr, int device_id = 0, TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
enable_int8_(enable_int8), precision_(precision),
calibrator_(calibrator), calibrator_(calibrator),
device_id_(device_id), device_id_(device_id),
logger_(logger) {} logger_(logger) {}
...@@ -168,7 +171,7 @@ class TensorRTEngine { ...@@ -168,7 +171,7 @@ class TensorRTEngine {
// the max memory size the engine uses // the max memory size the engine uses
int max_workspace_; int max_workspace_;
bool enable_int8_; AnalysisConfig::Precision precision_;
TRTInt8Calibrator* calibrator_; TRTInt8Calibrator* calibrator_;
// batch size of the current data, will be updated each Executation. // batch size of the current data, will be updated each Executation.
int batch_size_{-1}; int batch_size_{-1};
...@@ -231,12 +234,12 @@ class TRTEngineManager { ...@@ -231,12 +234,12 @@ class TRTEngineManager {
return engines_.at(name).get(); return engines_.at(name).get();
} }
TensorRTEngine* Create(std::string name, int max_batch, int max_workspace, TensorRTEngine* Create(
bool enable_int8 = false, std::string name, int max_batch, int max_workspace,
TRTInt8Calibrator* calibrator = nullptr, AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
int device_id = 0, TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global()) { nvinfer1::ILogger& logger = NaiveLogger::Global()) {
auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8, auto* p = new TensorRTEngine(max_batch, max_workspace, precision,
calibrator, device_id, logger); calibrator, device_id, logger);
engines_[name].reset(p); engines_[name].reset(p);
return p; return p;
......
...@@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
private: protected:
void RunImpl(const framework::Scope &scope, void RunBase(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
...@@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase {
out_lod->insert(out_lod->begin(), x.lod()[i]); out_lod->insert(out_lod->begin(), x.lod()[i]);
} }
} }
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
}
};
class MergeLoDTensorInferOp : public MergeLoDTensorOp {
public:
MergeLoDTensorInferOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: MergeLoDTensorOp(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
RunBase(scope, dev_place);
framework::Variable *in_true_var = scope.FindVar(Input("InTrue"));
framework::Variable *in_false_var = scope.FindVar(Input("InFalse"));
in_true_var->Clear();
in_false_var->Clear();
in_true_var->GetMutable<framework::LoDTensor>();
in_false_var->GetMutable<framework::LoDTensor>();
}
}; };
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -196,3 +223,7 @@ namespace ops = paddle::operators; ...@@ -196,3 +223,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp,
ops::MergeLoDTensorOpProtoMaker, ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker);
REGISTER_OPERATOR(merge_lod_tensor_infer, ops::MergeLoDTensorInferOp,
ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape,
paddle::framework::EmptyGradOpMaker);
...@@ -48,12 +48,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -48,12 +48,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
int workspace_size_; int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_; bool enable_int8_;
bool enable_fp16_;
bool use_calib_mode_; bool use_calib_mode_;
std::string calibration_data_; std::string calibration_data_;
std::string engine_key_; std::string engine_key_;
bool calibration_mode_; bool calibration_mode_;
int predictor_id_; int predictor_id_;
int device_id_; int device_id_;
AnalysisConfig::Precision precision_mode_;
public: public:
TensorRTEngineOp(const std::string &type, TensorRTEngineOp(const std::string &type,
...@@ -66,6 +68,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -66,6 +68,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
workspace_size_ = Attr<int>("workspace_size"); workspace_size_ = Attr<int>("workspace_size");
device_id_ = Attr<int>("gpu_id"); device_id_ = Attr<int>("gpu_id");
enable_int8_ = Attr<bool>("enable_int8"); enable_int8_ = Attr<bool>("enable_int8");
enable_fp16_ = Attr<bool>("enable_fp16");
use_calib_mode_ = Attr<bool>("use_calib_mode"); use_calib_mode_ = Attr<bool>("use_calib_mode");
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
...@@ -93,6 +96,13 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -93,6 +96,13 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_ + std::to_string(predictor_id_)); .Get(engine_key_ + std::to_string(predictor_id_));
} }
precision_mode_ = AnalysisConfig::Precision::kFloat32;
if (enable_int8_) {
precision_mode_ = AnalysisConfig::Precision::kInt8;
}
if (enable_fp16_) {
precision_mode_ = AnalysisConfig::Precision::kHalf;
}
} }
protected: protected:
...@@ -141,7 +151,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -141,7 +151,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_buffers, runtime_batch, engine_key_, dev_place)); calib_buffers, runtime_batch, engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() { calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine( calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, max_batch_size_, workspace_size_, precision_mode_,
calib_res->calib_.get(), calib_res->calib_.get(),
boost::get<platform::CUDAPlace>(dev_place).device)); boost::get<platform::CUDAPlace>(dev_place).device));
VLOG(3) << "start the calib trt engine thread"; VLOG(3) << "start the calib trt engine thread";
...@@ -241,7 +251,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -241,7 +251,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_ = trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key_ + std::to_string(predictor_id_), .Create(engine_key_ + std::to_string(predictor_id_),
max_batch_size_, workspace_size_, enable_int8_, max_batch_size_, workspace_size_, precision_mode_,
calibrator_.get(), device_id_); calibrator_.get(), device_id_);
PrepareTRTEngine(scope, trt_engine_); PrepareTRTEngine(scope, trt_engine_);
} }
......
...@@ -105,6 +105,7 @@ TEST(TensorRTEngineOp, manual) { ...@@ -105,6 +105,7 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("predictor_id", 1); engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string("")); engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false)); engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("enable_fp16", static_cast<bool>(false));
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false)); engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"})); std::vector<std::string>({"z0"}));
...@@ -205,6 +206,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -205,6 +206,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("predictor_id", 1); engine_op_desc.SetAttr("predictor_id", 1);
engine_op_desc.SetAttr("calibration_data", std::string("")); engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false)); engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("enable_fp16", static_cast<bool>(false));
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false)); engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"})); std::vector<std::string>({"z3"}));
......
...@@ -199,6 +199,7 @@ void BindAnalysisConfig(py::module *m) { ...@@ -199,6 +199,7 @@ void BindAnalysisConfig(py::module *m) {
py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision") py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
.value("Float32", AnalysisConfig::Precision::kFloat32) .value("Float32", AnalysisConfig::Precision::kFloat32)
.value("Int8", AnalysisConfig::Precision::kInt8) .value("Int8", AnalysisConfig::Precision::kInt8)
.value("Half", AnalysisConfig::Precision::kHalf)
.export_values(); .export_values();
analysis_config.def(py::init<const AnalysisConfig &>()) analysis_config.def(py::init<const AnalysisConfig &>())
......
...@@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor ...@@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.layers.control_flow import split_lod_tensor from paddle.fluid.layers.control_flow import split_lod_tensor
from paddle.fluid.layers.control_flow import merge_lod_tensor from paddle.fluid.layers.control_flow import merge_lod_tensor
from paddle.fluid.layer_helper import LayerHelper
class TestCPULoDTensorArrayOps(unittest.TestCase): class TestCPULoDTensorArrayOps(unittest.TestCase):
...@@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
expect_false=expect_false, expect_false=expect_false,
expect_out=tensor) expect_out=tensor)
def test_split_and_merge_lod_tensor_level_0(self): def split_and_merge_lod_tensor_level_0(self, use_merge_lod_infer=False):
tensor = core.LoDTensor() tensor = core.LoDTensor()
tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place())
tensor.set_recursive_sequence_lengths([[3, 6, 1]]) tensor.set_recursive_sequence_lengths([[3, 6, 1]])
...@@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
mask=mask, mask=mask,
expect_true=expect_true, expect_true=expect_true,
expect_false=expect_false, expect_false=expect_false,
expect_out=tensor) expect_out=tensor,
use_merge_lod_infer=use_merge_lod_infer)
def main(self, tensor, mask, expect_true, expect_false, expect_out,
level=0): def test_split_and_merge_lod_tensor_1(self):
self.split_and_merge_lod_tensor_level_0()
def test_split_and_merge_lod_tensor_2(self):
self.split_and_merge_lod_tensor_level_0(True)
def main(self,
tensor,
mask,
expect_true,
expect_false,
expect_out,
level=0,
use_merge_lod_infer=False):
place = self.place() place = self.place()
program = Program() program = Program()
with program_guard(program): with program_guard(program):
...@@ -103,10 +117,35 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -103,10 +117,35 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
out_true, out_false = split_lod_tensor(input=x, mask=y, level=level) out_true, out_false = split_lod_tensor(input=x, mask=y, level=level)
out_true.persistable = True out_true.persistable = True
out_false.persistable = True out_false.persistable = True
if use_merge_lod_infer:
input_dict = {
'X': x,
'Mask': mask,
'InTrue': out_true,
'InFalse': out_false,
'level': level
}
helper = LayerHelper('merge_lod_tensor_infer')
out = helper.create_variable_for_type_inference(
dtype=out_true.dtype)
helper.append_op(
type='merge_lod_tensor_infer',
inputs={
'X': x,
'Mask': y,
'InTrue': out_true,
'InFalse': out_false
},
outputs={'Out': out},
attrs={'level': level})
out.persistable = True
else:
out = merge_lod_tensor( out = merge_lod_tensor(
in_true=out_true, in_false=out_false, mask=y, x=x, level=level) in_true=out_true,
in_false=out_false,
mask=y,
x=x,
level=level)
out.persistable = True out.persistable = True
exe = Executor(place) exe = Executor(place)
...@@ -122,7 +161,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): ...@@ -122,7 +161,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase):
var_false = scope.find_var(out_false.name).get_tensor() var_false = scope.find_var(out_false.name).get_tensor()
var_out = scope.find_var(out.name).get_tensor() var_out = scope.find_var(out.name).get_tensor()
if not use_merge_lod_infer:
self.check_tensor_same(var_true, expect_true) self.check_tensor_same(var_true, expect_true)
self.check_tensor_same(var_false, expect_false) self.check_tensor_same(var_false, expect_false)
self.check_tensor_same(var_out, expect_out) self.check_tensor_same(var_out, expect_out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册