未验证 提交 9a918144 编写于 作者: W whs 提交者: GitHub

Fix config and wrapper of dataloader (#1192)

* Fix config and wrapper of dataloader

* Fix wrapper of dataloader
1. Skip wrapper when input is format like [{'x': data1, 'y': data2}]
2. Fix scripts in nlp demo of ACT
上级 8cc39f55
......@@ -241,9 +241,10 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
all_config = load_config(args.config_path)
if train_config is not None:
train_config.optimizer_builder[
if "TrainConfig" in all_config:
all_config["TrainConfig"]["optimizer_builder"][
'apply_decay_param_fun'] = apply_decay_param_fun
train_dataloader, eval_dataloader = reader()
......@@ -255,11 +256,10 @@ if __name__ == '__main__':
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
config=args.config_path,
config=all_config,
train_dataloader=train_dataloader,
eval_callback=eval_function if compress_config is None or
'HyperParameterOptimization' not in compress_config else
eval_dataloader,
eval_callback=eval_function
if 'HyperParameterOptimization' not in all_config else eval_dataloader,
eval_dataloader=eval_dataloader)
ac.compress()
export FLAGS_cudnn_deterministic=True
python run.py \
--model_type='ppminilm' \
--model_dir='./afqmc/' \
--model_filename='inference.pdmodel' \
--params_filename='inference.pdiparams' \
--model_dir='./all_original_models/AFQMC' \
--model_filename='infer.pdmodel' \
--params_filename='infer.pdiparams' \
--dataset='clue' \
--save_dir='./save_afqmc_pruned/' \
--batch_size=16 \
......
......@@ -158,7 +158,7 @@ class AutoCompression:
self.model_type = self._get_model_type(self._exe, model_dir,
model_filename, params_filename)
if train_config is not None and train_config.use_fleet:
if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True)
if with_variable_shape(
......@@ -191,7 +191,8 @@ class AutoCompression:
self.strategy_config)
self.train_config = self._get_final_train_config(
train_config, self._strategy, self.model_type)
self.train_config, self._strategy, self.model_type)
_logger.info(f"Selected strategies: {self._strategy}")
def _get_final_train_config(self, train_config, strategy_config,
model_type):
......@@ -546,7 +547,12 @@ class AutoCompression:
return tmp_dir
def compress(self):
assert len(self._strategy) > 0
self.tmp_dir = self.create_tmp_dir(self.final_dir)
strategy = None
config = None
train_config = None
strategy_idx = None
for strategy_idx, (
strategy, config, train_config
) in enumerate(zip(self._strategy, self._config, self.train_config)):
......
......@@ -40,7 +40,6 @@ SUPPORTED_CONFIG = [
"UnstructurePrune",
"TransformerPrune",
"ASPPrune",
"TrainConfig",
]
TRAIN_CONFIG_NAME = "TrainConfig"
......
......@@ -21,16 +21,21 @@ def get_feed_vars(model_dir, model_filename, params_filename):
return feed_target_names
def _valid_format(data):
is_dict = isinstance(data, dict)
list_with_one_dict = isinstance(
data, list) and len(data) == 1 and isinstance(data[0], dict)
return is_dict or list_with_one_dict
def wrap_dataloader(dataloader, names):
"""Create a wrapper of dataloader if the data returned by the dataloader is not a dict.
And the names will be the keys of dict returned by the wrapper.
"""
if dataloader is None:
return dataloader
assert isinstance(dataloader, paddle.io.DataLoader)
assert len(dataloader) > 0
data = next(dataloader())
if isinstance(data, dict):
if _valid_format(data):
return dataloader
if isinstance(data, Iterable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册