未验证 提交 d00373ae 编写于 作者: B Bai Yifan 提交者: GitHub

fix pact demo path issue (#403)

上级 a769b284
......@@ -8,8 +8,9 @@ import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path.append(os.path.dirname("__file__"))
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from paddleslim.common import get_logger
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
......@@ -133,8 +134,8 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
if args.use_pact:
image.stop_gradient = False
......@@ -196,8 +197,7 @@ def compress(args):
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
......@@ -230,10 +230,9 @@ def compress(args):
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, compiled_train_prog):
......@@ -259,8 +258,8 @@ def compress(args):
threshold = {}
for var in val_program.list_vars():
if 'pact' in var.name:
array = np.array(fluid.global_scope().find_var(
var.name).get_tensor())
array = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
threshold[var.name] = array[0]
print(threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册