未验证 提交 880ad20b 编写于 作者: C Chang Xu 提交者: GitHub

Fix FullQuant Demo (#1459)

上级 7ed2aa7d
...@@ -89,11 +89,11 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev3/' - ...@@ -89,11 +89,11 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev3/' -
**验证精度** **验证精度**
根据训练log可以看到模型验证的精度,若需再次验证精度,修改配置文件```./configs/MobileNetV1/qat_dis.yaml```中所需验证模型的文件夹路径及模型和参数名称```model_dir, model_filename, params_filename```,然后使用以下命令进行验证: 根据训练log可以看到模型验证的精度,若需再次验证精度,修改配置文件```./configs/mobilenetv3_large_qat_dis.yaml```中所需验证模型的文件夹路径及模型和参数名称```model_dir, model_filename, params_filename```,然后使用以下命令进行验证:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path='./configs/eval.yaml' python eval.py --config_path='./configs/mobilenetv3_large_qat_dis.yaml'
``` ```
......
model_dir: ./MobileNetV3_large_x1_0_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
batch_size: 128
data_dir: ./ILSVRC2012_data_demo/ILSVRC2012/
img_size: 224
resize_size: 256
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
model_filename: inference.pdmodel model_filename: inference.pdmodel
params_filename: inference.pdiparams params_filename: inference.pdiparams
batch_size: 128 batch_size: 128
data_dir: ./ILSVRC2012_data_demo/ILSVRC2012/ data_dir: ./ILSVRC2012/
Distillation: Distillation:
alpha: 1.0 alpha: 1.0
......
...@@ -33,11 +33,6 @@ def argsparser(): ...@@ -33,11 +33,6 @@ def argsparser():
type=str, type=str,
default='./image_classification/configs/eval.yaml', default='./image_classification/configs/eval.yaml',
help="path of compression strategy config.") help="path of compression strategy config.")
parser.add_argument(
'--model_dir',
type=str,
default='./MobileNetV1_infer',
help='model directory')
return parser return parser
...@@ -92,6 +87,8 @@ def eval(): ...@@ -92,6 +87,8 @@ def eval():
acc_num += 1 acc_num += 1
top_5 = float(acc_num) / len(label) top_5 = float(acc_num) / len(label)
results.append([top_1, top_5]) results.append([top_1, top_5])
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
result = np.mean(np.array(results), axis=0) result = np.mean(np.array(results), axis=0)
return result[0] return result[0]
...@@ -103,8 +100,6 @@ def main(args): ...@@ -103,8 +100,6 @@ def main(args):
global data_dir global data_dir
data_dir = global_config['data_dir'] data_dir = global_config['data_dir']
if args.model_dir != global_config['model_dir']:
global_config['model_dir'] = args.model_dir
global img_size, resize_size global img_size, resize_size
img_size = int(global_config[ img_size = int(global_config[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册