提交 f0500791 编写于 作者: R root

fix running err

上级 ed7e33e7
......@@ -72,13 +72,13 @@ def _parse_configs(user_config):
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value."
assert isinstance(configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value."
assert isinstance(configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \
......@@ -90,7 +90,7 @@ def _parse_configs(user_config):
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
assert isinstance(configs['dtype'] in VALID_DTYPES), \
assert (configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \
......@@ -140,7 +140,7 @@ def quant_aware(program, scope, place, config, for_test=False):
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=''#not_quant_pattern
skip_pattern=config['not_quant_pattern']
)
transform_pass.apply(main_graph)
......
......@@ -206,7 +206,7 @@ def train(args):
# activation quantize bit num, default is 8
'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized
'not_quant_pattern': ['skip_quant'],
'not_quant_pattern': ['skip_quant_dd'],
# op of types in quantize_op_types, will quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8'
......@@ -292,14 +292,16 @@ def train(args):
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test = False)
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test=False)
train_prog = train_prog.with_data_parallel(loss_name=train_cost.name, build_strategy=build_strategy)
#train_prog_binary = train_prog_binary.with_data_parallel(loss_name=train_cost.name)
params = models.__dict__[args.model]().params
for pass_id in range(params["num_epochs"]):
......
#!/usr/bin/env bash
source activate py27_paddle1.6
source /home/wsz/anaconda2/bin/activate py27_paddle1.6
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#MobileNet v1:
python quanter_test.py \
nohup python quanter_test.py \
--model=MobileNet \
--pretrained_fp32_model='../../../../pretrain/MobileNetV1_pretrained/' \
--use_gpu=True \
--data_dir='/home/ssd8/wsz/tianfei01/traindata/imagenet/' \
--batch_size=256 \
--batch_size=2048 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
......@@ -16,7 +16,7 @@ python quanter_test.py \
--num_epochs=20 \
--lr=0.0001 \
--act_quant_type=abs_max \
--wt_quant_type=abs_max
--wt_quant_type=abs_max 2>&1 &
#ResNet50:
#python quanter_test.py \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册