未验证 提交 3ea2b661 编写于 作者: W whs 提交者: GitHub

Make PaddleSlim support PyReader (#19995)

* Make PaddleSlim support PyReader.
* Fix unittest of sensitive pruning.
* Add some assert.
上级 4b65af77
......@@ -492,7 +492,7 @@ paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'p
paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884'))
paddle.fluid.contrib.distributed_batch_reader (ArgSpec(args=['batch_reader'], varargs=None, keywords=None, defaults=None), ('document', 'b60796eb0a481484dd34e345f0eaa4d5'))
paddle.fluid.contrib.Compressor ('paddle.fluid.contrib.slim.core.compressor.Compressor', ('document', 'a5417774a94aa9ae5560a42b96527e7d'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'eval_func', 'save_eval_model', 'prune_infer_model', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, None, True, None, [], None, None, None, None)), ('document', '05119e0fa0fc07f5cf848ebf0a2cf070'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'eval_func', 'save_eval_model', 'prune_infer_model', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space', 'log_period'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, None, True, None, [], None, None, None, None, 20)), ('document', '26261076fa2140a1367cb4fbf3ac03fa'))
paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0'))
paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9'))
paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67'))
......
......@@ -20,6 +20,7 @@ from .... import profiler
from .... import scope_guard
from ....data_feeder import DataFeeder
from ....log_helper import get_logger
from ....reader import PyReader
from ..graph import *
from .config import ConfigFactory
import numpy as np
......@@ -185,10 +186,16 @@ class Context(object):
s_time = time.time()
reader = self.eval_reader
if sampled_rate:
assert (not isinstance(reader, Variable))
assert (sampled_rate > 0)
assert (self.cache_path is not None)
_logger.info('sampled_rate: {}; cached_id: {}'.format(sampled_rate,
cached_id))
reader = cached_reader(reader, sampled_rate, self.cache_path,
cached_id)
if isinstance(reader, Variable):
if isinstance(reader, Variable) or (isinstance(reader, PyReader) and
(not reader._iterable)):
reader.start()
try:
while True:
......@@ -249,7 +256,8 @@ class Compressor(object):
checkpoint_path=None,
train_optimizer=None,
distiller_optimizer=None,
search_space=None):
search_space=None,
log_period=20):
"""
Args:
place(fluid.Place): The device place where the compression job running.
......@@ -294,6 +302,7 @@ class Compressor(object):
student-net in fine-tune stage.
search_space(slim.nas.SearchSpace): The instance that define the searching space. It must inherite
slim.nas.SearchSpace class and overwrite the abstract methods.
log_period(int): The period of print log of training.
"""
assert train_feed_list is None or isinstance(
......@@ -329,6 +338,8 @@ class Compressor(object):
self.init_model = None
self.search_space = search_space
self.log_period = log_period
assert (log_period > 0)
def _add_strategy(self, strategy):
"""
......@@ -357,6 +368,7 @@ class Compressor(object):
if 'eval_epoch' in factory.compressor:
self.eval_epoch = factory.compressor['eval_epoch']
assert (self.eval_epoch > 0)
def _init_model(self, context):
"""
......@@ -414,7 +426,7 @@ class Compressor(object):
else:
strategies = pickle.load(
strategy_file, encoding='bytes')
assert (len(self.strategies) == len(strategies))
for s, s1 in zip(self.strategies, strategies):
s1.__dict__.update(s.__dict__)
......@@ -472,7 +484,9 @@ class Compressor(object):
context.optimize_graph.program).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'])
if isinstance(context.train_reader, Variable):
if isinstance(context.train_reader, Variable) or (
isinstance(context.train_reader,
PyReader) and (not context.train_reader._iterable)):
context.train_reader.start()
try:
while True:
......@@ -482,7 +496,7 @@ class Compressor(object):
results = executor.run(context.optimize_graph,
context.scope)
results = [float(np.mean(result)) for result in results]
if context.batch_id % 20 == 0:
if context.batch_id % self.log_period == 0:
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
context.epoch_id, context.batch_id,
context.optimize_graph.out_nodes.keys(
......@@ -502,7 +516,7 @@ class Compressor(object):
context.scope,
data=data)
results = [float(np.mean(result)) for result in results]
if context.batch_id % 20 == 0:
if context.batch_id % self.log_period == 0:
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
context.epoch_id, context.batch_id,
context.optimize_graph.out_nodes.keys(
......
......@@ -782,7 +782,6 @@ class SensitivePruneStrategy(PruneStrategy):
}
metric = None
for param in sensitivities.keys():
ratio = self.delta_rate
while ratio < 1:
......@@ -825,7 +824,7 @@ class SensitivePruneStrategy(PruneStrategy):
# restore pruned parameters
for param_name in param_backup.keys():
param_t = context.scope.find_var(param_name).get_tensor()
param_t.set(self.param_backup[param_name], context.place)
param_t.set(param_backup[param_name], context.place)
# pruned_metric = self._eval_graph(context)
......
......@@ -24,11 +24,11 @@ strategies:
target_ratio: 0.08
num_steps: 1
eval_rate: 0.5
pruned_params: 'conv6_sep_weights'
pruned_params: '_conv6_sep_weights'
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
metric_name: 'acc_top1'
compressor:
epoch: 1
epoch: 2
checkpoint_path: './checkpoints_pruning/'
strategies:
- sensitive_pruning_strategy
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册