提交 38ff0d94 编写于 作者: Y yangfukui

fix test error

上级 10c1c12f
...@@ -31,7 +31,7 @@ quant_config_default = { ...@@ -31,7 +31,7 @@ quant_config_default = {
'activation_quantize_type': 'abs_max', 'activation_quantize_type': 'abs_max',
# weight quantize bit num, default is 8 # weight quantize bit num, default is 8
'weight_bits': 8, 'weight_bits': 8,
# activation quantize bit num default is 8 # activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
...@@ -115,8 +115,8 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -115,8 +115,8 @@ def quant_aware(program, scope, place, config, for_test=False):
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict" assert isinstance(config, dict), "config must be dict"
assert 'weight_quant_type' in config.keys(), 'weight_quant_type must be configured' assert 'weight_quantize_type' in config.keys(), 'weight_quantize_type must be configured'
assert 'activation_quant_type' in config.keys(), 'activation_quant_type must be configured' assert 'activation_quantize_type' in config.keys(), 'activation_quantize_type must be configured'
config = _parse_configs(config) config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
...@@ -125,8 +125,8 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -125,8 +125,8 @@ def quant_aware(program, scope, place, config, for_test=False):
scope=scope, place=place, scope=scope, place=place,
weight_bits=config['weight_bits'], weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'], activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quant_type'], activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quant_type'], weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'], window_size=config['window_size'],
moving_rate=config['moving_rate'], moving_rate=config['moving_rate'],
skip_pattern=''#not_quant_pattern skip_pattern=''#not_quant_pattern
...@@ -156,8 +156,8 @@ def quant_post(program, scope, place, config): ...@@ -156,8 +156,8 @@ def quant_post(program, scope, place, config):
scope = fluid.global_scope() if not scope else scope scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict" assert isinstance(config, dict), "config must be dict"
assert 'weight_quant_type' in config.keys(), 'weight_quant_type must be configured' assert 'weight_quantize_type' in config.keys(), 'weight_quantize_type must be configured'
assert 'activation_quant_type' in config.keys(), 'activation_quant_type must be configured' assert 'activation_quantize_type' in config.keys(), 'activation_quantize_type must be configured'
config = _parse_configs(config) config = _parse_configs(config)
...@@ -165,8 +165,8 @@ def quant_post(program, scope, place, config): ...@@ -165,8 +165,8 @@ def quant_post(program, scope, place, config):
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, place=place, scope=scope, place=place,
activation_quantize_type=config['activation_quant_type'], activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quant_type']) weight_quantize_type=config['weight_quantize_type'])
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
...@@ -195,7 +195,7 @@ def convert(program, scope, place, config, save_int8=False): ...@@ -195,7 +195,7 @@ def convert(program, scope, place, config, save_int8=False):
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, scope=scope,
place=place, place=place,
weight_quantize_type=config['weight_quant_type']) weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
freezed_program_int8 = None freezed_program_int8 = None
......
...@@ -252,15 +252,7 @@ def train(args): ...@@ -252,15 +252,7 @@ def train(args):
scope = fluid.global_scope() scope = fluid.global_scope()
exe = fluid.Executor(place) exe = fluid.Executor(place)
############################################################################################################
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test = False)
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
# load checkpoint todo # load checkpoint todo
...@@ -294,6 +286,16 @@ def train(args): ...@@ -294,6 +286,16 @@ def train(args):
train_fetch_list = [train_cost.name, train_acc1.name, train_acc5.name, global_lr.name] train_fetch_list = [train_cost.name, train_acc1.name, train_acc5.name, global_lr.name]
test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name] test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
############################################################################################################
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test = False)
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
......
...@@ -4,7 +4,7 @@ source activate py27_paddle1.6 ...@@ -4,7 +4,7 @@ source activate py27_paddle1.6
#MobileNet v1: #MobileNet v1:
python quanter_test.py \ python quanter_test.py \
--model=MobileNet \ --model=MobileNet \
--pretrained_fp32_model='../../pretrain/MobileNetV1_pretrained/' \ --pretrained_fp32_model='../../../../pretrain/MobileNetV1_pretrained/' \
--use_gpu=True \ --use_gpu=True \
--data_dir='/home/ssd8/wsz/tianfei01/traindata/imagenet/' \ --data_dir='/home/ssd8/wsz/tianfei01/traindata/imagenet/' \
--batch_size=256 \ --batch_size=256 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册