提交 3cb494f3 编写于 作者: J jiangjiajun

add scope for model

上级 1c1b6052
# 预测部署-paddlex.deploy
使用Paddle Inference进行高性能的Python预测部署。更多关于Paddle Inference信息请参考[Paddle Inference文档](https://paddle-inference.readthedocs.io/en/latest/#)
## Predictor类
```
paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_trt=False, use_glog=False, memory_optimize=True)
```
> **参数**
> > * **model_dir**: 训练过程中保存的模型路径, 注意需要使用导出的inference模型
> > * **use_gpu**: 是否使用GPU进行预测
> > * **gpu_id**: 使用的GPU序列号
> > * **use_mkl**: 是否使用mkldnn加速库
> > * **use_trt**: 是否使用TensorRT预测引擎
> > * **use_glog**: 是否打印中间日志
> > * **memory_optimize**: 是否优化内存使用
> > ### 示例
> >
> > ```
> > import paddlex
> >
> > model = paddlex.deploy.Predictor(model_dir, use_gpu=True)
> > result = model.predict(image_file)
> > ```
### predict 接口
> ```
> predict(image, topk=1)
> ```
> **参数
* **image(str|np.ndarray)**: 待预测的图片路径或np.ndarray,若为后者需注意为BGR格式
* **topk(int)**: 图像分类时使用的参数,表示预测前topk个可能的分类
...@@ -10,4 +10,3 @@ API接口说明 ...@@ -10,4 +10,3 @@ API接口说明
slim.md slim.md
load_model.md load_model.md
visualize.md visualize.md
deploy.md
...@@ -73,6 +73,7 @@ class BaseAPI: ...@@ -73,6 +73,7 @@ class BaseAPI:
self.status = 'Normal' self.status = 'Normal'
# 已完成迭代轮数,为恢复训练时的起始轮数 # 已完成迭代轮数,为恢复训练时的起始轮数
self.completed_epochs = 0 self.completed_epochs = 0
self.scope = fluid.global_scope()
def _get_single_card_bs(self, batch_size): def _get_single_card_bs(self, batch_size):
if batch_size % len(self.places) == 0: if batch_size % len(self.places) == 0:
...@@ -84,6 +85,10 @@ class BaseAPI: ...@@ -84,6 +85,10 @@ class BaseAPI:
'place'])) 'place']))
def build_program(self): def build_program(self):
if hasattr(paddlex, 'model_built') and paddlex.model_built:
logging.error(
"Function model.train() only can be called once in your code.")
paddlex.model_built = True
# 构建训练网络 # 构建训练网络
self.train_inputs, self.train_outputs = self.build_net(mode='train') self.train_inputs, self.train_outputs = self.build_net(mode='train')
self.train_prog = fluid.default_main_program() self.train_prog = fluid.default_main_program()
...@@ -155,7 +160,7 @@ class BaseAPI: ...@@ -155,7 +160,7 @@ class BaseAPI:
outputs=self.test_outputs, outputs=self.test_outputs,
batch_size=batch_size, batch_size=batch_size,
batch_nums=batch_num, batch_nums=batch_num,
scope=None, scope=self.scope,
algo='KL', algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
...@@ -244,8 +249,8 @@ class BaseAPI: ...@@ -244,8 +249,8 @@ class BaseAPI:
logging.info( logging.info(
"Load pretrain weights from {}.".format(pretrain_weights), "Load pretrain weights from {}.".format(pretrain_weights),
use_color=True) use_color=True)
paddlex.utils.utils.load_pretrain_weights( paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
self.exe, self.train_prog, pretrain_weights, fuse_bn) pretrain_weights, fuse_bn)
# 进行裁剪 # 进行裁剪
if sensitivities_file is not None: if sensitivities_file is not None:
import paddleslim import paddleslim
...@@ -349,9 +354,7 @@ class BaseAPI: ...@@ -349,9 +354,7 @@ class BaseAPI:
logging.info("Model saved in {}.".format(save_dir)) logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir): def export_inference_model(self, save_dir):
test_input_names = [ test_input_names = [var.name for var in list(self.test_inputs.values())]
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values()) test_outputs = list(self.test_outputs.values())
if self.__class__.__name__ == 'MaskRCNN': if self.__class__.__name__ == 'MaskRCNN':
from paddlex.utils.save import save_mask_inference_model from paddlex.utils.save import save_mask_inference_model
...@@ -388,8 +391,7 @@ class BaseAPI: ...@@ -388,8 +391,7 @@ class BaseAPI:
# 模型保存成功的标志 # 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format( logging.info("Model for inference deploy saved in {}.".format(save_dir))
save_dir))
def train_loop(self, def train_loop(self,
num_epochs, num_epochs,
...@@ -513,13 +515,11 @@ class BaseAPI: ...@@ -513,13 +515,11 @@ class BaseAPI:
eta = ((num_epochs - i) * total_num_steps - step - 1 eta = ((num_epochs - i) * total_num_steps - step - 1
) * avg_step_time ) * avg_step_time
if time_eval_one_epoch is not None: if time_eval_one_epoch is not None:
eval_eta = ( eval_eta = (total_eval_times - i // save_interval_epochs
total_eval_times - i // save_interval_epochs ) * time_eval_one_epoch
) * time_eval_one_epoch
else: else:
eval_eta = ( eval_eta = (total_eval_times - i // save_interval_epochs
total_eval_times - i // save_interval_epochs ) * total_num_steps_eval * avg_step_time
) * total_num_steps_eval * avg_step_time
eta_str = seconds_to_hms(eta + eval_eta) eta_str = seconds_to_hms(eta + eval_eta)
logging.info( logging.info(
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -227,9 +227,10 @@ class BaseClassifier(BaseAPI): ...@@ -227,9 +227,10 @@ class BaseClassifier(BaseAPI):
true_labels = list() true_labels = list()
pred_scores = list() pred_scores = list()
if not hasattr(self, 'parallel_test_prog'): if not hasattr(self, 'parallel_test_prog'):
self.parallel_test_prog = fluid.CompiledProgram( with fluid.scope_guard(self.scope):
self.test_prog).with_data_parallel( self.parallel_test_prog = fluid.CompiledProgram(
share_vars_from=self.parallel_train_prog) self.test_prog).with_data_parallel(
share_vars_from=self.parallel_train_prog)
batch_size_each_gpu = self._get_single_card_bs(batch_size) batch_size_each_gpu = self._get_single_card_bs(batch_size)
logging.info("Start to evaluating(total_samples={}, total_steps={})...". logging.info("Start to evaluating(total_samples={}, total_steps={})...".
format(eval_dataset.num_samples, total_steps)) format(eval_dataset.num_samples, total_steps))
...@@ -242,9 +243,11 @@ class BaseClassifier(BaseAPI): ...@@ -242,9 +243,11 @@ class BaseClassifier(BaseAPI):
num_pad_samples = batch_size - num_samples num_pad_samples = batch_size - num_samples
pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1)) pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
images = np.concatenate([images, pad_images]) images = np.concatenate([images, pad_images])
outputs = self.exe.run(self.parallel_test_prog, with fluid.scope_guard(self.scope):
feed={'image': images}, outputs = self.exe.run(
fetch_list=list(self.test_outputs.values())) self.parallel_test_prog,
feed={'image': images},
fetch_list=list(self.test_outputs.values()))
outputs = [outputs[0][:num_samples]] outputs = [outputs[0][:num_samples]]
true_labels.extend(labels) true_labels.extend(labels)
pred_scores.extend(outputs[0].tolist()) pred_scores.extend(outputs[0].tolist())
...@@ -286,10 +289,11 @@ class BaseClassifier(BaseAPI): ...@@ -286,10 +289,11 @@ class BaseClassifier(BaseAPI):
self.arrange_transforms( self.arrange_transforms(
transforms=self.test_transforms, mode='test') transforms=self.test_transforms, mode='test')
im = self.test_transforms(img_file) im = self.test_transforms(img_file)
result = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed={'image': im}, result = self.exe.run(self.test_prog,
fetch_list=list(self.test_outputs.values()), feed={'image': im},
use_program_cache=True) fetch_list=list(self.test_outputs.values()),
use_program_cache=True)
pred_label = np.argsort(result[0][0])[::-1][:true_topk] pred_label = np.argsort(result[0][0])[::-1][:true_topk]
res = [{ res = [{
'category_id': l, 'category_id': l,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -317,19 +317,18 @@ class DeepLabv3p(BaseAPI): ...@@ -317,19 +317,18 @@ class DeepLabv3p(BaseAPI):
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details), tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
""" """
self.arrange_transforms( self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
transforms=eval_dataset.transforms, mode='eval')
total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size) total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
conf_mat = ConfusionMatrix(self.num_classes, streaming=True) conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
data_generator = eval_dataset.generator( data_generator = eval_dataset.generator(
batch_size=batch_size, drop_last=False) batch_size=batch_size, drop_last=False)
if not hasattr(self, 'parallel_test_prog'): if not hasattr(self, 'parallel_test_prog'):
self.parallel_test_prog = fluid.CompiledProgram( with fluid.scope_guard(self.scope):
self.test_prog).with_data_parallel( self.parallel_test_prog = fluid.CompiledProgram(
share_vars_from=self.parallel_train_prog) self.test_prog).with_data_parallel(
logging.info( share_vars_from=self.parallel_train_prog)
"Start to evaluating(total_samples={}, total_steps={})...".format( logging.info("Start to evaluating(total_samples={}, total_steps={})...".
eval_dataset.num_samples, total_steps)) format(eval_dataset.num_samples, total_steps))
for step, data in tqdm.tqdm( for step, data in tqdm.tqdm(
enumerate(data_generator()), total=total_steps): enumerate(data_generator()), total=total_steps):
images = np.array([d[0] for d in data]) images = np.array([d[0] for d in data])
...@@ -350,10 +349,12 @@ class DeepLabv3p(BaseAPI): ...@@ -350,10 +349,12 @@ class DeepLabv3p(BaseAPI):
pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1)) pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
images = np.concatenate([images, pad_images]) images = np.concatenate([images, pad_images])
feed_data = {'image': images} feed_data = {'image': images}
outputs = self.exe.run(self.parallel_test_prog, with fluid.scope_guard(self.scope):
feed=feed_data, outputs = self.exe.run(
fetch_list=list(self.test_outputs.values()), self.parallel_test_prog,
return_numpy=True) feed=feed_data,
fetch_list=list(self.test_outputs.values()),
return_numpy=True)
pred = outputs[0] pred = outputs[0]
if num_samples < batch_size: if num_samples < batch_size:
pred = pred[0:num_samples] pred = pred[0:num_samples]
...@@ -399,10 +400,11 @@ class DeepLabv3p(BaseAPI): ...@@ -399,10 +400,11 @@ class DeepLabv3p(BaseAPI):
transforms=self.test_transforms, mode='test') transforms=self.test_transforms, mode='test')
im, im_info = self.test_transforms(im_file) im, im_info = self.test_transforms(im_file)
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
result = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed={'image': im}, result = self.exe.run(self.test_prog,
fetch_list=list(self.test_outputs.values()), feed={'image': im},
use_program_cache=True) fetch_list=list(self.test_outputs.values()),
use_program_cache=True)
pred = result[0] pred = result[0]
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
logit = result[1] logit = result[1]
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -325,10 +325,12 @@ class FasterRCNN(BaseAPI): ...@@ -325,10 +325,12 @@ class FasterRCNN(BaseAPI):
'im_info': im_infos, 'im_info': im_infos,
'im_shape': im_shapes, 'im_shape': im_shapes,
} }
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed=[feed_data], outputs = self.exe.run(
fetch_list=list(self.test_outputs.values()), self.test_prog,
return_numpy=False) feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = { res = {
'bbox': (np.array(outputs[0]), 'bbox': (np.array(outputs[0]),
outputs[0].recursive_sequence_lengths()) outputs[0].recursive_sequence_lengths())
...@@ -388,15 +390,16 @@ class FasterRCNN(BaseAPI): ...@@ -388,15 +390,16 @@ class FasterRCNN(BaseAPI):
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
im_resize_info = np.expand_dims(im_resize_info, axis=0) im_resize_info = np.expand_dims(im_resize_info, axis=0)
im_shape = np.expand_dims(im_shape, axis=0) im_shape = np.expand_dims(im_shape, axis=0)
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed={ outputs = self.exe.run(self.test_prog,
'image': im, feed={
'im_info': im_resize_info, 'image': im,
'im_shape': im_shape 'im_info': im_resize_info,
}, 'im_shape': im_shape
fetch_list=list(self.test_outputs.values()), },
return_numpy=False, fetch_list=list(self.test_outputs.values()),
use_program_cache=True) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
...@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging ...@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging
def load_model(model_dir, fixed_input_shape=None): def load_model(model_dir, fixed_input_shape=None):
model_scope = fluid.Scope()
if not osp.exists(osp.join(model_dir, "model.yml")): if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir)) raise Exception("There's not model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f: with open(osp.join(model_dir, "model.yml")) as f:
...@@ -51,38 +52,40 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -51,38 +52,40 @@ def load_model(model_dir, fixed_input_shape=None):
format(fixed_input_shape)) format(fixed_input_shape))
model.fixed_input_shape = fixed_input_shape model.fixed_input_shape = fixed_input_shape
if status == "Normal" or \ with fluid.scope_guard(model_scope):
status == "Prune" or status == "fluid.save": if status == "Normal" or \
startup_prog = fluid.Program() status == "Prune" or status == "fluid.save":
model.test_prog = fluid.Program() startup_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog): model.test_prog = fluid.Program()
with fluid.unique_name.guard(): with fluid.program_guard(model.test_prog, startup_prog):
model.test_inputs, model.test_outputs = model.build_net( with fluid.unique_name.guard():
mode='test') model.test_inputs, model.test_outputs = model.build_net(
model.test_prog = model.test_prog.clone(for_test=True) mode='test')
model.exe.run(startup_prog) model.test_prog = model.test_prog.clone(for_test=True)
if status == "Prune": model.exe.run(startup_prog)
from .slim.prune import update_program if status == "Prune":
model.test_prog = update_program(model.test_prog, model_dir, from .slim.prune import update_program
model.places[0]) model.test_prog = update_program(model.test_prog, model_dir,
import pickle model.places[0])
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f: import pickle
load_dict = pickle.load(f) with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
fluid.io.set_program_state(model.test_prog, load_dict) load_dict = pickle.load(f)
fluid.io.set_program_state(model.test_prog, load_dict)
elif status == "Infer" or \
status == "Quant" or status == "fluid.save_inference_model": elif status == "Infer" or \
[prog, input_names, outputs] = fluid.io.load_inference_model( status == "Quant" or status == "fluid.save_inference_model":
model_dir, model.exe, params_filename='__params__') [prog, input_names, outputs] = fluid.io.load_inference_model(
model.test_prog = prog model_dir, model.exe, params_filename='__params__')
test_outputs_info = info['_ModelInputsOutputs']['test_outputs'] model.test_prog = prog
model.test_inputs = OrderedDict() test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
model.test_outputs = OrderedDict() model.test_inputs = OrderedDict()
for name in input_names: model.test_outputs = OrderedDict()
model.test_inputs[name] = model.test_prog.global_block().var(name) for name in input_names:
for i, out in enumerate(outputs): model.test_inputs[name] = model.test_prog.global_block().var(
var_desc = test_outputs_info[i] name)
model.test_outputs[var_desc[0]] = out for i, out in enumerate(outputs):
var_desc = test_outputs_info[i]
model.test_outputs[var_desc[0]] = out
if 'Transforms' in info: if 'Transforms' in info:
transforms_mode = info.get('TransformsMode', 'RGB') transforms_mode = info.get('TransformsMode', 'RGB')
# 固定模型的输入shape # 固定模型的输入shape
...@@ -107,6 +110,7 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -107,6 +110,7 @@ def load_model(model_dir, fixed_input_shape=None):
model.__dict__[k] = v model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(info['Model'])) logging.info("Model[{}] loaded.".format(info['Model']))
model.scope = model_scope
model.trainable = False model.trainable = False
model.status = status model.status = status
return model return model
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -286,10 +286,12 @@ class MaskRCNN(FasterRCNN): ...@@ -286,10 +286,12 @@ class MaskRCNN(FasterRCNN):
'im_info': im_infos, 'im_info': im_infos,
'im_shape': im_shapes, 'im_shape': im_shapes,
} }
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed=[feed_data], outputs = self.exe.run(
fetch_list=list(self.test_outputs.values()), self.test_prog,
return_numpy=False) feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = { res = {
'bbox': (np.array(outputs[0]), 'bbox': (np.array(outputs[0]),
outputs[0].recursive_sequence_lengths()), outputs[0].recursive_sequence_lengths()),
...@@ -356,15 +358,16 @@ class MaskRCNN(FasterRCNN): ...@@ -356,15 +358,16 @@ class MaskRCNN(FasterRCNN):
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
im_resize_info = np.expand_dims(im_resize_info, axis=0) im_resize_info = np.expand_dims(im_resize_info, axis=0)
im_shape = np.expand_dims(im_shape, axis=0) im_shape = np.expand_dims(im_shape, axis=0)
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed={ outputs = self.exe.run(self.test_prog,
'image': im, feed={
'im_info': im_resize_info, 'image': im,
'im_shape': im_shape 'im_info': im_resize_info,
}, 'im_shape': im_shape
fetch_list=list(self.test_outputs.values()), },
return_numpy=False, fetch_list=list(self.test_outputs.values()),
use_program_cache=True) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
...@@ -85,13 +85,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -85,13 +85,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._support_quantize_op_type = \ self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type + list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type)) AddQuantDequantPass._supported_quantizable_op_type))
# Check inputs # Check inputs
assert executor is not None, "The executor cannot be None." assert executor is not None, "The executor cannot be None."
assert batch_size > 0, "The batch_size should be greater than 0." assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \ assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max." "The algo should be KL, abs_max or min_max."
self._executor = executor self._executor = executor
self._dataset = dataset self._dataset = dataset
self._batch_size = batch_size self._batch_size = batch_size
...@@ -154,20 +154,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -154,20 +154,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
logging.info("Start to run batch!") logging.info("Start to run batch!")
for data in self._data_loader(): for data in self._data_loader():
start = time.time() start = time.time()
self._executor.run( with fluid.scope_guard(self._scope):
program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
if self._algo == "KL": if self._algo == "KL":
self._sample_data(batch_id) self._sample_data(batch_id)
else: else:
self._sample_threshold() self._sample_threshold()
end = time.time() end = time.time()
logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format( logging.debug(
str(batch_id + 1), '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
str(batch_ct), str(batch_id + 1), str(batch_ct), str(end - start)))
str(end-start)))
batch_id += 1 batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
...@@ -194,15 +193,16 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -194,15 +193,16 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
Returns: Returns:
None None
''' '''
feed_vars_names = [var.name for var in self._feed_list] with fluid.scope_guard(self._scope):
fluid.io.save_inference_model( feed_vars_names = [var.name for var in self._feed_list]
dirname=save_model_path, fluid.io.save_inference_model(
feeded_var_names=feed_vars_names, dirname=save_model_path,
target_vars=self._fetch_list, feeded_var_names=feed_vars_names,
executor=self._executor, target_vars=self._fetch_list,
params_filename='__params__', executor=self._executor,
main_program=self._program) params_filename='__params__',
main_program=self._program)
def _load_model_data(self): def _load_model_data(self):
''' '''
Set data loader. Set data loader.
...@@ -212,7 +212,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -212,7 +212,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator( self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True), self._dataset.generator(
self._batch_size, drop_last=True),
places=self._place) places=self._place)
def _calculate_kl_threshold(self): def _calculate_kl_threshold(self):
...@@ -235,10 +236,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -235,10 +236,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
weight_threshold.append(abs_max_value) weight_threshold.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = weight_threshold self._quantized_var_kl_threshold[var_name] = weight_threshold
end = time.time() end = time.time()
logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format( logging.debug(
str(ct), '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
str(len(self._quantized_weight_var_name)), format(
str(end-start))) str(ct),
str(len(self._quantized_weight_var_name)), str(end -
start)))
ct += 1 ct += 1
ct = 1 ct = 1
...@@ -257,10 +260,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -257,10 +260,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._quantized_var_kl_threshold[var_name] = \ self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data)) self._get_kl_scaling_factor(np.abs(sampling_data))
end = time.time() end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format( logging.debug(
str(ct), '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
str(len(self._quantized_act_var_name)), format(
str(end-start))) str(ct),
str(len(self._quantized_act_var_name)),
str(end - start)))
ct += 1 ct += 1
else: else:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
...@@ -270,10 +275,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -270,10 +275,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._quantized_var_kl_threshold[var_name] = \ self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name])) self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
end = time.time() end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format( logging.debug(
str(ct), '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
str(len(self._quantized_act_var_name)), format(
str(end-start))) str(ct),
str(len(self._quantized_act_var_name)),
str(end - start)))
ct += 1 ct += 1
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -313,10 +313,12 @@ class YOLOv3(BaseAPI): ...@@ -313,10 +313,12 @@ class YOLOv3(BaseAPI):
images = np.array([d[0] for d in data]) images = np.array([d[0] for d in data])
im_sizes = np.array([d[1] for d in data]) im_sizes = np.array([d[1] for d in data])
feed_data = {'image': images, 'im_size': im_sizes} feed_data = {'image': images, 'im_size': im_sizes}
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed=[feed_data], outputs = self.exe.run(
fetch_list=list(self.test_outputs.values()), self.test_prog,
return_numpy=False) feed=[feed_data],
fetch_list=list(self.test_outputs.values()),
return_numpy=False)
res = { res = {
'bbox': (np.array(outputs[0]), 'bbox': (np.array(outputs[0]),
outputs[0].recursive_sequence_lengths()) outputs[0].recursive_sequence_lengths())
...@@ -366,12 +368,13 @@ class YOLOv3(BaseAPI): ...@@ -366,12 +368,13 @@ class YOLOv3(BaseAPI):
im, im_size = self.test_transforms(img_file) im, im_size = self.test_transforms(img_file)
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
im_size = np.expand_dims(im_size, axis=0) im_size = np.expand_dims(im_size, axis=0)
outputs = self.exe.run(self.test_prog, with fluid.scope_guard(self.scope):
feed={'image': im, outputs = self.exe.run(self.test_prog,
'im_size': im_size}, feed={'image': im,
fetch_list=list(self.test_outputs.values()), 'im_size': im_size},
return_numpy=False, fetch_list=list(self.test_outputs.values()),
use_program_cache=True) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册