未验证 提交 c0034b5b 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] optimize some code and fix some bug (#48780)

* clean ir_pass_manager and fix map_depthwise_conv_to_conv_pass

* fix unitest timeout
上级 d341ce9d
......@@ -29,6 +29,11 @@ void FillConstData(phi::DenseTensor* out_t, T value) {
}
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op =
......
......@@ -16,7 +16,12 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace framework {
......@@ -620,34 +625,45 @@ void FloatToHalfPass::ConvertWeightsData() const {
for (const auto& var_name : var_names) {
if (vars_convert_to_half_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
half_tensor.set_type(DTYPE); \
auto* half_data = half_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
half_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
half_tensor, platform::CPUPlace(), origin_tensor)
auto* var = scope->FindLocalVar(var_name);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
if (var->IsType<phi::DenseTensor>()) {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor half_tensor;
half_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
half_tensor.set_type(half_precision_);
if (half_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
auto* half_data =
half_tensor.mutable_data<phi::dtype::float16>(phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]);
}
}
} else if (half_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
auto* half_data =
half_tensor.mutable_data<phi::dtype::bfloat16>(phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
}
}
}
origin_tensor->clear();
paddle::framework::TensorCopySync(
half_tensor, phi::CPUPlace{}, origin_tensor);
}
#undef CONVERT_TENSOR_DTYPE
}
}
......
......@@ -22,9 +22,6 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace framework {
......
......@@ -41,6 +41,7 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
op_desc->Flush();
++found_count;
}
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace inference {
......@@ -305,42 +306,18 @@ void IRPassManager::CreatePasses(Argument *argument,
}
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL(
graph.get(),
platform::errors::PreconditionNotMet("Graph cannot be NULL."));
graph.get(), platform::errors::InvalidArgument("Graph cannot be null."));
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
}
// delete_fill_constant_op_pass is not apply under trt dynamic shape
if (pass->Type() == "delete_fill_constant_op_pass") {
bool use_dynamic = pass->Get<bool>("with_dynamic_shape");
if (use_dynamic) continue;
}
graph.reset(pass->Apply(graph.release()));
}
return graph;
}
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
graph->reset(pass->Apply(the_graph));
return *desc.Proto();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -48,15 +48,9 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false};
};
......
......@@ -108,6 +108,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
}
#else
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false;
#endif
Update();
......@@ -299,7 +300,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config: ", key));
"invalid key %s in IPU config: ", key));
}
switch (ipu_config_mapper_.at(key)) {
case ipu_config_code::ipu_device_num:
......@@ -335,10 +336,9 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
case ipu_config_code::ipu_enable_model_runtime_executor:
ipu_enable_model_runtime_executor_ = string2bool(value);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key));
"invalid key %s in IPU config", key));
break;
}
}
......@@ -1438,7 +1438,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const {
return trt_allow_build_at_runtime_;
}
void AnalysisConfig::Exp_DisableMixedInferOps(
void AnalysisConfig::Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}
......
......@@ -1009,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_DisableMixedInferOps(
void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
......
......@@ -418,7 +418,7 @@ if(WITH_GPU)
analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 60)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册