未验证 提交 82da1f14 编写于 作者: C ceci3 提交者: GitHub

fix unittest (#1633)

上级 cb57443e
repos: repos:
- repo: https://github.com/Lucas-C/pre-commit-hooks.git - repo: https://github.com/Lucas-C/pre-commit-hooks.git
sha: v1.0.1 rev: v1.3.1
hooks: hooks:
- id: remove-crlf - id: remove-crlf
files: .∗ files: .∗
...@@ -9,12 +9,12 @@ repos: ...@@ -9,12 +9,12 @@ repos:
- id: remove-tabs - id: remove-tabs
files: \.(md|yml)$ files: \.(md|yml)$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 rev: v0.16.2
hooks: hooks:
- id: yapf - id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0 rev: v4.4.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
- id: check-merge-conflict - id: check-merge-conflict
......
...@@ -37,25 +37,34 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -37,25 +37,34 @@ _logger = get_logger(__name__, level=logging.INFO)
############################################################################################################ ############################################################################################################
_train_config_default = { _train_config_default = {
# configs of training aware quantization with infermodel # configs of training aware quantization with infermodel
"num_epoch": 1000, # training epoch num "num_epoch":
"max_iter": -1, # max training iteration num 1000, # training epoch num
"max_iter":
-1, # max training iteration num
"save_iter_step": "save_iter_step":
1000, # save quant model checkpoint every save_iter_step iteration 1000, # save quant model checkpoint every save_iter_step iteration
"learning_rate": 0.0001, # learning rate "learning_rate":
"weight_decay": 0.0001, # weight decay 0.0001, # learning rate
"use_pact": False, # use pact quantization or not "weight_decay":
0.0001, # weight decay
"use_pact":
False, # use pact quantization or not
# quant model checkpoints save path # quant model checkpoints save path
"quant_model_ckpt_path": "./quant_model_checkpoints/", "quant_model_ckpt_path":
"./quant_model_checkpoints/",
# storage directory of teacher model + teacher model name (excluding suffix) # storage directory of teacher model + teacher model name (excluding suffix)
"teacher_model_path_prefix": None, "teacher_model_path_prefix":
None,
# storage directory of model + model name (excluding suffix) # storage directory of model + model name (excluding suffix)
"model_path_prefix": None, "model_path_prefix":
None,
""" distillation node configuration: """ distillation node configuration:
the name of the distillation supervision nodes is configured as a list, the name of the distillation supervision nodes is configured as a list,
and the teacher node and student node are arranged in pairs. and the teacher node and student node are arranged in pairs.
for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"] for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"]
""" """
"node": None "node":
None
} }
...@@ -184,9 +193,9 @@ def quant_aware_with_infermodel(executor, ...@@ -184,9 +193,9 @@ def quant_aware_with_infermodel(executor,
place, place,
quant_config, quant_config,
scope=scope, scope=scope,
act_preprocess_func=act_preprocess_func, act_preprocess_func=None,
optimizer_func=optimizer_func, optimizer_func=None,
executor=pact_executor, executor=None,
for_test=True) for_test=True)
train_program = quant_aware( train_program = quant_aware(
train_program, train_program,
...@@ -225,8 +234,7 @@ def quant_aware_with_infermodel(executor, ...@@ -225,8 +234,7 @@ def quant_aware_with_infermodel(executor,
test_callback(compiled_test_prog, test_feed_names, test_callback(compiled_test_prog, test_feed_names,
test_fetch_list, checkpoint_name) test_fetch_list, checkpoint_name)
iter_sum += 1 iter_sum += 1
if train_config["max_iter"] >= 0 and iter_sum > train_config[ if train_config["max_iter"] >= 0 and iter_sum > train_config["max_iter"]:
"max_iter"]:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册