提交 f0500791 编写于 作者: R root

fix running err

上级 ed7e33e7
...@@ -72,13 +72,13 @@ def _parse_configs(user_config): ...@@ -72,13 +72,13 @@ def _parse_configs(user_config):
assert isinstance(configs['weight_bits'], int), \ assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value." "weight_bits must be int value."
assert isinstance(configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \ assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16." "weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \ assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value." "activation_bits must be int value."
assert isinstance(configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16." "activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \ assert isinstance(configs['not_quant_pattern'], list), \
...@@ -90,7 +90,7 @@ def _parse_configs(user_config): ...@@ -90,7 +90,7 @@ def _parse_configs(user_config):
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
assert isinstance(configs['dtype'] in VALID_DTYPES), \ assert (configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES) "dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \ assert isinstance(configs['window_size'], int), \
...@@ -140,7 +140,7 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -140,7 +140,7 @@ def quant_aware(program, scope, place, config, for_test=False):
window_size=config['window_size'], window_size=config['window_size'],
moving_rate=config['moving_rate'], moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'], quantizable_op_type=config['quantize_op_types'],
skip_pattern=''#not_quant_pattern skip_pattern=config['not_quant_pattern']
) )
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
......
...@@ -206,7 +206,7 @@ def train(args): ...@@ -206,7 +206,7 @@ def train(args):
# activation quantize bit num, default is 8 # activation quantize bit num, default is 8
'activation_bits': 8, 'activation_bits': 8,
# op of name_scope in not_quant_pattern list, will not quantized # op of name_scope in not_quant_pattern list, will not quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant_dd'],
# op of types in quantize_op_types, will quantized # op of types in quantize_op_types, will quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, default is 'int8' # data type after quantization, default is 'int8'
...@@ -292,14 +292,16 @@ def train(args): ...@@ -292,14 +292,16 @@ def train(args):
# According to the weight and activation quantization type, the graph will be added # According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators. # some fake quantize operators and fake dequantize operators.
############################################################################################################ ############################################################################################################
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test = False)
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
test_prog = quant.quanter.quant_aware(test_prog, scope, place, quant_config, for_test=True)
train_prog = quant.quanter.quant_aware(train_prog, scope, place, quant_config, for_test=False)
train_prog = train_prog.with_data_parallel(loss_name=train_cost.name, build_strategy=build_strategy)
#train_prog_binary = train_prog_binary.with_data_parallel(loss_name=train_cost.name)
params = models.__dict__[args.model]().params params = models.__dict__[args.model]().params
for pass_id in range(params["num_epochs"]): for pass_id in range(params["num_epochs"]):
......
#!/usr/bin/env bash #!/usr/bin/env bash
source activate py27_paddle1.6 source /home/wsz/anaconda2/bin/activate py27_paddle1.6
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#MobileNet v1: #MobileNet v1:
python quanter_test.py \ nohup python quanter_test.py \
--model=MobileNet \ --model=MobileNet \
--pretrained_fp32_model='../../../../pretrain/MobileNetV1_pretrained/' \ --pretrained_fp32_model='../../../../pretrain/MobileNetV1_pretrained/' \
--use_gpu=True \ --use_gpu=True \
--data_dir='/home/ssd8/wsz/tianfei01/traindata/imagenet/' \ --data_dir='/home/ssd8/wsz/tianfei01/traindata/imagenet/' \
--batch_size=256 \ --batch_size=2048 \
--total_images=1281167 \ --total_images=1281167 \
--class_dim=1000 \ --class_dim=1000 \
--image_shape=3,224,224 \ --image_shape=3,224,224 \
...@@ -16,7 +16,7 @@ python quanter_test.py \ ...@@ -16,7 +16,7 @@ python quanter_test.py \
--num_epochs=20 \ --num_epochs=20 \
--lr=0.0001 \ --lr=0.0001 \
--act_quant_type=abs_max \ --act_quant_type=abs_max \
--wt_quant_type=abs_max --wt_quant_type=abs_max 2>&1 &
#ResNet50: #ResNet50:
#python quanter_test.py \ #python quanter_test.py \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册