未验证 提交 cdf5f6fb 编写于 作者: G GaoWei8 提交者: GitHub

Add an inference interface to disable FC padding (#22097)

* Add an interface of disabling FC padding
* fix bert regression
* polish fc padding interface
* recover pass function
* fix argument error
* fix mkldnn error
上级 68c76793
......@@ -90,17 +90,17 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
desc.SetAttr("activation_type", activation_type);
// This is to add padding for dimension 128 on concern of MKL performance
auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto* weight_data = weight->data<float>();
auto weight_dims = weight->dims();
int weight_num = product(weight_dims);
int w_h = weight_dims[0];
int w_w = weight_dims[1];
bool use_gpu = Has("use_gpu") ? Get<bool>("use_gpu") : false;
bool use_fc_padding =
Has("use_fc_padding") ? Get<bool>("use_fc_padding") : true;
if (!use_gpu && use_fc_padding) {
auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto* weight_data = weight->data<float>();
auto weight_dims = weight->dims();
int weight_num = product(weight_dims);
int w_h = weight_dims[0];
int w_w = weight_dims[1];
if (w_h % 128 == 0 && w_w % 128 == 0) {
auto* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) {
......
......@@ -172,6 +172,7 @@ struct Argument {
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
......
......@@ -142,6 +142,14 @@ void IRPassManager::CreatePasses(Argument *argument,
disable_logs_ = argument->disable_logs();
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0;
for (const std::string &pass_n : passes) {
if (pass_n == "fc_mkldnn_pass") {
fc_mkldnn_pass = 1;
}
}
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
}
pre_pass = pass_name;
......@@ -150,47 +158,12 @@ void IRPassManager::CreatePasses(Argument *argument,
}
}
bool IRPassManager::HasPass(const std::string &pass_type) {
if (passes_.empty()) return false;
auto it = std::find_if(
passes_.begin(), passes_.end(),
[&](std::unique_ptr<Pass> &pass) { return pass->Type() == pass_type; });
return it != passes_.end();
}
std::unique_ptr<Pass> &IRPassManager::GetPass(const std::string &pass_type) {
PADDLE_ENFORCE_EQ(passes_.empty(), false,
platform::errors::PreconditionNotMet(
"The list of passes cannot be empty."));
auto it = std::find_if(passes_.begin(), passes_.end(),
[&](const std::unique_ptr<Pass> &pass) {
return pass->Type() == pass_type;
});
PADDLE_ENFORCE_NE(it, passes_.end(),
platform::errors::PermissionDenied(
"You cannot get pass which was not added earlier."));
return *it;
}
// Some passes depend on each other. This method serves for exchanging
// information between them.
void IRPassManager::UpdatePasses() {
// Update padding settings for fc_fuse_pass. Skipp adding padding for
// MKL-DNN-based FC
bool use_fc_padding = !HasPass("fc_mkldnn_pass");
if (HasPass("fc_fuse_pass")) {
auto &fc_fuse_pass = GetPass("fc_fuse_pass");
fc_fuse_pass->Set<bool>("use_fc_padding", new bool(use_fc_padding));
}
}
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL(graph.get(), platform::errors::PreconditionNotMet(
"Graph cannot be NULL."));
UpdatePasses();
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
......
......@@ -54,9 +54,6 @@ class IRPassManager final {
private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
bool HasPass(const std::string &pass_type);
std::unique_ptr<Pass> &GetPass(const std::string &pass_type);
void UpdatePasses();
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_;
......
......@@ -82,6 +82,12 @@ void AnalysisConfig::DisableGpu() {
Update();
}
void AnalysisConfig::DisableFCPadding() {
use_fc_padding_ = false;
Update();
}
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#define CP_MEMBER(member__) member__ = other.member__;
......@@ -94,6 +100,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
prog_file_ = std::move(other.prog_file_);
params_file_ = std::move(other.params_file_);
CP_MEMBER(use_fc_padding_);
// GPU related.
CP_MEMBER(use_gpu_);
CP_MEMBER(use_cudnn_);
......@@ -354,6 +361,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_;
ss << use_gpu_;
ss << use_fc_padding_;
ss << device_id_;
ss << memory_pool_init_size_mb_;
......
......@@ -381,6 +381,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu());
argument_.SetUseFcPadding(config_.use_fc_padding());
argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
......
......@@ -77,6 +77,14 @@ struct AnalysisConfig {
*/
const std::string& params_file() const { return params_file_; }
// Padding related.
/** Turn off Padding.
*/
void DisableFCPadding();
/** A bool state telling whether padding is turned on.
*/
bool use_fc_padding() const { return use_fc_padding_; }
// GPU related.
/**
......@@ -293,6 +301,9 @@ struct AnalysisConfig {
bool use_cudnn_{false};
// Padding related
bool use_fc_padding_{true};
// TensorRT related.
bool use_tensorrt_{false};
// For workspace_size, refer it from here:
......
......@@ -145,7 +145,10 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
return true;
}
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
void SetConfig(AnalysisConfig *config) {
config->SetModel(FLAGS_infer_model);
config->DisableFCPadding();
}
void profile(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig config;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册