未验证 提交 d7efdad6 编写于 作者: Z zhouzj 提交者: GitHub

Adapt to api updates. (#1629)

* Adapt to api updates.

* fix bugs.
上级 98900b35
......@@ -399,7 +399,8 @@ def eval(predictor, val_loader, metric, rerun_flag=False):
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
boxes_num = predictor.get_output_handle(output_names[1])
if FLAGS.include_nms:
boxes_num = predictor.get_output_handle(output_names[1])
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
for i, _ in enumerate(input_names):
......
......@@ -139,9 +139,10 @@ def main():
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
if global_config.get('input_list') is None:
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
train_loader = reader_wrapper(train_loader, global_config['input_list'])
if 'Evaluation' in global_config.keys() and global_config[
......
......@@ -139,7 +139,8 @@ def save_cls_model(model, input_shape, save_dir, data_type):
batch_nums=1,
weight_bits=8,
activation_bits=8,
quantizable_op_type=["conv2d", "depthwise_conv2d"])
quantizable_op_type=["conv2d", "depthwise_conv2d"],
onnx_format=False)
model_file = os.path.join(quantize_model_path, 'model.pdmodel')
param_file = os.path.join(quantize_model_path, 'model.pdiparams')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册