model.py 2.4 KB
Newer Older
W
wuzewu 已提交
1 2 3
import os
import numpy as np

jm_12138's avatar
jm_12138 已提交
4
from paddle.inference import create_predictor, Config
W
wuzewu 已提交
5 6 7 8 9

__all__ = ['Model']

class Model():
    # 初始化函数
jm_12138's avatar
jm_12138 已提交
10
    def __init__(self, modelpath, use_gpu=False, use_mkldnn=True, combined=True):
W
wuzewu 已提交
11
        # 加载模型预测器
jm_12138's avatar
jm_12138 已提交
12
        self.predictor = self.load_model(modelpath, use_gpu, use_mkldnn, combined)
W
wuzewu 已提交
13 14 15 16

        # 获取模型的输入输出
        self.input_names = self.predictor.get_input_names()
        self.output_names = self.predictor.get_output_names()
jm_12138's avatar
jm_12138 已提交
17 18
        self.input_handle = self.predictor.get_input_handle(self.input_names[0])
        self.output_handle = self.predictor.get_output_handle(self.output_names[0])
W
wuzewu 已提交
19 20

    # 模型加载函数
jm_12138's avatar
jm_12138 已提交
21
    def load_model(self, modelpath, use_gpu, use_mkldnn, combined):
W
wuzewu 已提交
22 23 24
        # 对运行位置进行配置
        if use_gpu:
            try:
jm_12138's avatar
jm_12138 已提交
25 26 27
                int(os.environ.get('CUDA_VISIBLE_DEVICES'))
            except Exception:
                print('Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU.')
W
wuzewu 已提交
28
                use_gpu = False
jm_12138's avatar
jm_12138 已提交
29
                
W
wuzewu 已提交
30
        # 加载模型参数
jm_12138's avatar
jm_12138 已提交
31 32 33 34 35 36
        if combined:
            model = os.path.join(modelpath, "__model__")
            params = os.path.join(modelpath, "__params__")
            config = Config(model, params)
        else:
            config = Config(modelpath)
W
wuzewu 已提交
37 38

        # 设置参数
jm_12138's avatar
jm_12138 已提交
39 40
        if use_gpu:   
            config.enable_use_gpu(100, 0)
W
wuzewu 已提交
41 42
        else:
            config.disable_gpu()
jm_12138's avatar
jm_12138 已提交
43 44
            if use_mkldnn:
                config.enable_mkldnn()
W
wuzewu 已提交
45 46 47 48 49
        config.disable_glog_info()
        config.switch_ir_optim(True)
        config.enable_memory_optim()
        config.switch_use_feed_fetch_ops(False)
        config.switch_specify_input_names(True)
jm_12138's avatar
jm_12138 已提交
50
        
W
wuzewu 已提交
51
        # 通过参数加载模型预测器
jm_12138's avatar
jm_12138 已提交
52 53
        predictor = create_predictor(config)
        
W
wuzewu 已提交
54 55 56 57 58 59 60 61 62 63
        # 返回预测器
        return predictor

    # 模型预测函数
    def predict(self, input_datas):
        outputs = []

        # 遍历输入数据进行预测
        for input_data in input_datas:
            inputs = input_data.copy()
jm_12138's avatar
jm_12138 已提交
64 65 66
            self.input_handle.copy_from_cpu(inputs)
            self.predictor.run()
            output = self.output_handle.copy_to_cpu()
W
wuzewu 已提交
67
            outputs.append(output)
jm_12138's avatar
jm_12138 已提交
68
        
W
wuzewu 已提交
69 70 71 72
        # 预测结果合并
        outputs = np.concatenate(outputs, 0)

        # 返回预测结果
jm_12138's avatar
jm_12138 已提交
73
        return outputs