未验证 提交 0bf8a1dd 编写于 作者: C Chang Xu 提交者: GitHub

Fix distiller (#1049)

* fix distiller

* fix distiller

* fix distiller

* demo imagenet

* demo imagenet
上级 7901ff90
......@@ -25,9 +25,10 @@ add_arg('save_dir', str, None, "directory to save
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
# yapf: enable
add_arg('data_dir', str, None, "path of dataset")
# yapf: enable
def reader_wrapper(reader):
def gen():
for i, data in enumerate(reader()):
......@@ -38,7 +39,8 @@ def reader_wrapper(reader):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = paddle.batch(reader.val(), batch_size=1)
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1)
image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
......@@ -47,7 +49,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
......@@ -76,6 +77,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 5000 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0)
return result[0]
......@@ -85,8 +88,10 @@ if __name__ == '__main__':
print_arguments(args)
paddle.enable_static()
compress_config, train_config = load_config(args.config_path)
data_dir = args.data_dir
train_reader = paddle.batch(reader.train(), batch_size=64)
train_reader = paddle.batch(
reader.train(data_dir=data_dir), batch_size=args.batch_size)
train_dataloader = reader_wrapper(train_reader)
ac = AutoCompression(
......
......@@ -5,4 +5,5 @@ python3.7 demo_imagenet.py \
--save_dir='./save_qat_mbv2/' \
--devices='cpu' \
--batch_size=2 \
--data_dir='data/ILSVRC2012/' \
--config_path='./configs/CV/mbv2_ptq_hpo.yaml'
......@@ -187,6 +187,7 @@ def test(data_dir=DATA_DIR):
class ImageNetDataset(Dataset):
def __init__(self, data_dir=DATA_DIR, mode='train'):
super(ImageNetDataset, self).__init__()
self.data_dir = data_dir
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
......@@ -204,7 +205,7 @@ class ImageNetDataset(Dataset):
def __getitem__(self, index):
sample = self.data[index]
data_path = os.path.join(DATA_DIR, sample[0])
data_path = os.path.join(self.data_dir, sample[0])
if self.mode == 'train':
data, label = process_image(
[data_path, sample[1]],
......
......@@ -163,6 +163,8 @@ class AutoCompression:
self._exe, self._places, config_dict, train_program_info,
self._strategy)
if self.train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(self.train_config)
else:
......@@ -188,6 +190,8 @@ class AutoCompression:
self._exe.run(train_program_info.startup_program)
if (not self.train_config.use_fleet
) and self.train_config.amp_config is not None:
if hasattr(self.train_config.amp_config, 'use_pure_fp16'
......
......@@ -96,14 +96,19 @@ def _load_program_and_merge(executor,
params_filename,
teacher_idx=None,
feed_target_names=None):
scope = paddle.static.global_scope()
new_scope = paddle.static.Scope()
try:
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=model_dir, \
model_filename=model_filename, \
params_filename=params_filename, \
executor=executor)
except:
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=model_dir, \
executor=executor)
......@@ -130,6 +135,7 @@ def _load_program_and_merge(executor,
train_program,
data_name_map,
place,
teacher_scope=new_scope,
name_prefix=teacher_name_prefix,
merge_feed=config.get('merge_feed') or True)
if teacher_idx == None or teacher_idx == 1:
......@@ -280,7 +286,6 @@ def build_quant_program(executor, place, config, train_program_info,
assert isinstance(config, dict), "quant config must be dict"
default_config = _quant_config_default
default_config.update(config)
print(default_config)
config = _parse_configs(default_config)
use_pact = config["use_pact"]
......
......@@ -22,6 +22,7 @@ def merge(teacher_program,
data_name_map,
place,
scope=None,
teacher_scope=None,
name_prefix='teacher_',
merge_feed=True):
"""Merge teacher program into student program and add a uniform prefix to the
......@@ -48,6 +49,8 @@ def merge(teacher_program,
"""
if scope == None:
scope = paddle.static.global_scope()
if teacher_scope == None:
teacher_scope = scope
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
......@@ -60,7 +63,7 @@ def merge(teacher_program,
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
old_var = scope.var(teacher_var.name).get_tensor()
old_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册