提交 01aad69b 编写于 作者: B barrierye

support profile

上级 3e847867
...@@ -25,6 +25,10 @@ public class PaddleServingClientExample { ...@@ -25,6 +25,10 @@ public class PaddleServingClientExample {
} }
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch); Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) { for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
...@@ -56,6 +60,10 @@ public class PaddleServingClientExample { ...@@ -56,6 +60,10 @@ public class PaddleServingClientExample {
} }
Map<String, INDArray> fetch_map = client.predict(feed_batch, fetch); Map<String, INDArray> fetch_map = client.predict(feed_batch, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) { for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
...@@ -113,6 +121,10 @@ public class PaddleServingClientExample { ...@@ -113,6 +121,10 @@ public class PaddleServingClientExample {
Map<String, HashMap<String, INDArray>> fetch_map Map<String, HashMap<String, INDArray>> fetch_map
= client.ensemble_predict(feed_data, fetch); = client.ensemble_predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) { for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey()); System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue(); HashMap<String, INDArray> tt = entry.getValue();
...@@ -145,14 +157,94 @@ public class PaddleServingClientExample { ...@@ -145,14 +157,94 @@ public class PaddleServingClientExample {
return false; return false;
} }
Map<String, HashMap<String, INDArray>> fetch_map Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
= client.ensemble_predict(feed_data, fetch); if (fetch_map == null) {
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) { return false;
System.out.println("Model = " + entry.getKey()); }
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) { for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
return true;
}
boolean cube_local() {
long[] embedding_14 = {250644};
long[] embedding_2 = {890346};
long[] embedding_10 = {3939};
long[] embedding_17 = {421122};
long[] embedding_23 = {664215};
long[] embedding_6 = {704846};
float[] dense_input = {0.0f, 0.006633499170812604f, 0.03f, 0.0f,
0.145078125f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long[] embedding_24 = {269955};
long[] embedding_12 = {295309};
long[] embedding_7 = {437731};
long[] embedding_3 = {990128};
long[] embedding_1 = {7753};
long[] embedding_4 = {286835};
long[] embedding_8 = {27346};
long[] embedding_9 = {636474};
long[] embedding_18 = {880474};
long[] embedding_16 = {681378};
long[] embedding_22 = {410878};
long[] embedding_13 = {255651};
long[] embedding_5 = {25207};
long[] embedding_11 = {10891};
long[] embedding_20 = {238459};
long[] embedding_21 = {26235};
long[] embedding_15 = {691460};
long[] embedding_25 = {544187};
long[] embedding_19 = {537425};
long[] embedding_0 = {737395};
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("embedding_14.tmp_0", Nd4j.createFromArray(embedding_14));
put("embedding_2.tmp_0", Nd4j.createFromArray(embedding_2));
put("embedding_10.tmp_0", Nd4j.createFromArray(embedding_10));
put("embedding_17.tmp_0", Nd4j.createFromArray(embedding_17));
put("embedding_23.tmp_0", Nd4j.createFromArray(embedding_23));
put("embedding_6.tmp_0", Nd4j.createFromArray(embedding_6));
put("dense_input", Nd4j.createFromArray(dense_input));
put("embedding_24.tmp_0", Nd4j.createFromArray(embedding_24));
put("embedding_12.tmp_0", Nd4j.createFromArray(embedding_12));
put("embedding_7.tmp_0", Nd4j.createFromArray(embedding_7));
put("embedding_3.tmp_0", Nd4j.createFromArray(embedding_3));
put("embedding_1.tmp_0", Nd4j.createFromArray(embedding_1));
put("embedding_4.tmp_0", Nd4j.createFromArray(embedding_4));
put("embedding_8.tmp_0", Nd4j.createFromArray(embedding_8));
put("embedding_9.tmp_0", Nd4j.createFromArray(embedding_9));
put("embedding_18.tmp_0", Nd4j.createFromArray(embedding_18));
put("embedding_16.tmp_0", Nd4j.createFromArray(embedding_16));
put("embedding_22.tmp_0", Nd4j.createFromArray(embedding_22));
put("embedding_13.tmp_0", Nd4j.createFromArray(embedding_13));
put("embedding_5.tmp_0", Nd4j.createFromArray(embedding_5));
put("embedding_11.tmp_0", Nd4j.createFromArray(embedding_11));
put("embedding_20.tmp_0", Nd4j.createFromArray(embedding_20));
put("embedding_21.tmp_0", Nd4j.createFromArray(embedding_21));
put("embedding_15.tmp_0", Nd4j.createFromArray(embedding_15));
put("embedding_25.tmp_0", Nd4j.createFromArray(embedding_25));
put("embedding_19.tmp_0", Nd4j.createFromArray(embedding_19));
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}};
List<String> fetch = Arrays.asList("prob");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9292");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
return true; return true;
} }
...@@ -175,8 +267,11 @@ public class PaddleServingClientExample { ...@@ -175,8 +267,11 @@ public class PaddleServingClientExample {
succ = e.asyn_predict(); succ = e.asyn_predict();
} else if ("batch_predict".equals(arg)) { } else if ("batch_predict".equals(arg)) {
succ = e.batch_predict(); succ = e.batch_predict();
} else if ("cube_local".equals(arg)) {
succ = e.cube_local();
} else { } else {
System.out.format("%s not match: java -cp <jar> PaddleServingClientExample <exp>.\n", arg); System.out.format("%s not match: java -cp <jar> PaddleServingClientExample <exp>.\n", arg);
System.out.println("<exp>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local.");
} }
} }
......
...@@ -151,7 +151,6 @@ ...@@ -151,7 +151,6 @@
<artifactId>${nd4j.backend}</artifactId> <artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
<profiles> <profiles>
......
...@@ -2,6 +2,8 @@ package io.paddle.serving.client; ...@@ -2,6 +2,8 @@ package io.paddle.serving.client;
import java.util.*; import java.util.*;
import java.util.function.Function; import java.util.function.Function;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
...@@ -17,6 +19,41 @@ import io.paddle.serving.grpc.*; ...@@ -17,6 +19,41 @@ import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*; import io.paddle.serving.configure.*;
import io.paddle.serving.client.PredictFuture; import io.paddle.serving.client.PredictFuture;
class Profiler {
int pid_;
String print_head_ = null;
List<String> time_record_ = null;
boolean enable_ = false;
Profiler() {
RuntimeMXBean runtimeMXBean = ManagementFactory.getRuntimeMXBean();
pid_ = Integer.valueOf(runtimeMXBean.getName().split("@")[0]).intValue();
print_head_ = "\nPROFILE\tpid:" + pid_ + "\t";
time_record_ = new ArrayList<String>();
time_record_.add(print_head_);
}
void record(String name) {
if (enable_) {
long ctime = System.currentTimeMillis() * 1000;
time_record_.add(name + ":" + String.valueOf(ctime) + " ");
}
}
void printProfile() {
if (enable_) {
String profile_str = String.join("", time_record_);
System.out.println(profile_str);
time_record_ = new ArrayList<String>();
time_record_.add(print_head_);
}
}
void enable(boolean flag) {
enable_ = flag;
}
}
public class Client { public class Client {
private ManagedChannel channel_; private ManagedChannel channel_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_; private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
...@@ -29,6 +66,7 @@ public class Client { ...@@ -29,6 +66,7 @@ public class Client {
private Map<String, Integer> fetchTypes_; private Map<String, Integer> fetchTypes_;
private Set<String> lodTensorSet_; private Set<String> lodTensorSet_;
private Map<String, Integer> feedTensorLen_; private Map<String, Integer> feedTensorLen_;
private Profiler profiler_;
public Client() { public Client() {
channel_ = null; channel_ = null;
...@@ -43,9 +81,17 @@ public class Client { ...@@ -43,9 +81,17 @@ public class Client {
fetchTypes_ = null; fetchTypes_ = null;
lodTensorSet_ = null; lodTensorSet_ = null;
feedTensorLen_ = null; feedTensorLen_ = null;
profiler_ = new Profiler();
boolean is_profile = false;
String FLAGS_profile_client = System.getenv("FLAGS_profile_client");
if (FLAGS_profile_client != null && FLAGS_profile_client.equals("1")) {
is_profile = true;
}
profiler_.enable(is_profile);
} }
public Boolean setRpcTimeoutMs(int rpc_timeout) throws NullPointerException { public boolean setRpcTimeoutMs(int rpc_timeout) throws NullPointerException {
if (futureStub_ == null || blockingStub_ == null) { if (futureStub_ == null || blockingStub_ == null) {
throw new NullPointerException("set timeout must be set after connect."); throw new NullPointerException("set timeout must be set after connect.");
} }
...@@ -63,7 +109,7 @@ public class Client { ...@@ -63,7 +109,7 @@ public class Client {
return resp.getErrCode() == 0; return resp.getErrCode() == 0;
} }
public Boolean connect(List<String> endpoints) { public boolean connect(List<String> endpoints) {
// TODO // TODO
//String target = "ipv4:" + String.join(",", endpoints); //String target = "ipv4:" + String.join(",", endpoints);
String target = endpoints.get(0); String target = endpoints.get(0);
...@@ -164,7 +210,7 @@ public class Client { ...@@ -164,7 +210,7 @@ public class Client {
INDArray flattened_list = variable.reshape(flattened_shape); INDArray flattened_list = variable.reshape(flattened_shape);
int v_type = feedTypes_.get(name); int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
//System.out.format("name: %s, type: %d\n", name, v_type); //System.out.format("[A] name: %s, type: %d\n", name, v_type);
if (v_type == 0) { // int64 if (v_type == 0) { // int64
while (iter.hasNext()) { while (iter.hasNext()) {
long[] next_index = iter.next(); long[] next_index = iter.next();
...@@ -355,9 +401,16 @@ public class Client { ...@@ -355,9 +401,16 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
try { try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
InferenceResponse resp = blockingStub_.inference(req); InferenceResponse resp = blockingStub_.inference(req);
profiler_.record("java_client_infer_1");
profiler_.record("java_postpro_0");
Map<String, HashMap<String, INDArray>> ensemble_result Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag); = _unpackInferenceResponse(resp, fetch, need_variant_tag);
List<Map.Entry<String, HashMap<String, INDArray>>> list List<Map.Entry<String, HashMap<String, INDArray>>> list
...@@ -367,6 +420,9 @@ public class Client { ...@@ -367,6 +420,9 @@ public class Client {
System.out.format("grpc failed: please use ensemble_predict impl.\n"); System.out.format("grpc failed: please use ensemble_predict impl.\n");
return null; return null;
} }
profiler_.record("java_postpro_1");
profiler_.printProfile();
return list.get(0).getValue(); return list.get(0).getValue();
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s\n", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
...@@ -378,11 +434,22 @@ public class Client { ...@@ -378,11 +434,22 @@ public class Client {
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
try { try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
InferenceResponse resp = blockingStub_.inference(req); InferenceResponse resp = blockingStub_.inference(req);
return _unpackInferenceResponse( profiler_.record("java_client_infer_1");
resp, fetch, need_variant_tag);
profiler_.record("java_postpro_0");
Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag);
profiler_.record("java_postpro_1");
profiler_.printProfile();
return ensemble_result;
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s\n", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
return null; return null;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -43,6 +43,7 @@ message InferenceResponse { ...@@ -43,6 +43,7 @@ message InferenceResponse {
repeated ModelOutput outputs = 1; repeated ModelOutput outputs = 1;
optional string tag = 2; optional string tag = 2;
required int32 err_code = 3; required int32 err_code = 3;
optional string profile = 4;
}; };
message ModelOutput { message ModelOutput {
......
wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz
tar -zxvf text_classification_data.tar.gz
tar -zxvf imdb_model.tar.gz
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import sys
import os
import paddle
import re
import paddle.fluid.incubate.data_generator as dg
py_version = sys.version_info[0]
class IMDBDataset(dg.MultiSlotDataGenerator):
def load_resource(self, dictfile):
self._vocab = {}
wid = 0
if py_version == 2:
with open(dictfile) as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
else:
with open(dictfile, encoding="utf-8") as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
def get_words_only(self, line):
sent = line.lower().replace("<br />", " ").strip()
words = [x for x in self._pattern.split(sent) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas
def get_words_and_label(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
label = [int(line.split('|')[-1])]
words = [x for x in self._pattern.split(send) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas, label
def infer_reader(self, infer_filelist, batch, buf_size):
def local_iter():
for fname in infer_filelist:
with open(fname, "r") as fin:
for line in fin:
feas, label = self.get_words_and_label(line)
yield feas, label
import paddle
batch_iter = paddle.batch(
paddle.reader.shuffle(
local_iter, buf_size=buf_size),
batch_size=batch)
return batch_iter
def generate_sample(self, line):
def memory_iter():
for i in range(1000):
yield self.return_value
def data_iter():
feas, label = self.get_words_and_label(line)
yield ("words", feas), ("label", label)
return data_iter
if __name__ == "__main__":
imdb = IMDBDataset()
imdb.load_resource("imdb.vocab")
imdb.run_from_stdin()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient
from imdb_reader import IMDBDataset
client = MultiLangClient()
# If you have more than one model, make sure that the input
# and output of more than one model are the same.
client.connect(["127.0.0.1:9393"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
# can define your own data preprocessing easily.
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource('imdb.vocab')
for i in range(3):
line = 'i am very sad | 0'
word_ids, label = imdb_dataset.get_words_and_label(line)
print(type(word_ids[0]))
print(word_ids)
feed = {"words": word_ids}
fetch = ["prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch)
for model, fetch_map in fetch_maps.items():
if model == "serving_status_code":
continue
print("step: {}, model: {}, res: {}".format(i, model, fetch_map))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server import OpMaker
from paddle_serving_server import OpGraphMaker
from paddle_serving_server import MultiLangServer
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
cnn_infer_op = op_maker.create(
'general_infer', engine_name='cnn', inputs=[read_op])
bow_infer_op = op_maker.create(
'general_infer', engine_name='bow', inputs=[read_op])
response_op = op_maker.create(
'general_response', inputs=[cnn_infer_op, bow_infer_op])
op_graph_maker = OpGraphMaker()
op_graph_maker.add_op(read_op)
op_graph_maker.add_op(cnn_infer_op)
op_graph_maker.add_op(bow_infer_op)
op_graph_maker.add_op(response_op)
server = MultiLangServer()
server.set_op_graph(op_graph_maker.get_op_graph())
model_config = {cnn_infer_op: 'imdb_cnn_model', bow_infer_op: 'imdb_bow_model'}
server.load_model_config(model_config)
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
...@@ -135,6 +135,7 @@ class Client(object): ...@@ -135,6 +135,7 @@ class Client(object):
self.rpc_timeout_ms = 20000 self.rpc_timeout_ms = 20000
from .serving_client import PredictorRes from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes self.predictorres_constructor = PredictorRes
self.write_profile_into_fetch_map_ = False # only for grpc impl
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
...@@ -399,6 +400,7 @@ class MultiLangClient(object): ...@@ -399,6 +400,7 @@ class MultiLangClient(object):
self.channel_ = None self.channel_ = None
self.stub_ = None self.stub_ = None
self.rpc_timeout_s_ = 2 self.rpc_timeout_s_ = 2
self.profile_ = _Profiler()
def add_variant(self, tag, cluster, variant_weight): def add_variant(self, tag, cluster, variant_weight):
# TODO # TODO
...@@ -582,6 +584,7 @@ class MultiLangClient(object): ...@@ -582,6 +584,7 @@ class MultiLangClient(object):
ret = list(multi_result_map.values())[0] ret = list(multi_result_map.values())[0]
else: else:
ret = multi_result_map ret = multi_result_map
ret["serving_status_code"] = 0 ret["serving_status_code"] = 0
return ret if not need_variant_tag else [ret, tag] return ret if not need_variant_tag else [ret, tag]
...@@ -601,18 +604,30 @@ class MultiLangClient(object): ...@@ -601,18 +604,30 @@ class MultiLangClient(object):
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True): is_python=True):
req = self._pack_inference_request(feed, fetch, is_python=is_python)
if not asyn: if not asyn:
try: try:
self.profile_.record('py_prepro_0')
req = self._pack_inference_request(
feed, fetch, is_python=is_python)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_) resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
return self._unpack_inference_response( self.profile_.record('py_client_infer_1')
self.profile_.record('py_postpro_0')
ret = self._unpack_inference_response(
resp, resp,
fetch, fetch,
is_python=is_python, is_python=is_python,
need_variant_tag=need_variant_tag) need_variant_tag=need_variant_tag)
self.profile_.record('py_postpro_1')
self.profile_.print_profile()
return ret
except grpc.RpcError as e: except grpc.RpcError as e:
return {"serving_status_code": e.code()} return {"serving_status_code": e.code()}
else: else:
req = self._pack_inference_request(feed, fetch, is_python=is_python)
call_future = self.stub_.Inference.future( call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_) req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture( return MultiLangPredictFuture(
......
...@@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret results, tag = ret
resp.tag = tag resp.tag = tag
resp.err_code = 0 resp.err_code = 0
if not self.is_multi_model_: if not self.is_multi_model_:
results = {'general_infer_0': results} results = {'general_infer_0': results}
for model_name, model_result in results.items(): for model_name, model_result in results.items():
......
...@@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret results, tag = ret
resp.tag = tag resp.tag = tag
resp.err_code = 0 resp.err_code = 0
if not self.is_multi_model_: if not self.is_multi_model_:
results = {'general_infer_0': results} results = {'general_infer_0': results}
for model_name, model_result in results.items(): for model_name, model_result in results.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册