未验证 提交 49b18a5c 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1147 from HexToString/fix_grpc_bug

fix a grpc bug
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import google.protobuf.text_format
from .proto import general_model_config_pb2 as m_config
......@@ -9,6 +23,7 @@ sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path_list, is_multi_model, endpoints):
......@@ -31,7 +46,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
......@@ -57,7 +72,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
......@@ -86,11 +101,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:# int64
if v_type == 0: # int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:# float32
elif v_type == 1: # float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2:# int32
elif v_type == 2: # int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
......@@ -99,7 +114,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:# int32
elif v_type == 2: # int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
......@@ -155,7 +170,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_list, self.endpoints_, timeout_ms)
self._init_bclient(self.model_config_path_list, self.endpoints_,
timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
......@@ -176,4 +192,4 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
#dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str_list[:] = self.model_config_path_list
return resp
\ No newline at end of file
return resp
......@@ -101,7 +101,7 @@ class WebService(object):
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
if client_config_path == None:
self.client_config_path = self.server_config_dir_paths
self.client_config_path = file_path_list
def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册