__init__.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid import Executor
from paddle.fluid.compiler import CompiledProgram
17
from paddle.fluid.framework import core
G
guru4elephant 已提交
18
from paddle.fluid.framework import default_main_program
19
from paddle.fluid.framework import Program
G
guru4elephant 已提交
20 21
from paddle.fluid import CPUPlace
from paddle.fluid.io import save_persistables
22
from ..proto import general_model_config_pb2 as model_conf
G
guru4elephant 已提交
23
import os
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38

def save_model(server_model_folder,
               client_config_folder,
               feed_var_dict,
               fetch_var_dict,
               main_program=None):
    if main_program is None:
        main_program = default_main_program()
    elif isinstance(main_program, CompiledProgram):
        main_program = main_program._program
        if main_program is None:
            raise TypeError("program should be as Program type or None")
    if not isinstance(main_program, Program):
        raise TypeError("program should be as Program type or None")

G
guru4elephant 已提交
39 40 41 42 43
    executor = Executor(place=CPUPlace())

    save_persistables(executor, server_model_folder,
                      main_program)

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    config = model_conf.GeneralModelConfig()

    for key in feed_var_dict:
        feed_var = model_conf.FeedVar()
        feed_var.alias_name = key
        feed_var.name = feed_var_dict[key].name
        feed_var.is_lod_tensor = feed_var_dict[key].lod_level == 1
        if feed_var_dict[key].dtype == core.VarDesc.VarType.INT32 or \
           feed_var_dict[key].dtype == core.VarDesc.VarType.INT64:
            feed_var.feed_type = 0
        if feed_var_dict[key].dtype == core.VarDesc.VarType.FP32:
            feed_var.feed_type = 1
        if feed_var.is_lod_tensor:
            feed_var.shape.extend([-1])
        else:
            tmp_shape = []
            for v in feed_var_dict[key].shape:
                if v >= 0:
                    tmp_shape.append(v)
            feed_var.shape.extend(tmp_shape)
        config.feed_var.extend([feed_var])

    for key in fetch_var_dict:
        fetch_var = model_conf.FetchVar()
        fetch_var.alias_name = key
        fetch_var.name = fetch_var_dict[key].name
        fetch_var.shape.extend(fetch_var_dict[key].shape)
        config.fetch_var.extend([fetch_var])

G
guru4elephant 已提交
73 74
    cmd = "mkdir -p {}".format(client_config_folder)
    os.system(cmd)
75 76 77 78
    with open("{}/serving_client_conf.prototxt", "w") as fout:
        fout.write(str(config))
    with open("{}/serving_server_conf.prototxt", "w") as fout:
        fout.write(str(config))
G
guru4elephant 已提交
79

80 81 82