未验证 提交 bdb3e376 编写于 作者: W whs 提交者: GitHub

[PaddleSlim] Enhence compressor api in PaddleSlim (#19894)


1. Support customize eval function instead of eval program.
2. Fix loading checkpoint in quantization strategy.
3. Support saving eval model when saving a checkpoint.
4. Fix decoder of loading context in PaddleSlim.
5. Fix restoring from the checkpoint of uniform prune strategy.
6. Support saving eval model and infer model during training.
7. Add ‘unitest’ for saving eval model, saving infer model and uniform pruning restoring from the checkpoint.
8. Fix pruning of depthwise_conv_grad op by updating the groups.
上级 cedc0477
......@@ -491,7 +491,7 @@ paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'p
paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884'))
paddle.fluid.contrib.distributed_batch_reader (ArgSpec(args=['batch_reader'], varargs=None, keywords=None, defaults=None), ('document', 'b60796eb0a481484dd34e345f0eaa4d5'))
paddle.fluid.contrib.Compressor ('paddle.fluid.contrib.slim.core.compressor.Compressor', ('document', 'a5417774a94aa9ae5560a42b96527e7d'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], None, None, None, None)), ('document', 'c195b3bba26169cff9439e8c467557c0'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'eval_func', 'save_eval_model', 'prune_infer_model', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, None, True, None, [], None, None, None, None)), ('document', '05119e0fa0fc07f5cf848ebf0a2cf070'))
paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0'))
paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9'))
paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67'))
......
......@@ -139,7 +139,7 @@ class Context(object):
"""
Load the context from file.
"""
with open(file_name) as context_file:
with open(file_name, 'rb') as context_file:
if sys.version_info < (3, 0):
data = pickle.load(context_file)
else:
......@@ -242,6 +242,9 @@ class Compressor(object):
eval_reader=None,
eval_feed_list=None,
eval_fetch_list=None,
eval_func=None,
save_eval_model=True,
prune_infer_model=None,
teacher_programs=[],
checkpoint_path=None,
train_optimizer=None,
......@@ -260,13 +263,28 @@ class Compressor(object):
The key is user-defined and human-readable name.
The value is the name of Variable.
eval_program(Program): The program used for evaluation.
eval_reader: The data reader used for evaluation.
eval_reader: The data reader used for evaluation. It can be None if eval_func is not None.
eval_feed_list(dict): A dict to indicate the input variable of the evaluation program.
The key is user-defined and human-readable name.
The value is the name of Variable.
It can be None if eval_func is not None.
eval_fetch_list(dict): A dict to indicate the output variable of the evaluation program.
The key is user-defined and human-readable name.
The value is the name of Variable.
eval_func(dict|function): Callback functions used to evaluate the compressed model.
The eval_func is a dict, the key is user-defined name and the value is
a callback function. And the score returned from callback functions
can be referenced in config file by the key of eval_func.
The args of callback function are compressed eval_program and scope which
store the compressed parameters.
Default: None.
save_eval_model(bool): Whether to save eval model when saving checkpoints. Default: True.
prune_infer_model(tuple|list): If prune_infer_model is not None, compressor will prune
eval program into inference program according to inputs and outputs
defined in prune_infer_model. prune_infer_model[0] is a list of input
variables' names and prune_infer_model[1] is a list of output variables'
names. If prune_infer_model is None, it will not save inference model.
Default: None.
teacher_programs: The teacher graphs used in distillation strategies.
train_optimizer: The optimizer used to append backward ops and
optimization ops into train_graph.
......@@ -294,6 +312,10 @@ class Compressor(object):
eval_program, in_nodes=eval_feed_list, out_nodes=eval_fetch_list)
self.train_reader = train_reader
self.eval_reader = eval_reader
self.eval_func = eval_func
self.save_eval_model = save_eval_model
self.prune_infer_model = prune_infer_model
self.teacher_graphs = []
for teacher in teacher_programs:
self.teacher_graphs.append(GraphWrapper(teacher))
......@@ -393,6 +415,9 @@ class Compressor(object):
strategies = pickle.load(
strategy_file, encoding='bytes')
for s, s1 in zip(self.strategies, strategies):
s1.__dict__.update(s.__dict__)
for strategy in strategies:
strategy.restore_from_checkpoint(context)
......@@ -401,10 +426,6 @@ class Compressor(object):
with scope_guard(context.scope):
context.optimize_graph.load_persistables(model_path,
exe)
context.optimize_graph.update_param_shape(context.scope)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_param_shape(context.scope)
context.eval_graph.update_groups_of_conv()
_logger.info("Loaded params from: {}".format(model_path))
return context, strategies
......@@ -416,6 +437,7 @@ class Compressor(object):
checkpoint_path = os.path.join(self.checkpoint_path,
str(context.epoch_id))
model_path = os.path.join(checkpoint_path, 'model')
eval_model_path = os.path.join(checkpoint_path, 'eval_model')
context_path = os.path.join(checkpoint_path, 'context')
strategy_path = os.path.join(checkpoint_path, 'strategies')
if not os.path.isdir(model_path):
......@@ -423,6 +445,15 @@ class Compressor(object):
exe = SlimGraphExecutor(context.place)
with scope_guard(context.scope):
context.optimize_graph.save_persistables(model_path, exe)
if self.save_eval_model:
context.eval_graph.save_model(eval_model_path, exe)
if self.prune_infer_model:
context.eval_graph.save_infer_model(
eval_model_path,
exe,
self.prune_infer_model,
program_only=self.save_eval_model)
context.to_file(context_path)
with open(strategy_path, 'wb') as strategy_file:
pickle.dump(self.strategies, strategy_file)
......@@ -485,11 +516,19 @@ class Compressor(object):
"""
Runing evaluation.
"""
results, names = context.run_eval_graph()
for name, result in zip(names, results):
if name not in context.eval_results:
context.eval_results[name] = []
context.eval_results[name].append(result)
if self.eval_func is not None:
for key in self.eval_func:
func = self.eval_func[key]
if key not in context.eval_results:
context.eval_results[key] = []
context.eval_results[key].append(
func(self.eval_graph.program, self.scope))
else:
results, names = context.run_eval_graph()
for name, result in zip(names, results):
if name not in context.eval_results:
context.eval_results[name] = []
context.eval_results[name].append(result)
def run(self):
"""
......
......@@ -211,6 +211,7 @@ class GraphWrapper(object):
self.persistables[var.name] = var
self.compiled_graph = None
in_nodes = [] if in_nodes is None else in_nodes
out_nodes = [] if out_nodes is None else out_nodes
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
self._attrs = OrderedDict()
......@@ -471,6 +472,54 @@ class GraphWrapper(object):
return flops
def save_model(self, path, exe):
"""
Save network and parameters into file which can be load by load_inference_model api.
Args:
path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables.
"""
out_vars = [
self.var(var_name)._var for var_name in self.out_nodes.values()
]
in_vars = list(self.in_nodes.values())
assert (len(in_vars) > 0)
assert (len(out_vars) > 0)
io.save_inference_model(
path,
in_vars,
out_vars,
exe.exe,
model_filename="__model__",
params_filename="__params__",
main_program=self.program.clone(),
export_for_deployment=True)
def save_infer_model(self, path, exe, in_out, program_only=False):
"""
Save network and parameters into file which can be load by load_inference_model api.
Args:
path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables.
in_out(tuple|list): in_out[0] is a list of input nodes' names
and in_out[1] is a list of output nodes' names.
program_only(bool): Whether to save program only.
"""
out_vars = [self.var(var_name)._var for var_name in in_out[1]]
in_vars = list(in_out[0])
assert (len(in_vars) > 0)
assert (len(out_vars) > 0)
io.save_inference_model(
path,
in_vars,
out_vars,
exe.exe,
model_filename="__model__.infer",
params_filename="__params__",
program_only=program_only,
main_program=self.program.clone(),
export_for_deployment=True)
def save_persistables(self, path, exe):
"""
Save all the persistable variables into file.
......@@ -527,5 +576,6 @@ class GraphWrapper(object):
def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d':
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
......@@ -635,31 +635,35 @@ class UniformPruneStrategy(PruneStrategy):
_logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios]))
return pruned_params, ratios
def on_epoch_begin(self, context):
if context.epoch_id == self.start_epoch:
params, ratios = self._get_best_ratios(context)
def restore_from_checkpoint(self, context):
self._prune(context, self.params, self.ratios)
self._prune_parameters(context.optimize_graph, context.scope,
params, ratios, context.place)
def _prune(self, context, params, ratios):
self._prune_parameters(context.optimize_graph, context.scope, params,
ratios, context.place)
model_size = context.eval_graph.numel_params()
flops = context.eval_graph.flops()
_logger.debug('\n################################')
_logger.debug('# pruning eval graph #')
_logger.debug('################################\n')
self._prune_graph(context.eval_graph, context.optimize_graph)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
model_size = context.eval_graph.numel_params()
flops = context.eval_graph.flops()
_logger.debug('\n################################')
_logger.debug('# pruning eval graph #')
_logger.debug('################################\n')
self._prune_graph(context.eval_graph, context.optimize_graph)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
_logger.info(
'------------------finish pruning--------------------------------'
)
_logger.info('Pruned size: {:.2f}'.format(1 - (float(
context.eval_graph.numel_params()) / model_size)))
_logger.info('Pruned flops: {:.2f}'.format(1 - (float(
context.eval_graph.flops()) / flops)))
# metric = self._eval_graph(context)
# _logger.info('Metric after pruning: {:.2f}'.format(metric))
_logger.info(
'------------------finish pruning--------------------------------')
_logger.info('Pruned size: {:.2f}'.format(1 - (float(
context.eval_graph.numel_params()) / model_size)))
_logger.info('Pruned flops: {:.2f}'.format(1 - (float(
context.eval_graph.flops()) / flops)))
def on_epoch_begin(self, context):
if context.epoch_id == self.start_epoch:
params, ratios = self._get_best_ratios(context)
self.params = params
self.ratios = ratios
self._prune(context, params, ratios)
_logger.info(
'------------------UniformPruneStrategy.on_compression_begin finish--------------------------------'
)
......
......@@ -17,7 +17,7 @@ import sys
import numpy as np
from .... import Executor
from .... import io
from .... import core
from .... import core, scope_guard
from ....compiler import CompiledProgram
from ....compiler import BuildStrategy
from ....framework import IrGraph, Variable, Program
......@@ -199,15 +199,16 @@ class QuantizationStrategy(Strategy):
# save float model
if self.float_model_save_path:
executor = Executor(context.place)
io.save_inference_model(
self.float_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
with scope_guard(context.scope):
io.save_inference_model(
self.float_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save int8 model
if self.int8_model_save_path:
......@@ -216,15 +217,17 @@ class QuantizationStrategy(Strategy):
convert_int8_pass.apply(test_ir_graph)
executor = Executor(context.place)
io.save_inference_model(
self.int8_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
with scope_guard(context.scope):
io.save_inference_model(
self.int8_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save mobile model
if self.mobile_model_save_path:
......@@ -237,13 +240,14 @@ class QuantizationStrategy(Strategy):
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_ir_graph)
executor = Executor(context.place)
io.save_inference_model(
self.mobile_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
with scope_guard(context.scope):
io.save_inference_model(
self.mobile_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
_logger.info('Finish QuantizationStrategy::on_epoch_end')
version: 1.0
compressor:
epoch: 1
checkpoint_path: './checkpoints/'
......@@ -28,7 +28,7 @@ strategies:
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
metric_name: 'acc_top1'
compressor:
epoch: 2
epoch: 1
checkpoint_path: './checkpoints_pruning/'
strategies:
- sensitive_pruning_strategy
version: 1.0
pruners:
pruner_1:
class: 'StructurePruner'
pruning_axis:
'*': 0
criterions:
'*': 'l1_norm'
strategies:
uniform_pruning_strategy:
class: 'UniformPruneStrategy'
pruner: 'pruner_1'
start_epoch: 0
target_ratio: 0.5
pruned_params: 'conv.*'
metric_name: 'acc_top1'
compressor:
epoch: 2
checkpoint_path: './checkpoints_uniform_restore_tmp/'
strategies:
- uniform_pruning_strategy
version: 1.0
pruners:
pruner_1:
class: 'StructurePruner'
pruning_axis:
'*': 0
criterions:
'*': 'l1_norm'
strategies:
uniform_pruning_strategy:
class: 'UniformPruneStrategy'
pruner: 'pruner_1'
start_epoch: 0
target_ratio: 0.5
pruned_params: 'conv.*'
metric_name: 'acc_top1'
compressor:
epoch: 1
checkpoint_path: './checkpoints_uniform_restore/'
strategies:
- uniform_pruning_strategy
version: 1.0
pruners:
pruner_1:
class: 'StructurePruner'
pruning_axis:
'*': 0
criterions:
'*': 'l1_norm'
strategies:
uniform_pruning_strategy:
class: 'UniformPruneStrategy'
pruner: 'pruner_1'
start_epoch: 0
target_ratio: 0.5
pruned_params: 'conv.*'
metric_name: 'acc_top1'
compressor:
epoch: 2
checkpoint_path: './checkpoints_uniform_restore/'
strategies:
- uniform_pruning_strategy
#start_epoch(int): The epoch to insert quantization operators. default: 0
#
#end_epoch(int): The epoch to save inference model. default: 0
#
#float_model_save_path(str): The path to save model with float weights.
# None means it doesn't save float model. default: None.
#
#mobile_model_save_path(str): The path to save model for paddle-mobile execution.
# None means it doesn't save mobile model. default: None.
#
#int8_model_save_path(str): The path to save model with int8_t weight.
# None means it doesn't save int8 model. default: None.
#
#activation_bits(int): quantization bit number for activation. default: 8.
#
#weight_bits(int): quantization bit number for weights. The bias is not quantized.
# default: 8.
#
#activation_quantize_type(str): quantization type for activation,
# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
# If use 'abs_max' mode, the quantization scale will be calculated
# dynamically each step in both training and testing period. If use
# 'range_abs_max', a static quantization scale will be calculated
# during training and used in inference.
#
#save_in_nodes(list<str>): A list of variable names used to prune graph
# for saving inference model.
#
#save_out_nodes(list<str>): A list of variable names used to prune graph
# for saving inference model.
version: 1.0
strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 0
float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
activation_quantize_type: 'abs_max'
save_in_nodes: ['image']
save_out_nodes: ['quan.tmp_2']
compressor:
epoch: 2
checkpoint_path: './checkpoints_quan/'
strategies:
- quantization_strategy
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import unittest
import os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
class TestCompressor(unittest.TestCase):
def test_eval_func(self):
class_dim = 10
image_shape = [1, 28, 28]
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = fluid.layers.fc(input=image, size=class_dim)
out = fluid.layers.softmax(out)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
val_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
train_feed_list = [('img', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
eval_feed_list = [('img', image.name), ('label', label.name)]
eval_fetch_list = [('acc_top1', acc_top1.name)]
def eval_func(program, scope):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(
feed_list=[image.name, label.name],
place=place,
program=program)
results = []
for data in val_reader():
result = exe.run(program=program,
scope=scope,
fetch_list=[acc_top1.name],
feed=feeder.feed(data))
results.append(np.array(result))
result = np.mean(results)
return result
com_pass = Compressor(
place,
fluid.global_scope(),
fluid.default_main_program(),
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=val_program,
eval_feed_list=eval_feed_list,
eval_fetch_list=eval_fetch_list,
eval_func={"score": eval_func},
prune_infer_model=[[image.name], [out.name]],
train_optimizer=optimizer)
com_pass.config('./configs/compress.yaml')
com_pass.run()
self.assertTrue('score' in com_pass.context.eval_results)
self.assertTrue(float(com_pass.context.eval_results['score'][0]) > 0.9)
self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__model__"))
self.assertTrue(
os.path.exists("./checkpoints/0/eval_model/__model__.infer"))
self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__params__"))
if __name__ == '__main__':
unittest.main()
......@@ -15,6 +15,7 @@
import paddle
import unittest
import paddle.fluid as fluid
import numpy as np
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
......@@ -84,6 +85,80 @@ class TestFilterPruning(unittest.TestCase):
abs((com_pass.context.eval_results['acc_top1'][-1] - 0.969) / 0.969)
< 0.02)
def test_uniform_restore_from_checkpoint(self):
np.random.seed(0)
self.uniform_restore_from_checkpoint(
"./filter_pruning/uniform_restore_0.yaml")
acc_0 = self.uniform_restore_from_checkpoint(
"./filter_pruning/uniform_restore_1.yaml")
np.random.seed(0)
acc_1 = self.uniform_restore_from_checkpoint(
"./filter_pruning/uniform_restore.yaml")
self.assertTrue(abs((acc_0 - acc_1) / acc_1) < 0.001)
def uniform_restore_from_checkpoint(self, config_file):
class_dim = 10
image_shape = [1, 28, 28]
train_program = fluid.Program()
startup_program = fluid.Program()
train_program.random_seed = 10
startup_program.random_seed = 10
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
out = fluid.layers.conv2d(image, 4, 1)
out = fluid.layers.fc(out, size=class_dim)
out = fluid.layers.softmax(out)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
val_program = train_program.clone(for_test=False)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CPUPlace()
scope = fluid.Scope()
exe = fluid.Executor(place)
exe.run(startup_program, scope=scope)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
val_feed_list = [('img', image.name), ('label', label.name)]
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
acc_top5.name)]
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
train_feed_list = [('img', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
com_pass = Compressor(
place,
scope,
train_program,
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=val_program,
eval_reader=val_reader,
eval_feed_list=val_feed_list,
eval_fetch_list=val_fetch_list,
train_optimizer=optimizer)
com_pass.config(config_file)
eval_graph = com_pass.run()
return com_pass.context.eval_results['acc_top1'][-1]
if __name__ == '__main__':
unittest.main()
......@@ -26,30 +26,44 @@ class TestQuantizationStrategy(unittest.TestCase):
"""
def test_compression(self):
self.quan("./quantization/compress.yaml")
self.quan("./quantization/compress_1.yaml")
def quan(self, config_file):
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = MobileNet(name='quan').net(input=image, class_dim=class_dim)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
out = MobileNet(name='quan').net(input=image,
class_dim=class_dim)
print("out: {}".format(out.name))
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
val_program = train_program.clone(for_test=False)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
scope = fluid.Scope()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(startup_program, scope=scope)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
......@@ -64,8 +78,8 @@ class TestQuantizationStrategy(unittest.TestCase):
com_pass = Compressor(
place,
fluid.global_scope(),
fluid.default_main_program(),
scope,
train_program,
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
......@@ -74,7 +88,7 @@ class TestQuantizationStrategy(unittest.TestCase):
eval_feed_list=val_feed_list,
eval_fetch_list=val_fetch_list,
train_optimizer=optimizer)
com_pass.config('./quantization/compress.yaml')
com_pass.config(config_file)
eval_graph = com_pass.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册