未验证 提交 c4ee247a 编写于 作者: Z zhouzj 提交者: GitHub

Fix bug of loading saved model and tests. (#1744)

* fix bug of loading saved model.

* fix distill demo.

* fix flops.py.

* fix tests.
上级 f60c6a04
...@@ -97,8 +97,8 @@ def compress(args): ...@@ -97,8 +97,8 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data)) raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")] image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model, assert args.model in model_list, "{} is not in lists: {}".format(
model_list) args.model, model_list)
student_program = paddle.static.Program() student_program = paddle.static.Program()
s_startup = paddle.static.Program() s_startup = paddle.static.Program()
places = paddle.static.cuda_places( places = paddle.static.cuda_places(
...@@ -202,7 +202,7 @@ def compress(args): ...@@ -202,7 +202,7 @@ def compress(args):
_logger.info( _logger.info(
"train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}". "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
format(epoch_id, step_id, format(epoch_id, step_id,
lr.get_lr(), loss_1[0], loss_2[0], loss_3[0])) lr.get_lr(), loss_1, loss_2, loss_3))
lr.step() lr.step()
val_acc1s = [] val_acc1s = []
val_acc5s = [] val_acc5s = []
...@@ -216,8 +216,7 @@ def compress(args): ...@@ -216,8 +216,7 @@ def compress(args):
if step_id % args.log_period == 0: if step_id % args.log_period == 0:
_logger.info( _logger.info(
"valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}". "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
format(epoch_id, step_id, val_loss[0], val_acc1[0], format(epoch_id, step_id, val_loss, val_acc1, val_acc5))
val_acc5[0]))
if args.save_inference: if args.save_inference:
paddle.static.save_inference_model( paddle.static.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), [image], [out], os.path.join("./saved_models", str(epoch_id)), [image], [out],
......
...@@ -84,7 +84,8 @@ def _graph_flops(graph, only_conv=True, detail=False): ...@@ -84,7 +84,8 @@ def _graph_flops(graph, only_conv=True, detail=False):
output_shape = op.outputs("Out")[0].shape() output_shape = op.outputs("Out")[0].shape()
_, c_out, h_out, w_out = output_shape _, c_out, h_out, w_out = output_shape
k_size = op.attr("ksize") k_size = op.attr("ksize")
flops += h_out * w_out * c_out * (k_size[0]**2) if op.attr('pooling_type') == 'avg':
flops += (h_out * w_out * c_out * (k_size[0]**2) * 2)
elif op.type() in ['mul', 'matmul', 'matmul_v2']: elif op.type() in ['mul', 'matmul', 'matmul_v2']:
x_shape = list(op.inputs("X")[0].shape()) x_shape = list(op.inputs("X")[0].shape())
...@@ -101,7 +102,11 @@ def _graph_flops(graph, only_conv=True, detail=False): ...@@ -101,7 +102,11 @@ def _graph_flops(graph, only_conv=True, detail=False):
input_shape = list(op.inputs("X")[0].shape()) input_shape = list(op.inputs("X")[0].shape())
if input_shape[0] == -1: if input_shape[0] == -1:
input_shape[0] = 1 input_shape[0] = 1
flops += np.product(input_shape) if op.type() == 'batch_norm':
op_flops = np.product(input_shape) * 2
else:
op_flops = np.product(input_shape)
flops += op_flops
if detail: if detail:
return flops, params2flops return flops, params2flops
......
...@@ -570,7 +570,7 @@ class AutoCompression: ...@@ -570,7 +570,7 @@ class AutoCompression:
tmp_base_name = "_".join([prefix, str(os.getppid()), s_datetime]) tmp_base_name = "_".join([prefix, str(os.getppid()), s_datetime])
tmp_dir = os.path.join(base_dir, tmp_base_name) tmp_dir = os.path.join(base_dir, tmp_base_name)
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir return tmp_dir
def compress(self): def compress(self):
...@@ -609,10 +609,17 @@ class AutoCompression: ...@@ -609,10 +609,17 @@ class AutoCompression:
shutil.rmtree(self.tmp_dir) shutil.rmtree(self.tmp_dir)
if self.eval_function is not None and self.final_metric < 0.0: if self.eval_function is not None and self.final_metric < 0.0:
model_filename = None
if self.model_filename is None:
model_filename = "model.pdmodel"
elif self.model_filename.endswith(".pdmodel"):
model_filename = self.model_filename
else:
model_filename = self.model_filename + '.pdmodel'
[inference_program, feed_target_names, fetch_targets]= load_inference_model( \ [inference_program, feed_target_names, fetch_targets]= load_inference_model( \
final_model_path, \ final_model_path, \
model_filename=self.model_filename, params_filename=self.params_filename, model_filename=model_filename, executor=self._exe)
executor=self._exe)
self.final_metric = self.eval_function( self.final_metric = self.eval_function(
self._exe, inference_program, feed_target_names, self._exe, inference_program, feed_target_names,
fetch_targets) fetch_targets)
......
...@@ -73,9 +73,8 @@ class TestFLOPsCase2(unittest.TestCase): ...@@ -73,9 +73,8 @@ class TestFLOPsCase2(unittest.TestCase):
def add_cases(suite): def add_cases(suite):
suite.addTest( suite.addTest(
TestFlops( TestFlops(net=paddle.vision.models.mobilenet_v1, gt=12920832.0))
net=paddle.vision.models.mobilenet_v1, gt=11792896.0)) suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=86112768.0))
suite.addTest(TestFlops(net=paddle.vision.models.resnet50, gt=83872768.0))
suite.addTest(TestFLOPsCase1()) suite.addTest(TestFLOPsCase1())
suite.addTest(TestFLOPsCase2()) suite.addTest(TestFLOPsCase2())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册