提交 02332b49 编写于 作者: W wangjiawei04

codestyle

上级 c3f7f006
......@@ -222,10 +222,11 @@ class PredictorClient {
const std::vector<std::vector<py::array_t<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<int>>& float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::vector<int>>& lod_slot_batch,
const std::vector<std::vector<int>>& int_lod_slot_batch,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid,
......
......@@ -141,10 +141,11 @@ int PredictorClient::numpy_predict(
const std::vector<std::vector<py::array_t<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &lod_slot_batch,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid,
......@@ -198,11 +199,13 @@ int PredictorClient::numpy_predict(
<< float_shape[vec_idx].size();
for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
tensor->add_shape(float_shape[vec_idx][j]);
std::cout << "shape " << j << " : " << float_shape[vec_idx][j]
std::cout << "float shape " << j << " : " << float_shape[vec_idx][j]
<< std::endl;
}
for (uint32_t j = 0; j < lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(lod_slot_batch[vec_idx][j]);
for (uint32_t j = 0; j < float_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(float_lod_slot_batch[vec_idx][j]);
std::cout << "float lod: " << vec_idx << " " << j
<< " value:" << float_lod_slot_batch[vec_idx][j] << std::endl;
}
tensor->set_elem_type(1);
const int float_shape_size = float_shape[vec_idx].size();
......@@ -261,6 +264,13 @@ int PredictorClient::numpy_predict(
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]);
std::cout << "int shape " << j << " : " << int_shape[vec_idx][j]
<< std::endl;
}
for (uint32_t j = 0; j < int_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int_lod_slot_batch[vec_idx][j]);
std::cout << "int lod: " << vec_idx << " " << j
<< " value:" << int_lod_slot_batch[vec_idx][j] << std::endl;
}
tensor->set_elem_type(_type[idx]);
......
......@@ -101,11 +101,12 @@ PYBIND11_MODULE(serving_client, m) {
&float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<std::vector<py::array_t<int64_t>>>
&int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &lod_slot_batch,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid,
......@@ -113,10 +114,11 @@ PYBIND11_MODULE(serving_client, m) {
return self.numpy_predict(float_feed_batch,
float_feed_name,
float_shape,
float_lod_slot_batch,
int_feed_batch,
int_feed_name,
int_shape,
lod_slot_batch,
int_lod_slot_batch,
fetch_name,
predict_res_batch,
pid,
......
......@@ -135,6 +135,8 @@ int GeneralReaderOp::inference() {
lod_tensor.dtype = paddle::PaddleDType::INT32;
}
// implement lod tensor here
std::cout << "lod size: " << req->insts(0).tensor_array(i).lod_size()
<< std::endl;
if (req->insts(0).tensor_array(i).lod_size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor";
lod_tensor.lod.resize(1);
......@@ -194,14 +196,13 @@ int GeneralReaderOp::inference() {
} else {
sample_len = tensor.shape(0);
}
out->at(i).lod[0].push_back(cur_len + sample_len);
VLOG(2) << "(logid=" << log_id << ") new len: " << cur_len + sample_len;
}
out->at(i).data.Resize(tensor_size * elem_size[i]);
out->at(i).shape = {};
for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
}
// out->at(i).shape = {};
// for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
// out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
// }
// if (out->at(i).shape.size() == 1) {
// out->at(i).shape.push_back(1);
//}
......@@ -223,6 +224,7 @@ int GeneralReaderOp::inference() {
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
int elem_num = req->insts(j).tensor_array(i).int64_data_size();
std::cout << "int elem num: " << elem_num << std::endl;
for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).int64_data(k);
}
......@@ -234,6 +236,7 @@ int GeneralReaderOp::inference() {
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
int elem_num = req->insts(j).tensor_array(i).float_data_size();
std::cout << "float elem num: " << elem_num << std::endl;
for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).float_data(k);
}
......
......@@ -18,7 +18,7 @@ import sys
from paddle_serving_client import Client
from paddle_serving_client.utils import benchmark_args
from paddle_serving_app.reader import ChineseBertReader
import numpy as np
args = benchmark_args()
reader = ChineseBertReader({"max_seq_len": 128})
......@@ -30,4 +30,8 @@ client.connect(endpoint_list)
for line in sys.stdin:
feed_dict = reader.process(line)
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape((128, 1))
#print(feed_dict)
result = client.predict(feed=feed_dict, fetch=fetch)
print(result)
......@@ -15,6 +15,7 @@
from paddle_serving_client import Client
from paddle_serving_app.reader import ChineseBertReader
import sys
import numpy as np
client = Client()
client.load_client_config("./bert_seq32_client/serving_client_conf.prototxt")
......@@ -28,12 +29,21 @@ expected_shape = {
"pooled_output": (4, 768)
}
batch_size = 4
feed_batch = []
feed_batch = {}
batch_len = 0
for line in sys.stdin:
feed = reader.process(line)
if batch_len == 0:
for key in feed.keys():
val_len = len(feed[key])
feed_batch[key] = np.array(feed[key]).reshape((1, val_len, 1))
continue
if len(feed_batch) < batch_size:
feed_batch.append(feed)
for key in feed.keys():
np.concatenate([
feed_batch[key], np.array(feed[key]).reshape((1, val_len, 1))
])
else:
fetch_map = client.predict(feed=feed_batch, fetch=fetch)
feed_batch = []
......
......@@ -19,6 +19,7 @@ import os
import criteo as criteo
import time
from paddle_serving_client.metric import auc
import numpy as np
py_version = sys.version_info[0]
......@@ -41,9 +42,13 @@ for ei in range(10000):
else:
data = reader().__next__()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
feed_dict['dense_input'] = np.array(data[0][0]).astype("float32").reshape(
1, 13)
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][i]
tmp_data = np.array(data[0][i]).astype(np.int64)
feed_dict["embedding_{}.tmp_0".format(i - 1)] = tmp_data.reshape(
(1, len(data[0][i])))
print(feed_dict)
fetch_map = client.predict(feed=feed_dict, fetch=["prob"])
prob_list.append(fetch_map['prob'][0][1])
label_list.append(data[0][-1][0])
......
......@@ -27,5 +27,12 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
import numpy as np
new_data = np.zeros((2, 1, 13)).astype("float32")
new_data[0] = data[0][0]
new_data[1] = data[0][0]
print(new_data)
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
print(fetch_map)
......@@ -15,6 +15,7 @@
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
import paddle
import numpy as np
def single_func(idx, resource):
......@@ -26,6 +27,7 @@ def single_func(idx, resource):
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
]
x = np.array(x)
for i in range(1000):
fetch_map = client.predict(feed={"x": x}, fetch=["price"])
if fetch_map is None:
......
......@@ -15,6 +15,7 @@
from paddle_serving_client import Client
from paddle_serving_app.reader import IMDBDataset
import sys
import numpy as np
client = Client()
client.load_client_config(sys.argv[1])
......@@ -28,7 +29,12 @@ imdb_dataset.load_resource(sys.argv[2])
for line in sys.stdin:
word_ids, label = imdb_dataset.get_words_and_label(line)
feed = {"words": word_ids}
word_len = len(word_ids)
feed = {
"words": np.array(word_ids).reshape(word_len, 1),
"words.lod": [0, word_len]
}
#print(feed)
fetch = ["prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = client.predict(feed=feed, fetch=fetch, batch=True)
print("{} {}".format(fetch_map["prediction"][0], label[0]))
......@@ -29,6 +29,11 @@ class IMDBService(WebService):
res_feed = [{
"words": self.dataset.get_words_only(ins["words"])
} for ins in feed]
feed = {
"words": np.array(word_ids).reshape(word_len, 1),
"words.lod": [0, word_len]
}
return res_feed, fetch
......
......@@ -19,6 +19,7 @@ from paddle_serving_app.reader import LACReader
import sys
import os
import io
import numpy as np
client = Client()
client.load_client_config(sys.argv[1])
......@@ -31,7 +32,17 @@ for line in sys.stdin:
feed_data = reader.process(line)
if len(feed_data) <= 0:
continue
fetch_map = client.predict(feed={"words": feed_data}, fetch=["crf_decode"])
print(feed_data)
#fetch_map = client.predict(feed={"words": np.array(feed_data).reshape(len(feed_data), 1), "words.lod": [0, len(feed_data)]}, fetch=["crf_decode"], batch=True)
fetch_map = client.predict(
feed={
"words": np.array(feed_data + feed_data).reshape(
len(feed_data) * 2, 1),
"words.lod": [0, len(feed_data), 2 * len(feed_data)]
},
fetch=["crf_decode"],
batch=True)
print(fetch_map)
begin = fetch_map['crf_decode.lod'][0]
end = fetch_map['crf_decode.lod'][1]
segs = reader.parse_result(line, fetch_map["crf_decode"][begin:end])
......
rpc_port: 18085
rpc_port: 18080
worker_num: 4
build_dag_each_worker: false
http_port: 9999
dag:
is_thread_op: false
client_type: brpc
is_thread_op: true
retry: 1
use_profile: false
tracer:
interval_s: 10
op:
bow:
concurrency: 2
remote_service_conf:
client_type: brpc
model_config: ocr_det_model
devices: ""
......@@ -52,7 +52,7 @@ class DetOp(Op):
self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis,:].copy()}
return {"image": det_img[np.newaxis, :].copy()}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
......
......@@ -52,7 +52,7 @@ class DetOp(Op):
self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis,:]}
return {"image": det_img[np.newaxis, :]}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
......
......@@ -53,8 +53,8 @@ class DetOp(Op):
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
with open("in.npy", 'wb') as f:
np.save(f, det_img[np.newaxis,:])
return {"image": det_img[np.newaxis,:].copy()}
np.save(f, det_img[np.newaxis, :])
return {"image": det_img[np.newaxis, :].copy()}
def postprocess(self, input_dicts, fetch_dict):
det_out = fetch_dict["concat_1.tmp_0"]
......
......@@ -36,4 +36,4 @@ for img_file in os.listdir(test_img_dir):
print(js)
break
if "error_info" in js:
print("receive error exit")
print("receive error exit")
......@@ -150,5 +150,6 @@ class Debugger(object):
for i, name in enumerate(fetch):
fetch_map[name] = outputs[i]
if len(output_tensors[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[0]).astype('int32')
fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[
0]).astype('int32')
return fetch_map
......@@ -265,7 +265,8 @@ class Client(object):
int_feed_names = []
float_feed_names = []
int_shape = []
lod_slot_batch = []
int_lod_slot_batch = []
float_lod_slot_batch = []
float_shape = []
fetch_names = []
......@@ -284,7 +285,8 @@ class Client(object):
for i, feed_i in enumerate(feed_batch):
int_slot = []
float_slot = []
lod_slot = []
int_lod_slot = []
float_lod_slot = []
for key in feed_i:
if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key))
......@@ -298,7 +300,6 @@ class Client(object):
shape_lst = []
if batch == False:
feed_i[key] = feed_i[key][np.newaxis, :]
shape_lst.append(1)
if isinstance(feed_i[key], np.ndarray):
print("feed_i_key shape", feed_i[key].shape)
shape_lst.extend(list(feed_i[key].shape))
......@@ -307,9 +308,10 @@ class Client(object):
else:
int_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i:
lod_slot_batch.append(feed_i["{}.lod".format(key)])
int_lod_slot_batch.append(feed_i["{}.lod".format(
key)])
else:
lod_slot_batch.append([])
int_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray):
int_slot.append(feed_i[key])
......@@ -324,7 +326,6 @@ class Client(object):
shape_lst = []
if batch == False:
feed_i[key] = feed_i[key][np.newaxis, :]
shape_lst.append(1)
if isinstance(feed_i[key], np.ndarray):
print("feed_i_key shape", feed_i[key].shape)
shape_lst.extend(list(feed_i[key].shape))
......@@ -333,9 +334,10 @@ class Client(object):
else:
float_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i:
lod_slot_batch.append(feed_i["{}.lod".format(key)])
float_lod_slot_batch.append(feed_i["{}.lod".format(
key)])
else:
lod_slot_batch.append([])
float_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray):
float_slot.append(feed_i[key])
......@@ -345,7 +347,8 @@ class Client(object):
self.all_numpy_input = False
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
lod_slot_batch.append(lod_slot)
int_lod_slot_batch.append(int_lod_slot)
float_lod_slot_batch.append(float_lod_slot)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
......@@ -353,9 +356,10 @@ class Client(object):
result_batch_handle = self.predictorres_constructor()
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, lod_slot_batch, fetch_names,
result_batch_handle, self.pid, log_id)
float_slot_batch, float_feed_names, float_shape,
float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
int_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
log_id)
elif self.has_numpy_input == False:
raise ValueError(
"Please make sure all of your inputs are numpy array")
......
......@@ -60,8 +60,8 @@ class DAGExecutor(object):
self._is_thread_op, tracer_interval_s, server_worker_num)
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, channel_size,
build_dag_each_worker, self._tracer)
self._is_thread_op, channel_size, build_dag_each_worker,
self._tracer)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -568,11 +568,9 @@ class DAG(object):
op.use_profiler(self._use_profile)
op.set_tracer(self._tracer)
if self._is_thread_op:
self._threads_or_proces.extend(
op.start_with_thread())
self._threads_or_proces.extend(op.start_with_thread())
else:
self._threads_or_proces.extend(
op.start_with_process())
self._threads_or_proces.extend(op.start_with_process())
_LOGGER.info("[DAG] start")
# not join yet
......
......@@ -143,7 +143,8 @@ class LocalPredictorServiceHandler(LocalRpcServiceHandler):
else:
gpu = True
self.predictor = Debugger()
self.predictor.load_model_config(model_path=self._model_config, gpu=gpu, profile=False, cpu_num=1)
self.predictor.load_model_config(
model_path=self._model_config, gpu=gpu, profile=False, cpu_num=1)
def get_client(self):
if self.predictor is None:
......
......@@ -152,7 +152,8 @@ class Op(object):
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
self._fetch_names = service_handler.get_fetch_list(
)
elif self.client_type == "local_predictor":
service_handler = local_rpc_service_handler.LocalPredictorServiceHandler(
model_config=model_config,
......@@ -165,7 +166,8 @@ class Op(object):
self._client_config = service_handler.get_client_config(
)
if self._fetch_names is None:
self._fetch_names = service_handler.get_fetch_list()
self._fetch_names = service_handler.get_fetch_list(
)
self._local_rpc_service_handler = service_handler
else:
self.with_serving = True
......@@ -230,8 +232,7 @@ class Op(object):
def set_tracer(self, tracer):
self._tracer = tracer
def init_client(self, client_config, server_endpoints,
fetch_names):
def init_client(self, client_config, server_endpoints, fetch_names):
if self.with_serving == False:
_LOGGER.info("Op({}) has no client (and it also do not "
"run the process function)".format(self.name))
......@@ -319,7 +320,10 @@ class Op(object):
"preprocess func.".format(err_info)))
os._exit(-1)
if self.client_type == "local_predictor":
call_result = self.client.predict(feed=feed_batch[0], fetch=self._fetch_names, log_id=typical_logid)
call_result = self.client.predict(
feed=feed_batch[0],
fetch=self._fetch_names,
log_id=typical_logid)
else:
call_result = self.client.predict(
feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
......@@ -374,13 +378,12 @@ class Op(object):
trace_buffer = None
if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
process= []
process = []
for concurrency_idx in range(self.concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), False,
trace_buffer))
self._get_output_channels(), False, trace_buffer))
p.daemon = True
p.start()
process.append(p)
......@@ -395,8 +398,7 @@ class Op(object):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), True,
trace_buffer))
self._get_output_channels(), True, trace_buffer))
# When a process exits, it attempts to terminate
# all of its daemonic child processes.
t.daemon = True
......@@ -683,8 +685,7 @@ class Op(object):
# init op
profiler = None
try:
profiler = self._initialize(is_thread_op,
concurrency_idx)
profiler = self._initialize(is_thread_op, concurrency_idx)
except Exception as e:
_LOGGER.critical(
"{} Failed to init op: {}".format(op_info_prefix, e),
......@@ -831,9 +832,9 @@ class Op(object):
# for the threaded version of Op, each thread cannot get its concurrency_idx
self.concurrency_idx = None
# init client
self.client = self.init_client(
self._client_config,
self._server_endpoints, self._fetch_names)
self.client = self.init_client(self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
......@@ -841,9 +842,8 @@ class Op(object):
else:
self.concurrency_idx = concurrency_idx
# init client
self.client = self.init_client(self._client_config,
self._server_endpoints,
self._fetch_names)
self.client = self.init_client(
self._client_config, self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
......
......@@ -157,7 +157,7 @@ function python_test_fit_a_line() {
cd fit_a_line # pwd: /Serving/python/examples/fit_a_line
sh get_data.sh
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
case $TYPE in
CPU)
# test rpc
......@@ -166,25 +166,6 @@ function python_test_fit_a_line() {
check_cmd "python test_client.py uci_housing_client/serving_client_conf.prototxt > /dev/null"
kill_server_process
# test web
unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed.
check_cmd "python -m paddle_serving_server.serve --model uci_housing_model --name uci --port 9393 --thread 4 --name uci > /dev/null &"
sleep 5 # wait for the server to start
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction"
# check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction`
if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200"
exit 1
fi
# test web batch
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction"
# check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction`
if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200"
exit 1
fi
setproxy # recover proxy state
kill_server_process
;;
......@@ -234,7 +215,7 @@ function python_run_criteo_ctr_with_cube() {
local TYPE=$1
yum install -y bc >/dev/null
cd criteo_ctr_with_cube # pwd: /Serving/python/examples/criteo_ctr_with_cube
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
case $TYPE in
CPU)
check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz"
......@@ -301,7 +282,7 @@ function python_run_criteo_ctr_with_cube() {
function python_test_bert() {
# pwd: /Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd bert # pwd: /Serving/python/examples/bert
case $TYPE in
CPU)
......@@ -342,7 +323,7 @@ function python_test_bert() {
function python_test_multi_fetch() {
# pwd: /Serving/python/examples
local TYPT=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd bert # pwd: /Serving/python/examples/bert
case $TYPE in
CPU)
......@@ -378,7 +359,7 @@ function python_test_multi_fetch() {
function python_test_multi_process(){
# pwd: /Serving/python/examples
local TYPT=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd fit_a_line # pwd: /Serving/python/examples/fit_a_line
sh get_data.sh
case $TYPE in
......@@ -412,7 +393,7 @@ function python_test_multi_process(){
function python_test_imdb() {
# pwd: /Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd imdb # pwd: /Serving/python/examples/imdb
case $TYPE in
CPU)
......@@ -428,19 +409,19 @@ function python_test_imdb() {
sleep 5
unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed.
check_cmd "python text_classify_service.py imdb_cnn_model/ workdir/ 9292 imdb.vocab &"
sleep 5
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"i am very sad | 0\"}], \"fetch\":[\"prediction\"]}' http://127.0.0.1:9292/imdb/prediction"
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/imdb/prediction`
if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200"
exit 1
fi
#check_cmd "python text_classify_service.py imdb_cnn_model/ workdir/ 9292 imdb.vocab &"
#sleep 5
#check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"i am very sad | 0\"}], \"fetch\":[\"prediction\"]}' http://127.0.0.1:9292/imdb/prediction"
#http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/imdb/prediction`
#if [ ${http_code} -ne 200 ]; then
# echo "HTTP status code -ne 200"
# exit 1
#fi
# test batch predict
check_cmd "python benchmark.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request http --endpoint 127.0.0.1:9292"
setproxy # recover proxy state
kill_server_process
ps -ef | grep "text_classify_service.py" | grep -v grep | awk '{print $2}' | xargs kill
#check_cmd "python benchmark.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request http --endpoint 127.0.0.1:9292"
#setproxy # recover proxy state
#kill_server_process
#ps -ef | grep "text_classify_service.py" | grep -v grep | awk '{print $2}' | xargs kill
echo "imdb CPU HTTP inference pass"
;;
GPU)
......@@ -459,7 +440,7 @@ function python_test_imdb() {
function python_test_lac() {
# pwd: /Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd lac # pwd: /Serving/python/examples/lac
case $TYPE in
CPU)
......@@ -472,23 +453,23 @@ function python_test_lac() {
kill_server_process
unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed.
check_cmd "python lac_web_service.py lac_model/ lac_workdir 9292 &"
sleep 5
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"我爱北京天安门\"}], \"fetch\":[\"word_seg\"]}' http://127.0.0.1:9292/lac/prediction"
#check_cmd "python lac_web_service.py lac_model/ lac_workdir 9292 &"
#sleep 5
#check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"我爱北京天安门\"}], \"fetch\":[\"word_seg\"]}' http://127.0.0.1:9292/lac/prediction"
# check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/lac/prediction`
if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200"
exit 1
fi
#http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/lac/prediction`
#if [ ${http_code} -ne 200 ]; then
# echo "HTTP status code -ne 200"
# exit 1
#fi
# http batch
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"我爱北京天安门\"}, {\"words\": \"我爱北京天安门\"}], \"fetch\":[\"word_seg\"]}' http://127.0.0.1:9292/lac/prediction"
#check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"我爱北京天安门\"}, {\"words\": \"我爱北京天安门\"}], \"fetch\":[\"word_seg\"]}' http://127.0.0.1:9292/lac/prediction"
# check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}, {"words": "我爱北京天安门"}], "fetch":["word_seg"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/lac/prediction`
if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200"
exit 1
fi
#http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}, {"words": "我爱北京天安门"}], "fetch":["word_seg"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9292/lac/prediction`
#if [ ${http_code} -ne 200 ]; then
# echo "HTTP status code -ne 200"
# exit 1
#fi
setproxy # recover proxy state
kill_server_process
ps -ef | grep "lac_web_service" | grep -v grep | awk '{print $2}' | xargs kill
......@@ -511,7 +492,7 @@ function python_test_lac() {
function java_run_test() {
# pwd: /Serving
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
unsetproxy
case $TYPE in
CPU)
......@@ -570,7 +551,7 @@ function python_test_grpc_impl() {
# pwd: /Serving/python/examples
cd grpc_impl_example # pwd: /Serving/python/examples/grpc_impl_example
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
unsetproxy
case $TYPE in
CPU)
......@@ -710,7 +691,7 @@ function python_test_grpc_impl() {
function python_test_yolov4(){
#pwd:/ Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd yolov4
case $TYPE in
CPU)
......@@ -738,7 +719,7 @@ function python_test_yolov4(){
function python_test_resnet50(){
#pwd:/ Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
cd imagenet
case $TYPE in
CPU)
......@@ -765,7 +746,7 @@ function python_test_resnet50(){
function python_test_pipeline(){
# pwd: /Serving/python/examples
local TYPE=$1
#export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
unsetproxy
cd pipeline # pwd: /Serving/python/examples/pipeline
case $TYPE in
......@@ -945,14 +926,14 @@ function python_run_test() {
cd python/examples # pwd: /Serving/python/examples
python_test_fit_a_line $TYPE # pwd: /Serving/python/examples
#python_run_criteo_ctr_with_cube $TYPE # pwd: /Serving/python/examples
#python_test_bert $TYPE # pwd: /Serving/python/examples
python_test_bert $TYPE # pwd: /Serving/python/examples
#python_test_imdb $TYPE # pwd: /Serving/python/examples
#python_test_lac $TYPE # pwd: /Serving/python/examples
#python_test_multi_process $TYPE # pwd: /Serving/python/examples
#python_test_multi_fetch $TYPE # pwd: /Serving/python/examples
#python_test_yolov4 $TYPE # pwd: /Serving/python/examples
python_test_lac $TYPE # pwd: /Serving/python/examples
python_test_multi_process $TYPE # pwd: /Serving/python/examples
python_test_multi_fetch $TYPE # pwd: /Serving/python/examples
python_test_yolov4 $TYPE # pwd: /Serving/python/examples
#python_test_grpc_impl $TYPE # pwd: /Serving/python/examples
#python_test_resnet50 $TYPE # pwd: /Serving/python/examples
python_test_resnet50 $TYPE # pwd: /Serving/python/examples
#python_test_pipeline $TYPE # pwd: /Serving/python/examples
echo "test python $TYPE part finished as expected."
cd ../.. # pwd: /Serving
......@@ -1092,10 +1073,10 @@ function monitor_test() {
function main() {
local TYPE=$1 # pwd: /
#init # pwd: /Serving
#build_client $TYPE # pwd: /Serving
#build_server $TYPE # pwd: /Serving
#build_app $TYPE # pwd: /Serving
init # pwd: /Serving
build_client $TYPE # pwd: /Serving
build_server $TYPE # pwd: /Serving
build_app $TYPE # pwd: /Serving
#java_run_test $TYPE # pwd: /Serving
python_run_test $TYPE # pwd: /Serving
monitor_test $TYPE # pwd: /Serving
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册