未验证 提交 0bf8a1dd 编写于 作者: C Chang Xu 提交者: GitHub

Fix distiller (#1049)

* fix distiller

* fix distiller

* fix distiller

* demo imagenet

* demo imagenet
上级 7901ff90
...@@ -25,9 +25,10 @@ add_arg('save_dir', str, None, "directory to save ...@@ -25,9 +25,10 @@ add_arg('save_dir', str, None, "directory to save
add_arg('devices', str, 'gpu', "which device used to compress.") add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.") add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.") add_arg('config_path', str, None, "path of compression strategy config.")
# yapf: enable add_arg('data_dir', str, None, "path of dataset")
# yapf: enable
def reader_wrapper(reader): def reader_wrapper(reader):
def gen(): def gen():
for i, data in enumerate(reader()): for i, data in enumerate(reader()):
...@@ -38,7 +39,8 @@ def reader_wrapper(reader): ...@@ -38,7 +39,8 @@ def reader_wrapper(reader):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = paddle.batch(reader.val(), batch_size=1)
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1)
image = paddle.static.data( image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32') name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
...@@ -47,7 +49,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -47,7 +49,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
for batch_id, data in enumerate(val_reader()): for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc # top1_acc, top5_acc
if len(test_feed_names) == 1: if len(test_feed_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = data[0][0].reshape((1, 3, 224, 224)) image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data] label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program, pred = exe.run(compiled_test_program,
...@@ -76,6 +77,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -76,6 +77,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list) fetch_list=test_fetch_list)
result = [np.mean(r) for r in result] result = [np.mean(r) for r in result]
results.append(result) results.append(result)
if batch_id % 5000 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0) result = np.mean(np.array(results), axis=0)
return result[0] return result[0]
...@@ -85,8 +88,10 @@ if __name__ == '__main__': ...@@ -85,8 +88,10 @@ if __name__ == '__main__':
print_arguments(args) print_arguments(args)
paddle.enable_static() paddle.enable_static()
compress_config, train_config = load_config(args.config_path) compress_config, train_config = load_config(args.config_path)
data_dir = args.data_dir
train_reader = paddle.batch(reader.train(), batch_size=64) train_reader = paddle.batch(
reader.train(data_dir=data_dir), batch_size=args.batch_size)
train_dataloader = reader_wrapper(train_reader) train_dataloader = reader_wrapper(train_reader)
ac = AutoCompression( ac = AutoCompression(
......
...@@ -5,4 +5,5 @@ python3.7 demo_imagenet.py \ ...@@ -5,4 +5,5 @@ python3.7 demo_imagenet.py \
--save_dir='./save_qat_mbv2/' \ --save_dir='./save_qat_mbv2/' \
--devices='cpu' \ --devices='cpu' \
--batch_size=2 \ --batch_size=2 \
--data_dir='data/ILSVRC2012/' \
--config_path='./configs/CV/mbv2_ptq_hpo.yaml' --config_path='./configs/CV/mbv2_ptq_hpo.yaml'
...@@ -187,6 +187,7 @@ def test(data_dir=DATA_DIR): ...@@ -187,6 +187,7 @@ def test(data_dir=DATA_DIR):
class ImageNetDataset(Dataset): class ImageNetDataset(Dataset):
def __init__(self, data_dir=DATA_DIR, mode='train'): def __init__(self, data_dir=DATA_DIR, mode='train'):
super(ImageNetDataset, self).__init__() super(ImageNetDataset, self).__init__()
self.data_dir = data_dir
train_file_list = os.path.join(data_dir, 'train_list.txt') train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt') val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt') test_file_list = os.path.join(data_dir, 'test_list.txt')
...@@ -204,7 +205,7 @@ class ImageNetDataset(Dataset): ...@@ -204,7 +205,7 @@ class ImageNetDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sample = self.data[index] sample = self.data[index]
data_path = os.path.join(DATA_DIR, sample[0]) data_path = os.path.join(self.data_dir, sample[0])
if self.mode == 'train': if self.mode == 'train':
data, label = process_image( data, label = process_image(
[data_path, sample[1]], [data_path, sample[1]],
......
...@@ -163,6 +163,8 @@ class AutoCompression: ...@@ -163,6 +163,8 @@ class AutoCompression:
self._exe, self._places, config_dict, train_program_info, self._exe, self._places, config_dict, train_program_info,
self._strategy) self._strategy)
if self.train_config.use_fleet: if self.train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(self.train_config) dist_strategy = _prepare_fleet_strategy(self.train_config)
else: else:
...@@ -188,6 +190,8 @@ class AutoCompression: ...@@ -188,6 +190,8 @@ class AutoCompression:
self._exe.run(train_program_info.startup_program) self._exe.run(train_program_info.startup_program)
if (not self.train_config.use_fleet if (not self.train_config.use_fleet
) and self.train_config.amp_config is not None: ) and self.train_config.amp_config is not None:
if hasattr(self.train_config.amp_config, 'use_pure_fp16' if hasattr(self.train_config.amp_config, 'use_pure_fp16'
......
...@@ -96,13 +96,18 @@ def _load_program_and_merge(executor, ...@@ -96,13 +96,18 @@ def _load_program_and_merge(executor,
params_filename, params_filename,
teacher_idx=None, teacher_idx=None,
feed_target_names=None): feed_target_names=None):
scope = paddle.static.global_scope()
new_scope = paddle.static.Scope()
try: try:
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \ [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=model_dir, \ dirname=model_dir, \
model_filename=model_filename, \ model_filename=model_filename, \
params_filename=params_filename, \ params_filename=params_filename, \
executor=executor) executor=executor)
except: except:
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \ [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=model_dir, \ path_prefix=model_dir, \
executor=executor) executor=executor)
...@@ -130,6 +135,7 @@ def _load_program_and_merge(executor, ...@@ -130,6 +135,7 @@ def _load_program_and_merge(executor,
train_program, train_program,
data_name_map, data_name_map,
place, place,
teacher_scope=new_scope,
name_prefix=teacher_name_prefix, name_prefix=teacher_name_prefix,
merge_feed=config.get('merge_feed') or True) merge_feed=config.get('merge_feed') or True)
if teacher_idx == None or teacher_idx == 1: if teacher_idx == None or teacher_idx == 1:
...@@ -280,7 +286,6 @@ def build_quant_program(executor, place, config, train_program_info, ...@@ -280,7 +286,6 @@ def build_quant_program(executor, place, config, train_program_info,
assert isinstance(config, dict), "quant config must be dict" assert isinstance(config, dict), "quant config must be dict"
default_config = _quant_config_default default_config = _quant_config_default
default_config.update(config) default_config.update(config)
print(default_config)
config = _parse_configs(default_config) config = _parse_configs(default_config)
use_pact = config["use_pact"] use_pact = config["use_pact"]
......
...@@ -22,6 +22,7 @@ def merge(teacher_program, ...@@ -22,6 +22,7 @@ def merge(teacher_program,
data_name_map, data_name_map,
place, place,
scope=None, scope=None,
teacher_scope=None,
name_prefix='teacher_', name_prefix='teacher_',
merge_feed=True): merge_feed=True):
"""Merge teacher program into student program and add a uniform prefix to the """Merge teacher program into student program and add a uniform prefix to the
...@@ -48,6 +49,8 @@ def merge(teacher_program, ...@@ -48,6 +49,8 @@ def merge(teacher_program,
""" """
if scope == None: if scope == None:
scope = paddle.static.global_scope() scope = paddle.static.global_scope()
if teacher_scope == None:
teacher_scope = scope
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
skip_rename = False skip_rename = False
...@@ -60,7 +63,7 @@ def merge(teacher_program, ...@@ -60,7 +63,7 @@ def merge(teacher_program,
new_name = name_prefix + teacher_var.name new_name = name_prefix + teacher_var.name
if not skip_rename: if not skip_rename:
# scope var rename # scope var rename
old_var = scope.var(teacher_var.name).get_tensor() old_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor() renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place) renamed_var.set(np.array(old_var), place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册