__init__.py 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
barrierye 已提交
14
# pylint: disable=doc-string-missing
15 16 17

from paddle.fluid import Executor
from paddle.fluid.compiler import CompiledProgram
18
from paddle.fluid.framework import core
G
guru4elephant 已提交
19
from paddle.fluid.framework import default_main_program
20
from paddle.fluid.framework import Program
G
guru4elephant 已提交
21
from paddle.fluid import CPUPlace
22
from paddle.fluid.io import save_inference_model
23
import paddle.fluid as fluid
24 25 26
from paddle.fluid.core import CipherUtils
from paddle.fluid.core import CipherFactory
from paddle.fluid.core import Cipher
27
from ..proto import general_model_config_pb2 as model_conf
G
guru4elephant 已提交
28
import os
29

B
barrierye 已提交
30

31 32 33 34
def save_model(server_model_folder,
               client_config_folder,
               feed_var_dict,
               fetch_var_dict,
35 36 37 38
               main_program=None,
               encryption=False,
               key_len=128,
               encrypt_conf=None):
G
guru4elephant 已提交
39 40
    executor = Executor(place=CPUPlace())

41
    feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
M
MRXLT 已提交
42 43 44 45 46
    target_vars = []
    target_var_names = []
    for key in sorted(fetch_var_dict.keys()):
        target_vars.append(fetch_var_dict[key])
        target_var_names.append(key)
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    if not encryption:
        save_inference_model(
            server_model_folder,
            feed_var_names,
            target_vars,
            executor,
            main_program=main_program)
    else:
        if encrypt_conf == None:
            aes_cipher = CipherFactory.create_cipher()
        else:
            #todo: more encryption algorithms
            pass
        key = CipherUtils.gen_key_to_file(128, "key")
        params = fluid.io.save_persistables(
            executor=executor, dirname=None, main_program=main_program)
        model = main_program.desc.serialize_to_string()
        if not os.path.exists(server_model_folder):
            os.makedirs(server_model_folder)
        os.chdir(server_model_folder)
        aes_cipher.encrypt_to_file(params, key, "encrypt_params")
        aes_cipher.encrypt_to_file(model, key, "encrypt_model")
        os.chdir("..")
70 71
    config = model_conf.GeneralModelConfig()

M
MRXLT 已提交
72
    #int64 = 0; float32 = 1; int32 = 2;
73 74 75 76
    for key in feed_var_dict:
        feed_var = model_conf.FeedVar()
        feed_var.alias_name = key
        feed_var.name = feed_var_dict[key].name
77
        feed_var.is_lod_tensor = feed_var_dict[key].lod_level >= 1
M
MRXLT 已提交
78
        if feed_var_dict[key].dtype == core.VarDesc.VarType.INT64:
79 80 81
            feed_var.feed_type = 0
        if feed_var_dict[key].dtype == core.VarDesc.VarType.FP32:
            feed_var.feed_type = 1
M
MRXLT 已提交
82 83
        if feed_var_dict[key].dtype == core.VarDesc.VarType.INT32:
            feed_var.feed_type = 2
84 85 86 87 88 89 90 91 92 93
        if feed_var.is_lod_tensor:
            feed_var.shape.extend([-1])
        else:
            tmp_shape = []
            for v in feed_var_dict[key].shape:
                if v >= 0:
                    tmp_shape.append(v)
            feed_var.shape.extend(tmp_shape)
        config.feed_var.extend([feed_var])

M
MRXLT 已提交
94
    for key in target_var_names:
95 96 97
        fetch_var = model_conf.FetchVar()
        fetch_var.alias_name = key
        fetch_var.name = fetch_var_dict[key].name
W
wangjiawei04 已提交
98 99
        #fetch_var.is_lod_tensor = fetch_var_dict[key].lod_level >= 1
        fetch_var.is_lod_tensor = 1
M
MRXLT 已提交
100
        if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT64:
101 102
            fetch_var.fetch_type = 0
        if fetch_var_dict[key].dtype == core.VarDesc.VarType.FP32:
B
barrierye 已提交
103
            fetch_var.fetch_type = 1
M
MRXLT 已提交
104 105
        if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT32:
            fetch_var.fetch_type = 2
G
guru4elephant 已提交
106
        if fetch_var.is_lod_tensor:
107 108 109 110 111 112 113
            fetch_var.shape.extend([-1])
        else:
            tmp_shape = []
            for v in fetch_var_dict[key].shape:
                if v >= 0:
                    tmp_shape.append(v)
            fetch_var.shape.extend(tmp_shape)
114 115
        config.fetch_var.extend([fetch_var])

G
guru4elephant 已提交
116
    cmd = "mkdir -p {}".format(client_config_folder)
117

G
guru4elephant 已提交
118
    os.system(cmd)
B
barrierye 已提交
119 120
    with open("{}/serving_client_conf.prototxt".format(client_config_folder),
              "w") as fout:
121
        fout.write(str(config))
B
barrierye 已提交
122 123
    with open("{}/serving_server_conf.prototxt".format(server_model_folder),
              "w") as fout:
124
        fout.write(str(config))
B
barrierye 已提交
125 126
    with open("{}/serving_client_conf.stream.prototxt".format(
            client_config_folder), "wb") as fout:
G
guru4elephant 已提交
127
        fout.write(config.SerializeToString())
B
barrierye 已提交
128 129
    with open("{}/serving_server_conf.stream.prototxt".format(
            server_model_folder), "wb") as fout:
G
guru4elephant 已提交
130
        fout.write(config.SerializeToString())
131 132


M
MRXLT 已提交
133 134
def inference_model_to_serving(dirname,
                               serving_server="serving_server",
M
MRXLT 已提交
135 136
                               serving_client="serving_client",
                               model_filename=None,
137 138 139 140
                               params_filename=None,
                               encryption=False,
                               key_len=128,
                               encrypt_conf=None):
141 142 143
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    inference_program, feed_target_names, fetch_targets = \
M
MRXLT 已提交
144
            fluid.io.load_inference_model(dirname=dirname, executor=exe, model_filename=model_filename, params_filename=params_filename)
145 146 147 148 149
    feed_dict = {
        x: inference_program.global_block().var(x)
        for x in feed_target_names
    }
    fetch_dict = {x.name: x for x in fetch_targets}
M
MRXLT 已提交
150
    save_model(serving_server, serving_client, feed_dict, fetch_dict,
151
               inference_program, encryption, key_len, encrypt_conf)
152 153 154
    feed_names = feed_dict.keys()
    fetch_names = fetch_dict.keys()
    return feed_names, fetch_names