未验证 提交 9a918144 编写于 作者: W whs 提交者: GitHub

Fix config and wrapper of dataloader (#1192)

* Fix config and wrapper of dataloader

* Fix wrapper of dataloader
1. Skip wrapper when input is format like [{'x': data1, 'y': data2}]
2. Fix scripts in nlp demo of ACT
上级 8cc39f55
...@@ -241,9 +241,10 @@ if __name__ == '__main__': ...@@ -241,9 +241,10 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
paddle.enable_static() paddle.enable_static()
all_config = load_config(args.config_path)
if train_config is not None: if "TrainConfig" in all_config:
train_config.optimizer_builder[ all_config["TrainConfig"]["optimizer_builder"][
'apply_decay_param_fun'] = apply_decay_param_fun 'apply_decay_param_fun'] = apply_decay_param_fun
train_dataloader, eval_dataloader = reader() train_dataloader, eval_dataloader = reader()
...@@ -255,11 +256,10 @@ if __name__ == '__main__': ...@@ -255,11 +256,10 @@ if __name__ == '__main__':
model_filename=args.model_filename, model_filename=args.model_filename,
params_filename=args.params_filename, params_filename=args.params_filename,
save_dir=args.save_dir, save_dir=args.save_dir,
config=args.config_path, config=all_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function if compress_config is None or eval_callback=eval_function
'HyperParameterOptimization' not in compress_config else if 'HyperParameterOptimization' not in all_config else eval_dataloader,
eval_dataloader,
eval_dataloader=eval_dataloader) eval_dataloader=eval_dataloader)
ac.compress() ac.compress()
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
python run.py \ python run.py \
--model_type='ppminilm' \ --model_type='ppminilm' \
--model_dir='./afqmc/' \ --model_dir='./all_original_models/AFQMC' \
--model_filename='inference.pdmodel' \ --model_filename='infer.pdmodel' \
--params_filename='inference.pdiparams' \ --params_filename='infer.pdiparams' \
--dataset='clue' \ --dataset='clue' \
--save_dir='./save_afqmc_pruned/' \ --save_dir='./save_afqmc_pruned/' \
--batch_size=16 \ --batch_size=16 \
......
...@@ -158,7 +158,7 @@ class AutoCompression: ...@@ -158,7 +158,7 @@ class AutoCompression:
self.model_type = self._get_model_type(self._exe, model_dir, self.model_type = self._get_model_type(self._exe, model_dir,
model_filename, params_filename) model_filename, params_filename)
if train_config is not None and train_config.use_fleet: if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True) fleet.init(is_collective=True)
if with_variable_shape( if with_variable_shape(
...@@ -191,7 +191,8 @@ class AutoCompression: ...@@ -191,7 +191,8 @@ class AutoCompression:
self.strategy_config) self.strategy_config)
self.train_config = self._get_final_train_config( self.train_config = self._get_final_train_config(
train_config, self._strategy, self.model_type) self.train_config, self._strategy, self.model_type)
_logger.info(f"Selected strategies: {self._strategy}")
def _get_final_train_config(self, train_config, strategy_config, def _get_final_train_config(self, train_config, strategy_config,
model_type): model_type):
...@@ -546,7 +547,12 @@ class AutoCompression: ...@@ -546,7 +547,12 @@ class AutoCompression:
return tmp_dir return tmp_dir
def compress(self): def compress(self):
assert len(self._strategy) > 0
self.tmp_dir = self.create_tmp_dir(self.final_dir) self.tmp_dir = self.create_tmp_dir(self.final_dir)
strategy = None
config = None
train_config = None
strategy_idx = None
for strategy_idx, ( for strategy_idx, (
strategy, config, train_config strategy, config, train_config
) in enumerate(zip(self._strategy, self._config, self.train_config)): ) in enumerate(zip(self._strategy, self._config, self.train_config)):
......
...@@ -40,7 +40,6 @@ SUPPORTED_CONFIG = [ ...@@ -40,7 +40,6 @@ SUPPORTED_CONFIG = [
"UnstructurePrune", "UnstructurePrune",
"TransformerPrune", "TransformerPrune",
"ASPPrune", "ASPPrune",
"TrainConfig",
] ]
TRAIN_CONFIG_NAME = "TrainConfig" TRAIN_CONFIG_NAME = "TrainConfig"
......
...@@ -21,16 +21,21 @@ def get_feed_vars(model_dir, model_filename, params_filename): ...@@ -21,16 +21,21 @@ def get_feed_vars(model_dir, model_filename, params_filename):
return feed_target_names return feed_target_names
def _valid_format(data):
is_dict = isinstance(data, dict)
list_with_one_dict = isinstance(
data, list) and len(data) == 1 and isinstance(data[0], dict)
return is_dict or list_with_one_dict
def wrap_dataloader(dataloader, names): def wrap_dataloader(dataloader, names):
"""Create a wrapper of dataloader if the data returned by the dataloader is not a dict. """Create a wrapper of dataloader if the data returned by the dataloader is not a dict.
And the names will be the keys of dict returned by the wrapper. And the names will be the keys of dict returned by the wrapper.
""" """
if dataloader is None: if dataloader is None:
return dataloader return dataloader
assert isinstance(dataloader, paddle.io.DataLoader)
assert len(dataloader) > 0
data = next(dataloader()) data = next(dataloader())
if isinstance(data, dict): if _valid_format(data):
return dataloader return dataloader
if isinstance(data, Iterable): if isinstance(data, Iterable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册