diff --git a/README.md b/README.md index 9d1ec854ba67d220a481816cda5eeebf2bc89739..17730e2a071facf7c939cb7fb686596b2b752aa6 100644 --- a/README.md +++ b/README.md @@ -264,8 +264,8 @@ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"url": "https://pa ### About Efficiency - [How to profile Paddle Serving latency?](python/examples/util) -- [How to optimize performance?(Chinese)](doc/MULTI_SERVICE_ON_ONE_GPU_CN.md) -- [Deploy multi-services on one GPU(Chinese)](doc/PERFORMANCE_OPTIM_CN.md) +- [How to optimize performance?(Chinese)](doc/PERFORMANCE_OPTIM_CN.md) +- [Deploy multi-services on one GPU(Chinese)](doc/MULTI_SERVICE_ON_ONE_GPU_CN.md) - [CPU Benchmarks(Chinese)](doc/BENCHMARKING.md) - [GPU Benchmarks(Chinese)](doc/GPU_BENCHMARKING.md) diff --git a/README_CN.md b/README_CN.md index 0c30ef0cffea7d2940c544c55b641255108908fd..3302d4850e8255e8d2d6460c201892fd6035b260 100644 --- a/README_CN.md +++ b/README_CN.md @@ -270,8 +270,8 @@ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"url": "https://pa ### 关于Paddle Serving性能 - [如何测试Paddle Serving性能?](python/examples/util/) -- [如何优化性能?](doc/MULTI_SERVICE_ON_ONE_GPU_CN.md) -- [在一张GPU上启动多个预测服务](doc/PERFORMANCE_OPTIM_CN.md) +- [如何优化性能?](doc/PERFORMANCE_OPTIM_CN.md) +- [在一张GPU上启动多个预测服务](doc/MULTI_SERVICE_ON_ONE_GPU_CN.md) - [CPU版Benchmarks](doc/BENCHMARKING.md) - [GPU版Benchmarks](doc/GPU_BENCHMARKING.md) diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp index 8695da2591a30725d5b2390ad287f9ceae40052b..7d48949b22d0ace289ab3b9214f092819f5476e0 100644 --- a/core/general-server/op/general_reader_op.cpp +++ b/core/general-server/op/general_reader_op.cpp @@ -131,7 +131,7 @@ int GeneralReaderOp::inference() { lod_tensor.dtype = paddle::PaddleDType::FLOAT32; } - if (req->insts(0).tensor_array(i).shape(0) == -1) { + if (model_config->_is_lod_feed[i]) { lod_tensor.lod.resize(1); lod_tensor.lod[0].push_back(0); VLOG(2) << "var[" << i << "] is lod_tensor"; @@ -153,6 +153,7 @@ int GeneralReaderOp::inference() { // specify the memory needed for output tensor_vector for (int i = 0; i < var_num; ++i) { if (out->at(i).lod.size() == 1) { + int tensor_size = 0; for (int j = 0; j < batch_size; ++j) { const Tensor &tensor = req->insts(j).tensor_array(i); int data_len = 0; @@ -162,15 +163,28 @@ int GeneralReaderOp::inference() { data_len = tensor.float_data_size(); } VLOG(2) << "tensor size for var[" << i << "]: " << data_len; + tensor_size += data_len; int cur_len = out->at(i).lod[0].back(); VLOG(2) << "current len: " << cur_len; - out->at(i).lod[0].push_back(cur_len + data_len); - VLOG(2) << "new len: " << cur_len + data_len; + int sample_len = 0; + if (tensor.shape_size() == 1) { + sample_len = data_len; + } else { + sample_len = tensor.shape(0); + } + out->at(i).lod[0].push_back(cur_len + sample_len); + VLOG(2) << "new len: " << cur_len + sample_len; + } + out->at(i).data.Resize(tensor_size * elem_size[i]); + out->at(i).shape = {out->at(i).lod[0].back()}; + for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) { + out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j)); + } + if (out->at(i).shape.size() == 1) { + out->at(i).shape.push_back(1); } - out->at(i).data.Resize(out->at(i).lod[0].back() * elem_size[i]); - out->at(i).shape = {out->at(i).lod[0].back(), 1}; VLOG(2) << "var[" << i << "] is lod_tensor and len=" << out->at(i).lod[0].back(); } else { diff --git a/core/general-server/op/general_response_op.cpp b/core/general-server/op/general_response_op.cpp index 4d853f88eef88716c498b2b95c1498f1abdeb3d0..126accfd0a406f420f57eef4e04268e9081c744f 100644 --- a/core/general-server/op/general_response_op.cpp +++ b/core/general-server/op/general_response_op.cpp @@ -15,8 +15,10 @@ #include "core/general-server/op/general_response_op.h" #include #include +#include #include #include +#include #include "core/general-server/op/general_infer_helper.h" #include "core/predictor/framework/infer.h" #include "core/predictor/framework/memory.h" @@ -86,37 +88,51 @@ int GeneralResponseOp::inference() { // To get the order of model return values output->set_engine_name(pre_name); FetchInst *fetch_inst = output->add_insts(); + + std::map fetch_index_map; + for (int i = 0; i < in->size(); ++i) { + VLOG(2) << "index " << i << " var " << in->at(i).name; + fetch_index_map.insert(std::pair(in->at(i).name, i)); + } + for (auto &idx : fetch_index) { Tensor *tensor = fetch_inst->add_tensor_array(); tensor->set_elem_type(1); + int true_idx = fetch_index_map[model_config->_fetch_name[idx]]; if (model_config->_is_lod_fetch[idx]) { - VLOG(2) << "out[" << idx << "] is lod_tensor"; - for (int k = 0; k < in->at(idx).shape.size(); ++k) { + VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx] + << " is lod_tensor"; + for (int k = 0; k < in->at(true_idx).shape.size(); ++k) { VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k]; - tensor->add_shape(in->at(idx).shape[k]); + tensor->add_shape(in->at(true_idx).shape[k]); } } else { - VLOG(2) << "out[" << idx << "] is tensor"; - for (int k = 0; k < in->at(idx).shape.size(); ++k) { - VLOG(2) << "shape[" << k << "]: " << in->at(idx).shape[k]; - tensor->add_shape(in->at(idx).shape[k]); + VLOG(2) << "out[" << idx << "] " << model_config->_fetch_name[idx] + << " is tensor"; + for (int k = 0; k < in->at(true_idx).shape.size(); ++k) { + VLOG(2) << "shape[" << k << "]: " << in->at(true_idx).shape[k]; + tensor->add_shape(in->at(true_idx).shape[k]); } } } int var_idx = 0; for (auto &idx : fetch_index) { + int true_idx = fetch_index_map[model_config->_fetch_name[idx]]; int cap = 1; - for (int j = 0; j < in->at(idx).shape.size(); ++j) { - cap *= in->at(idx).shape[j]; + for (int j = 0; j < in->at(true_idx).shape.size(); ++j) { + cap *= in->at(true_idx).shape[j]; } - if (in->at(idx).dtype == paddle::PaddleDType::INT64) { - int64_t *data_ptr = static_cast(in->at(idx).data.data()); + if (in->at(true_idx).dtype == paddle::PaddleDType::INT64) { + VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx] + << "]."; + int64_t *data_ptr = + static_cast(in->at(true_idx).data.data()); if (model_config->_is_lod_fetch[idx]) { FetchInst *fetch_p = output->mutable_insts(0); - for (int j = 0; j < in->at(idx).lod[0].size(); ++j) { + for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) { fetch_p->mutable_tensor_array(var_idx)->add_lod( - in->at(idx).lod[0][j]); + in->at(true_idx).lod[0][j]); } for (int j = 0; j < cap; ++j) { fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]); @@ -127,14 +143,17 @@ int GeneralResponseOp::inference() { fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]); } } + VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready"; var_idx++; - } else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) { - float *data_ptr = static_cast(in->at(idx).data.data()); + } else if (in->at(true_idx).dtype == paddle::PaddleDType::FLOAT32) { + VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx] + << "]."; + float *data_ptr = static_cast(in->at(true_idx).data.data()); if (model_config->_is_lod_fetch[idx]) { FetchInst *fetch_p = output->mutable_insts(0); - for (int j = 0; j < in->at(idx).lod[0].size(); ++j) { + for (int j = 0; j < in->at(true_idx).lod[0].size(); ++j) { fetch_p->mutable_tensor_array(var_idx)->add_lod( - in->at(idx).lod[0][j]); + in->at(true_idx).lod[0][j]); } for (int j = 0; j < cap; ++j) { fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]); @@ -145,6 +164,7 @@ int GeneralResponseOp::inference() { fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]); } } + VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready"; var_idx++; } } diff --git a/doc/BERT_10_MINS_CN.md b/doc/BERT_10_MINS_CN.md index 17592000f016f1f1e939e8f3dc6dab6e05f35fe7..b7a5180da1bae2dafc431251f2b98c8a2041856a 100644 --- a/doc/BERT_10_MINS_CN.md +++ b/doc/BERT_10_MINS_CN.md @@ -13,10 +13,10 @@ import paddlehub as hub model_name = "bert_chinese_L-12_H-768_A-12" module = hub.Module(model_name) inputs, outputs, program = module.context(trainable=True, max_seq_len=20) -feed_keys = ["input_ids", "position_ids", "segment_ids", "input_mask", "pooled_output", "sequence_output"] +feed_keys = ["input_ids", "position_ids", "segment_ids", "input_mask"] fetch_keys = ["pooled_output", "sequence_output"] feed_dict = dict(zip(feed_keys, [inputs[x] for x in feed_keys])) -fetch_dict = dict(zip(fetch_keys, [outputs[x]] for x in fetch_keys)) +fetch_dict = dict(zip(fetch_keys, [outputs[x] for x in fetch_keys])) import paddle_serving_client.io as serving_io serving_io.save_model("bert_seq20_model", "bert_seq20_client", feed_dict, fetch_dict, program) diff --git a/doc/SAVE.md b/doc/SAVE.md index 3f7f97e12e1e309ff0933e150ea7bcd23298b60e..4fcdfa438574fac7de21c963f5bb173c69261210 100644 --- a/doc/SAVE.md +++ b/doc/SAVE.md @@ -10,8 +10,9 @@ serving_io.save_model("imdb_model", "imdb_client_conf", {"words": data}, {"prediction": prediction}, fluid.default_main_program()) ``` -`imdb_model` is the server side model with serving configurations. `imdb_client_conf` is the client rpc configurations. Serving has a -dictionary for `Feed` and `Fetch` variables for client to assign. In the example, `{"words": data}` is the feed dict that specify the input of saved inference model. `{"prediction": prediction}` is the fetch dic that specify the output of saved inference model. An alias name can be defined for feed and fetch variables. An example of how to use alias name +`imdb_model` is the server side model with serving configurations. `imdb_client_conf` is the client rpc configurations. + +Serving has a dictionary for `Feed` and `Fetch` variables for client to assign. In the example, `{"words": data}` is the feed dict that specify the input of saved inference model. `{"prediction": prediction}` is the fetch dic that specify the output of saved inference model. An alias name can be defined for feed and fetch variables. An example of how to use alias name is as follows: ``` python from paddle_serving_client import Client @@ -35,10 +36,14 @@ for line in sys.stdin: If you have saved model files using Paddle's `save_inference_model` API, you can use Paddle Serving's` inference_model_to_serving` API to convert it into a model file that can be used for Paddle Serving. ``` import paddle_serving_client.io as serving_io -serving_io.inference_model_to_serving(dirname, model_filename=None, params_filename=None, serving_server="serving_server", serving_client="serving_client") +serving_io.inference_model_to_serving(dirname, serving_server="serving_server", serving_client="serving_client", model_filename=None, params_filename=None ) ``` dirname (str) - Path of saved model files. Program file and parameter files are saved in this directory. -model_filename (str, optional) - The name of file to load the inference program. If it is None, the default filename __model__ will be used. Default: None. -paras_filename (str, optional) - The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None. + serving_server (str, optional) - The path of model files and configuration files for server. Default: "serving_server". + serving_client (str, optional) - The path of configuration files for client. Default: "serving_client". + +model_filename (str, optional) - The name of file to load the inference program. If it is None, the default filename `__model__` will be used. Default: None. + +paras_filename (str, optional) - The name of file to load all parameters. It is only used for the case that all parameters were saved in a single binary file. If parameters were saved in separate files, set it as None. Default: None. diff --git a/doc/SAVE_CN.md b/doc/SAVE_CN.md index fc75cd8d015a6d6f42a08f29e4035db20f450d91..3ca715c024a38b6fdce5c973844e7d023eebffcc 100644 --- a/doc/SAVE_CN.md +++ b/doc/SAVE_CN.md @@ -11,7 +11,9 @@ serving_io.save_model("imdb_model", "imdb_client_conf", {"words": data}, {"prediction": prediction}, fluid.default_main_program()) ``` -imdb_model是具有服务配置的服务器端模型。 imdb_client_conf是客户端rpc配置。 Serving有一个 提供给用户存放Feed和Fetch变量信息的字典。 在示例中,`{words”:data}` 是用于指定已保存推理模型输入的提要字典。`{"prediction":projection}`是指定保存的推理模型输出的字典。可以为feed和fetch变量定义一个别名。 如何使用别名的例子 示例如下: +imdb_model是具有服务配置的服务器端模型。 imdb_client_conf是客户端rpc配置。 + +Serving有一个提供给用户存放Feed和Fetch变量信息的字典。 在示例中,`{"words":data}` 是用于指定已保存推理模型输入的提要字典。`{"prediction":projection}`是指定保存的推理模型输出的字典。可以为feed和fetch变量定义一个别名。 如何使用别名的例子 示例如下: ``` python from paddle_serving_client import Client @@ -35,10 +37,14 @@ for line in sys.stdin: 如果已使用Paddle 的`save_inference_model`接口保存出预测要使用的模型,则可以通过Paddle Serving的`inference_model_to_serving`接口转换成可用于Paddle Serving的模型文件。 ``` import paddle_serving_client.io as serving_io -serving_io.inference_model_to_serving(dirname, model_filename=None, params_filename=None, serving_server="serving_server", serving_client="serving_client") +serving_io.inference_model_to_serving(dirname, serving_server="serving_server", serving_client="serving_client", model_filename=None, params_filename=None) ``` dirname (str) – 需要转换的模型文件存储路径,Program结构文件和参数文件均保存在此目录。 -model_filename (str,可选) – 存储需要转换的模型Inference Program结构的文件名称。如果设置为None,则使用 __model__ 作为默认的文件名。默认值为None。 + +serving_server (str, 可选) - 转换后的模型文件和配置文件的存储路径。默认值为serving_server。 + +serving_client (str, 可选) - 转换后的客户端配置文件存储路径。默认值为serving_client。 + +model_filename (str,可选) – 存储需要转换的模型Inference Program结构的文件名称。如果设置为None,则使用 `__model__` 作为默认的文件名。默认值为None。 + params_filename (str,可选) – 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None。默认值为None。 -serving_server (str, 可选) - 转换后的模型文件和配置文件的存储路径。默认值为"serving_server"。 -serving_client (str, 可选) - 转换后的客户端配置文件存储路径。默认值为"serving_client"。 diff --git a/python/examples/bert/bert_web_service.py b/python/examples/bert/bert_web_service.py index 8db64e5eb792a7365ed739bbfb05bf38fd8a0da1..6a5830ea179b033f9f761010d8cf9213d9b1e40b 100644 --- a/python/examples/bert/bert_web_service.py +++ b/python/examples/bert/bert_web_service.py @@ -23,10 +23,10 @@ class BertService(WebService): def load(self): self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=128) - def preprocess(self, feed={}, fetch=[]): - feed_res = [{ - "words": self.reader.process(ins["words"].encode("utf-8")) - } for ins in feed] + def preprocess(self, feed=[], fetch=[]): + feed_res = [ + self.reader.process(ins["words"].encode("utf-8")) for ins in feed + ] return feed_res, fetch diff --git a/python/examples/imagenet/README.md b/python/examples/imagenet/README.md index 536440e73ea43f55a4c93bf126d62e86aa3983e6..415818e715e22e97399c710a61f2463fd166bd19 100644 --- a/python/examples/imagenet/README.md +++ b/python/examples/imagenet/README.md @@ -44,6 +44,6 @@ python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9696 client send inference request ``` -python image_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt +python resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt ``` *the port of server side in this example is 9696 diff --git a/python/examples/imagenet/README_CN.md b/python/examples/imagenet/README_CN.md index c34ccca32b737467e687dfd5e86c3229f4339075..77ade579ba17ad8247b2f118242642a1d3c79927 100644 --- a/python/examples/imagenet/README_CN.md +++ b/python/examples/imagenet/README_CN.md @@ -44,6 +44,6 @@ python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9696 client端进行预测 ``` -python image_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt +python resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt ``` *server端示例中服务端口为9696端口 diff --git a/python/examples/imagenet/benchmark.py b/python/examples/imagenet/benchmark.py index ece222f74c52614100a119e49c3754e22959b7c8..6b21719e7b665906e7abd02a7a3b8aef50136685 100644 --- a/python/examples/imagenet/benchmark.py +++ b/python/examples/imagenet/benchmark.py @@ -39,8 +39,8 @@ def single_func(idx, resource): client.connect([resource["endpoint"][idx % len(resource["endpoint"])]]) start = time.time() - for i in range(1000): - img = reader.process_image(img_list[i]).reshape(-1) + for i in range(100): + img = reader.process_image(img_list[i]) fetch_map = client.predict(feed={"image": img}, fetch=["score"]) end = time.time() return [[end - start]] @@ -49,7 +49,7 @@ def single_func(idx, resource): if __name__ == "__main__": multi_thread_runner = MultiThreadRunner() - endpoint_list = ["127.0.0.1:9393"] + endpoint_list = ["127.0.0.1:9292"] #card_num = 4 #for i in range(args.thread): # endpoint_list.append("127.0.0.1:{}".format(9295 + i % card_num)) diff --git a/python/examples/imagenet/benchmark_batch.py b/python/examples/imagenet/benchmark_batch.py index e531425770cbf9102b7ebd2f5b082c5c4aa14e71..1646fb9a94d6953f90f9f4907aa74940f13c2730 100644 --- a/python/examples/imagenet/benchmark_batch.py +++ b/python/examples/imagenet/benchmark_batch.py @@ -24,6 +24,7 @@ from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import benchmark_args import requests import json +import base64 from image_reader import ImageReader args = benchmark_args() @@ -36,6 +37,10 @@ def single_func(idx, resource): img_list = [] for i in range(1000): img_list.append(open("./image_data/n01440764/" + file_list[i]).read()) + profile_flags = False + if "FLAGS_profile_client" in os.environ and os.environ[ + "FLAGS_profile_client"]: + profile_flags = True if args.request == "rpc": reader = ImageReader() fetch = ["score"] @@ -46,23 +51,43 @@ def single_func(idx, resource): for i in range(1000): if args.batch_size >= 1: feed_batch = [] + i_start = time.time() for bi in range(args.batch_size): img = reader.process_image(img_list[i]) - img = img.reshape(-1) feed_batch.append({"image": img}) + i_end = time.time() + if profile_flags: + print("PROFILE\tpid:{}\timage_pre_0:{} image_pre_1:{}". + format(os.getpid(), + int(round(i_start * 1000000)), + int(round(i_end * 1000000)))) + result = client.predict(feed=feed_batch, fetch=fetch) else: print("unsupport batch size {}".format(args.batch_size)) elif args.request == "http": - raise ("no batch predict for http") + py_version = 2 + server = "http://" + resource["endpoint"][idx % len(resource[ + "endpoint"])] + "/image/prediction" + start = time.time() + for i in range(1000): + if py_version == 2: + image = base64.b64encode( + open("./image_data/n01440764/" + file_list[i]).read()) + else: + image = base64.b64encode(open(image_path, "rb").read()).decode( + "utf-8") + req = json.dumps({"feed": [{"image": image}], "fetch": ["score"]}) + r = requests.post( + server, data=req, headers={"Content-Type": "application/json"}) end = time.time() return [[end - start]] if __name__ == '__main__': multi_thread_runner = MultiThreadRunner() - endpoint_list = ["127.0.0.1:9393"] + endpoint_list = ["127.0.0.1:9292"] #endpoint_list = endpoint_list + endpoint_list + endpoint_list result = multi_thread_runner.run(single_func, args.thread, {"endpoint": endpoint_list}) diff --git a/python/examples/imdb/benchmark.py b/python/examples/imdb/benchmark.py index a734e80ef78a7710ca09a211132e248580c5a48c..b8d7a70f30c5cf2d0ee985a8c30fada8fa9481b3 100644 --- a/python/examples/imdb/benchmark.py +++ b/python/examples/imdb/benchmark.py @@ -16,7 +16,7 @@ import sys import time import requests -from imdb_reader import IMDBDataset +from paddle_serving_app import IMDBDataset from paddle_serving_client import Client from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import benchmark_args @@ -37,26 +37,39 @@ def single_func(idx, resource): client.load_client_config(args.model) client.connect([args.endpoint]) for i in range(1000): - if args.batch_size == 1: - word_ids, label = imdb_dataset.get_words_and_label(line) - fetch_map = client.predict( - feed={"words": word_ids}, fetch=["prediction"]) + if args.batch_size >= 1: + feed_batch = [] + for bi in range(args.batch_size): + word_ids, label = imdb_dataset.get_words_and_label(dataset[ + bi]) + feed_batch.append({"words": word_ids}) + result = client.predict(feed=feed_batch, fetch=["prediction"]) + if result is None: + raise ("predict failed.") else: print("unsupport batch size {}".format(args.batch_size)) elif args.request == "http": - for fn in filelist: - fin = open(fn) - for line in fin: - word_ids, label = imdb_dataset.get_words_and_label(line) - r = requests.post( - "http://{}/imdb/prediction".format(args.endpoint), - data={"words": word_ids, - "fetch": ["prediction"]}) + if args.batch_size >= 1: + feed_batch = [] + for bi in range(args.batch_size): + feed_batch.append({"words": dataset[bi]}) + r = requests.post( + "http://{}/imdb/prediction".format(args.endpoint), + json={"feed": feed_batch, + "fetch": ["prediction"]}) + if r.status_code != 200: + print('HTTP status code -ne 200') + raise ("predict failed.") + else: + print("unsupport batch size {}".format(args.batch_size)) end = time.time() return [[end - start]] multi_thread_runner = MultiThreadRunner() result = multi_thread_runner.run(single_func, args.thread, {}) -print(result) +avg_cost = 0 +for cost in result[0]: + avg_cost += cost +print("total cost {} s of each thread".format(avg_cost / args.thread)) diff --git a/python/examples/imdb/benchmark.sh b/python/examples/imdb/benchmark.sh index d77e184180d5c36de6cb865f6b9797511410a3ba..93dbf830c84bd38f72dd0d8a32139ad6098dc6f8 100644 --- a/python/examples/imdb/benchmark.sh +++ b/python/examples/imdb/benchmark.sh @@ -1,9 +1,12 @@ rm profile_log for thread_num in 1 2 4 8 16 do - $PYTHONROOT/bin/python benchmark.py --thread $thread_num --model imdbo_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 +for batch_size in 1 2 4 8 16 32 64 128 256 512 +do + $PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 echo "========================================" echo "batch size : $batch_size" >> profile_log $PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log tail -n 1 profile >> profile_log done +done diff --git a/python/examples/imdb/benchmark_batch.py b/python/examples/imdb/benchmark_batch.py deleted file mode 100644 index 5891970b5decc34f35723187e44b166e0482c6e9..0000000000000000000000000000000000000000 --- a/python/examples/imdb/benchmark_batch.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint: disable=doc-string-missing - -import sys -import time -import requests -from imdb_reader import IMDBDataset -from paddle_serving_client import Client -from paddle_serving_client.utils import MultiThreadRunner -from paddle_serving_client.utils import benchmark_args - -args = benchmark_args() - - -def single_func(idx, resource): - imdb_dataset = IMDBDataset() - imdb_dataset.load_resource("./imdb.vocab") - dataset = [] - with open("./test_data/part-0") as fin: - for line in fin: - dataset.append(line.strip()) - start = time.time() - if args.request == "rpc": - client = Client() - client.load_client_config(args.model) - client.connect([args.endpoint]) - for i in range(1000): - if args.batch_size >= 1: - feed_batch = [] - for bi in range(args.batch_size): - word_ids, label = imdb_dataset.get_words_and_label(dataset[ - bi]) - feed_batch.append({"words": word_ids}) - result = client.predict(feed=feed_batch, fetch=["prediction"]) - if result is None: - raise ("predict failed.") - else: - print("unsupport batch size {}".format(args.batch_size)) - - elif args.request == "http": - if args.batch_size >= 1: - feed_batch = [] - for bi in range(args.batch_size): - feed_batch.append({"words": dataset[bi]}) - r = requests.post( - "http://{}/imdb/prediction".format(args.endpoint), - json={"feed": feed_batch, - "fetch": ["prediction"]}) - if r.status_code != 200: - print('HTTP status code -ne 200') - raise ("predict failed.") - else: - print("unsupport batch size {}".format(args.batch_size)) - end = time.time() - return [[end - start]] - - -multi_thread_runner = MultiThreadRunner() -result = multi_thread_runner.run(single_func, args.thread, {}) -avg_cost = 0 -for cost in result[0]: - avg_cost += cost -print("total cost {} s of each thread".format(avg_cost / args.thread)) diff --git a/python/examples/imdb/benchmark_batch.sh b/python/examples/imdb/benchmark_batch.sh deleted file mode 100644 index 15b65338b21675fd89056cf32f9a247b385a6a36..0000000000000000000000000000000000000000 --- a/python/examples/imdb/benchmark_batch.sh +++ /dev/null @@ -1,12 +0,0 @@ -rm profile_log -for thread_num in 1 2 4 8 16 -do -for batch_size in 1 2 4 8 16 32 64 128 256 512 -do - $PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1 - echo "========================================" - echo "batch size : $batch_size" >> profile_log - $PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log - tail -n 1 profile >> profile_log -done -done diff --git a/python/examples/imdb/test_client.py b/python/examples/imdb/test_client.py index fdc3ced25377487a2844d57c4e6121801e9fa7fa..74364e5854d223e380cb386f9a8bc68b8517305a 100644 --- a/python/examples/imdb/test_client.py +++ b/python/examples/imdb/test_client.py @@ -13,7 +13,7 @@ # limitations under the License. # pylint: disable=doc-string-missing from paddle_serving_client import Client -from imdb_reader import IMDBDataset +from paddle_serving_app import IMDBDataset import sys client = Client() diff --git a/python/examples/imdb/test_client_batch.py b/python/examples/imdb/test_client_batch.py deleted file mode 100644 index 972b2c9609ca690542fa802f187fb30ed0467a04..0000000000000000000000000000000000000000 --- a/python/examples/imdb/test_client_batch.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint: disable=doc-string-missing - -from paddle_serving_client import Client -import sys -import subprocess -from multiprocessing import Pool -import time - - -def batch_predict(batch_size=4): - client = Client() - client.load_client_config(conf_file) - client.connect(["127.0.0.1:9292"]) - fetch = ["acc", "cost", "prediction"] - feed_batch = [] - for line in sys.stdin: - group = line.strip().split() - words = [int(x) for x in group[1:int(group[0])]] - label = [int(group[-1])] - feed = {"words": words, "label": label} - feed_batch.append(feed) - if len(feed_batch) == batch_size: - fetch_batch = client.batch_predict( - feed_batch=feed_batch, fetch=fetch) - for i in range(batch_size): - print("{} {}".format(fetch_batch[i]["prediction"][1], - feed_batch[i]["label"][0])) - feed_batch = [] - if len(feed_batch) > 0: - fetch_batch = client.batch_predict(feed_batch=feed_batch, fetch=fetch) - for i in range(len(feed_batch)): - print("{} {}".format(fetch_batch[i]["prediction"][1], feed_batch[i][ - "label"][0])) - - -if __name__ == '__main__': - conf_file = sys.argv[1] - batch_size = int(sys.argv[2]) - batch_predict(batch_size) diff --git a/python/examples/imdb/text_classify_service.py b/python/examples/imdb/text_classify_service.py index 4420a99facc7bd3db1c8bf1df0c58765467517de..ae54b99030ee777ad127242d26c13cdbc05645e9 100755 --- a/python/examples/imdb/text_classify_service.py +++ b/python/examples/imdb/text_classify_service.py @@ -14,7 +14,7 @@ # pylint: disable=doc-string-missing from paddle_serving_server.web_service import WebService -from imdb_reader import IMDBDataset +from paddle_serving_app import IMDBDataset import sys diff --git a/python/paddle_serving_app/README.md b/python/paddle_serving_app/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1756b83993e67dcbc66b6809631c5e953eef08d7 --- /dev/null +++ b/python/paddle_serving_app/README.md @@ -0,0 +1,169 @@ +([简体中文](./README_CN.md)|English) + +paddle_serving_app is a tool component of the Paddle Serving framework, and includes functions such as pre-training model download and data pre-processing methods. +It is convenient for users to quickly test and deploy model examples, analyze the performance of prediction services, and debug model prediction services. + +## Install + +```shell +pip install paddle_serving_app +``` + +## Get model list + +```shell +python -m paddle_serving_app.package --model_list +``` + +## Download pre-training model + +```shell +python -m paddle_serving_app.package --get_model senta_bilstm +``` + +11 pre-trained models are built into paddle_serving_app, covering 6 kinds of prediction tasks. +The model files can be directly used for deployment, and the `--tutorial` argument can be added to obtain the deployment method. + +| Prediction task | Model name | +| ------------ | ------------------------------------------------ | +| SentimentAnalysis | 'senta_bilstm', 'senta_bow', 'senta_cnn' | +| SemanticRepresentation | 'ernie_base' | +| ChineseWordSegmentation | 'lac' | +| ObjectDetection | 'faster_rcnn', 'yolov3' | +| ImageSegmentation | 'unet', 'deeplabv3' | +| ImageClassification | 'resnet_v2_50_imagenet', 'mobilenet_v2_imagenet' | + +## Data preprocess API + +paddle_serving_app provides a variety of data preprocessing methods for prediction tasks in the field of CV and NLP. + +- class ChineseBertReader + +Preprocessing for Chinese semantic representation task. + + - `__init__(vocab_file, max_seq_len=20)` + + - vocab_file(st ):Path of dictionary file. + + - max_seq_len(in ,optional):The length of sample after processing. The excess part will be truncated, and the insufficient part will be padding 0. Default 20. + + - `process(line)` + + - line(st ):Text input. + + [example](../examples/bert/bert_client.py) + +- class LACReader + +Preprocessing for Chinese word segmentation task. + + - `__init__(dict_floder)` + - dict_floder(st )Path of dictionary file. + - `process(sent)` + - sent(st ):Text input. + - `parse_result` + - words(st ):Original text input. + - crf_decode(np.array):CRF code predicted by model. + + [example](../examples/bert/lac_web_service.py) + +- class SentaReader + + - `__init__(vocab_path)` + - vocab_path(st ):Path of dictionary file. + - `process(cols)` + - cols(st ):Word segmentation result. + + [example](../examples/senta/senta_web_service.py) + +- The image preprocessing method is more flexible than the above method, and can be combined by the following multiple classes,[example](../examples/imagenet/image_rpc_client.py) + +- class Sequentia + + - `__init__(transforms)` + - transforms(list):List of image preprocessing classes + - `__call__(img)` + - img:The input of image preprocessing. The data type is is related to the first preprocessing method in transforms. + +- class File2Image + + - `__call__(img_path)` + - img_path(str):Path of image file. + +- class URL2Image + + - `__call__(img_url)` + - img_url(str):url of image file. + +- class Normalize + + - `__init__(mean,std)` + - mean(float):Mean + - std(float):Variance + - `__call__(img)` + - img(np.array):Image data in (C,H,W) channels. + +- class CenterCrop + + - `__init__(size)` + - size(list/int): + - `__call__(img)` + - img(np.array):Image data. + +- class Resize + + - `__init__(size, max_size=2147483647, interpolation=None)` + - size(list/int):The expected image size, when the input is a list type, it needs to contain the expected length and width. When the input is int type, the short side will be set to the length of size, and the long side will be scaled proportionally. + - `__call__(img)` + - img(numpy array):Image data. + + +## Timeline tools + +The Timeline tool can be used to visualize the start and end time of various stages such as the preparation data of the prediction service, client wait and server op. +This tool is convenient to analyze the proportion of time occupancy in the prediction service. On this basis, prediction services can be optimized in a targeted manner. + +### How to use + +1. Before making predictions on the client side, turn on the timeline function of each stage in the Paddle Serving framework by environment variables. It will print timeline information in log. + + ```shell + export FLAGS_profile_client=1 # Turn on timeline function of client + export FLAGS_profile_server=1 # Turn on timeline function of server + ``` +2. Perform predictions and redirect client-side logs to files, for example, named as profile. + +3. Export the information in the log file into a trace file. + ```shell + python -m paddle_serving_app.trace --profile_file profile --trace_file trace + ``` + +4. Open the `chrome: // tracing /` URL using Chrome browser. +Load the trace file generated in the previous step through the load button, you can +Visualize the time information of each stage of the forecast service. + +As shown in next figure, the figure shows the timeline of GPU prediction service using [bert example](https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/bert). +The server side starts service with 4 GPU cards, the client side starts 4 processes to request, and the batch size is 1. +In the figure, bert_pre represents the data pre-processing stage of the client, and client_infer represents the stage where the client completes the sending of the prediction request to the receiving result. +The process in the figure represents the process number of the client, and the second line of each process shows the timeline of each op of the server. + +![timeline](../../doc/timeline-example.png) + +## Debug tools + +The inference op of Paddle Serving is implemented based on Paddle inference lib. +Before deploying the prediction service, you may need to check the input and output of the prediction service or check the resource consumption. +Therefore, a local prediction tool is built into the paddle_serving_app, which is used in the same way as sending a request to the server through the client. + +Taking [fit_a_line prediction service](../examples/fit_a_line) as an example, the following code can be used to run local prediction. + +```python +from paddle_serving_app import Debugger +import numpy as np + +debugger = Debugger() +debugger.load_model_config("./uci_housing_model", gpu=False) +data = [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, + -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332] +fetch_map = debugger.predict(feed={"x":data}, fetch = ["price"]) +``` diff --git a/python/paddle_serving_app/README_CN.md b/python/paddle_serving_app/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..75dcf9ae78bec0c00b7662f7427d3816feaeca3d --- /dev/null +++ b/python/paddle_serving_app/README_CN.md @@ -0,0 +1,158 @@ +(简体中文|[English](./README.md)) + +paddle_serving_app是Paddle Serving框架的工具组件,包含了预训练模型下载、数据预处理方法等功能。方便用户快速体验和部署模型示例、分析预测服务性能、调试模型预测服务等。 + +## 安装 + +```shell +pip install paddle_serving_app +``` + +## 获取模型列表 + +```shell +python -m paddle_serving_app.package --model_list +``` + +## 下载预训练模型 + +```shell +python -m paddle_serving_app.package --get_model senta_bilstm +``` + +paddle_serving_app中内置了11中预训练模型,涵盖了6种预测任务。获取到的模型文件可以直接用于部署,添加`--tutorial`参数可以获取对应的部署方式。 + +| 预测服务类型 | 模型名称 | +| ------------ | ------------------------------------------------ | +| 中文情感分析 | 'senta_bilstm', 'senta_bow', 'senta_cnn' | +| 语义理解 | 'ernie_base' | +| 中文分词 | 'lac' | +| 图像检测 | 'faster_rcnn', 'yolov3' | +| 图像分割 | 'unet', 'deeplabv3' | +| 图像分类 | 'resnet_v2_50_imagenet', 'mobilenet_v2_imagenet' | + +## 数据预处理API + +paddle_serving_app针对CV和NLP领域的模型任务,提供了多种常见的数据预处理方法。 + +- class ChineseBertReader + + 中文语义理解模型预处理 + + - `__init__(vocab_file, max_seq_len=20)` + + - vocab_file(str):词典文件路径。 + + - max_seq_len(int,可选):处理后的样本长度,超出的部分会截断,不足的部分会padding 0。默认值20。 + + - `process(line)` + - line(str):输入文本 + + [参考示例](../examples/bert/bert_client.py) + +- class LACReader 中文分词预处理 + + - `__init__(dict_floder)` + - dict_floder(str)词典文件目录 + - `process(sent)` + - sent(str):输入文本 + - `parse_result` + - words(str):原始文本 + - crf_decode(np.array):模型预测结果中的CRF编码 + + [参考示例](../examples/lac/lac_web_service.py) + +- class SentaReader + + - `__init__(vocab_path)` + - vocab_path(str):词典文件目录 + - `process(cols)` + - cols(str):分词后的文本 + + [参考示例](../examples/senta/senta_web_service.py) + +- 图像的预处理方法相比于上述的方法更加灵活多变,可以通过以下的多个类进行组合,[参考示例](../examples/imagenet/image_rpc_client.py) + +- class Sequentia + + - `__init__(transforms)` + - transforms(list):图像预处理方法类的列表 + - `__call__(img)` + - img:图像处理的输入,具体类型与transforms中的第一个预处理方法有关 + +- class File2Image + + - `__call__(img_path)` + - img_path(str):图像文件路径 + +- class URL2Image + + - `__call__(img_url)` + - img_url(str):图像url + +- class Normalize + + - `__init__(mean,std)` + - mean(float):均值 + - std(float):方差 + - `__call__(img)` + - img(np.array):(C,H,W)排列的图像数据 + +- class CenterCrop + + - `__init__(size)` + - size(list/int):预期的裁剪后的大小,list类型时需要包含预期的长和宽,int类型时会返回边长为size的正方形图片 + - `__call__(img)` + - img(np.array):输入图像 + +- class Resize + + - `__init__(size, max_size=2147483647, interpolation=None)` + - size(list/int):预期的图像大小,list类型时需要包含预期的长和宽,int类型时,短边会设置为size的长度,长边按比例缩放 + - `__call__(img)` + - img(numpy array):输入图像 + +## Timeline 工具 + +通过Timeline工具可以将预测服务的准备数据、client等待、server端op等各阶段起止时间可视化,方便分析预测服务中的时间占用比重,在此基础上有针对性地优化预测服务。 + +### 使用方式 + +1. client端在进行预测之前,通过环境变量打开Paddle Serving框架中的各阶段日志打点功能 + + ```shell + export FLAGS_profile_client=1 #开启client端各阶段时间打点 + export FLAGS_profile_server=1 #开启server端各阶段时间打点 + ``` + +2. 执行预测,并将client端的日志重定向到文件中,例如profile文件。 + +3. 将日志文件中的信息导出成为trace文件 + + ```shell + python -m paddle_serving_app.trace --profile_file profile --trace_file trace + ``` + +4. 使用chrome浏览器,打开`chrome://tracing/`网址,通过load按钮加载上一步产生的trace文件,即可将预测服务的各阶段时间信息可视化。 + + 效果如下图,图中展示了使用[bert示例](https://github.com/PaddlePaddle/Serving/tree/develop/python/examples/bert)的GPU预测服务,server端开启4卡预测,client端启动4进程,batch size为1时的各阶段timeline。 +其中bert_pre代表client端的数据预处理阶段,client_infer代表client完成预测请求的发送到接收结果的阶段,图中的process代表的是client的进程号,每个进程的第二行展示的是server各个op的timeline。 + + ![timeline](../../doc/timeline-example.png) + +## Debug工具 + +Paddle Serving框架的server预测op使用了Paddle 的预测框架,在部署预测服务之前可能需要对预测服务的输入输出进行检验或者查看资源占用等。因此在paddle_serving_app中内置了本地预测工具,使用方式与通过client向服务端发送请求一致。 + +以[fit_a_line预测服务](../examples/fit_a_line)为例,使用以下代码即可执行本地预测。 + +```python +from paddle_serving_app import Debugger +import numpy as np + +debugger = Debugger() +debugger.load_model_config("./uci_housing_model", gpu=False) +data = [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, + -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332] +fetch_map = debugger.predict(feed={"x":data}, fetch = ["price"]) +``` diff --git a/python/paddle_serving_app/__init__.py b/python/paddle_serving_app/__init__.py index 2c9d658f93265fc69fcf5d5c2a40aa7402b1b17a..2a6225570c3de61ba6e0a0587f81175816cd0f8d 100644 --- a/python/paddle_serving_app/__init__.py +++ b/python/paddle_serving_app/__init__.py @@ -15,5 +15,6 @@ from .reader.chinese_bert_reader import ChineseBertReader from .reader.image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, PadStride from .reader.lac_reader import LACReader from .reader.senta_reader import SentaReader +from .reader.imdb_reader import IMDBDataset from .models import ServingModels from .local_predict import Debugger diff --git a/python/paddle_serving_app/models/model_list.py b/python/paddle_serving_app/models/model_list.py index a2019997968ce21a30669b2acd1421355b1e0fdd..ca33933b7797dc0bd1cb5881e73ba3fc9d82f3c1 100644 --- a/python/paddle_serving_app/models/model_list.py +++ b/python/paddle_serving_app/models/model_list.py @@ -25,7 +25,9 @@ class ServingModels(object): self.model_dict["SemanticRepresentation"] = ["ernie_base"] self.model_dict["ChineseWordSegmentation"] = ["lac"] self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov3"] - self.model_dict["ImageSegmentation"] = ["unet", "deeplabv3"] + self.model_dict["ImageSegmentation"] = [ + "unet", "deeplabv3", "mobilenet_cityspaces" + ] self.model_dict["ImageClassification"] = [ "resnet_v2_50_imagenet", "mobilenet_v2_imagenet" ] diff --git a/python/paddle_serving_app/package.py b/python/paddle_serving_app/package.py index e27914931d4f64c98627cd54025fcf87ac0f241d..250ee99f5130736945a6b77eb4d0bf5a2074a703 100644 --- a/python/paddle_serving_app/package.py +++ b/python/paddle_serving_app/package.py @@ -72,7 +72,7 @@ if __name__ == "__main__": Usage: Download a package for serving directly Example: - python -m paddle_serving_app.models --get senta_bilstm + python -m paddle_serving_app.models --get_model senta_bilstm python -m paddle_serving_app.models --list_model """) pass diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index a5afb9c84743fe401ab62608b7b38b5ccd6623ae..7988bf447b5a0a075171d93d22dd1933aa8532b8 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -13,14 +13,19 @@ # limitations under the License. import cv2 import os -import urllib import numpy as np import base64 +import sys from . import functional as F from PIL import Image, ImageDraw import json _cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"} +py_version = sys.version_info[0] +if py_version == 2: + import urllib +else: + import urllib.request as urllib def generate_colormap(num_classes): @@ -393,7 +398,7 @@ class Normalize(object): class Lambda(object): """Apply a user-defined lambda as a transform. - Very shame to just copy from + Very shame to just copy from https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py#L301 Args: diff --git a/python/paddle_serving_app/reader/imdb_reader.py b/python/paddle_serving_app/reader/imdb_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ef3e163a50b0dc244ac2653df1e38d7f91699b --- /dev/null +++ b/python/paddle_serving_app/reader/imdb_reader.py @@ -0,0 +1,92 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import sys +import os +import paddle +import re +import paddle.fluid.incubate.data_generator as dg + +py_version = sys.version_info[0] + + +class IMDBDataset(dg.MultiSlotDataGenerator): + def load_resource(self, dictfile): + self._vocab = {} + wid = 0 + if py_version == 2: + with open(dictfile) as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + else: + with open(dictfile, encoding="utf-8") as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + self._unk_id = len(self._vocab) + self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))') + self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0]) + + def get_words_only(self, line): + sent = line.lower().replace("
", " ").strip() + words = [x for x in self._pattern.split(sent) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + return feas + + def get_words_and_label(self, line): + send = '|'.join(line.split('|')[:-1]).lower().replace("
", + " ").strip() + label = [int(line.split('|')[-1])] + + words = [x for x in self._pattern.split(send) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + return feas, label + + def infer_reader(self, infer_filelist, batch, buf_size): + def local_iter(): + for fname in infer_filelist: + with open(fname, "r") as fin: + for line in fin: + feas, label = self.get_words_and_label(line) + yield feas, label + + import paddle + batch_iter = paddle.batch( + paddle.reader.shuffle( + local_iter, buf_size=buf_size), + batch_size=batch) + return batch_iter + + def generate_sample(self, line): + def memory_iter(): + for i in range(1000): + yield self.return_value + + def data_iter(): + feas, label = self.get_words_and_label(line) + yield ("words", feas), ("label", label) + + return data_iter + + +if __name__ == "__main__": + imdb = IMDBDataset() + imdb.load_resource("imdb.vocab") + imdb.run_from_stdin() diff --git a/python/paddle_serving_app/trace.py b/python/paddle_serving_app/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7f35b672d8d9bd7e9b8c64c5004eca7b9f6795 --- /dev/null +++ b/python/paddle_serving_app/trace.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import sys +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser("Convert profile log to trace") + parser.add_argument( + "--profile_file", + type=str, + default="", + required=True, + help="Profile log") + parser.add_argument( + "--trace_file", type=str, default="trace", help="Trace file") + return parser.parse_args() + + +def prase(pid_str, time_str, counter): + pid = pid_str.split(":")[1] + event_list = time_str.split(" ") + trace_list = [] + for event in event_list: + name, ts = event.split(":") + name_list = name.split("_") + ph = "B" if (name_list[-1] == "0") else "E" + if len(name_list) == 2: + name = name_list[0] + else: + name = name_list[0] + "_" + name_list[1] + event_dict = {} + event_dict["name"] = name + event_dict["tid"] = 0 + event_dict["pid"] = pid + event_dict["ts"] = ts + event_dict["ph"] = ph + + trace_list.append(event_dict) + return trace_list + + +if __name__ == "__main__": + args = parse_args() + profile_file = args.profile_file + trace_file = args.trace_file + all_list = [] + counter = 0 + with open(profile_file) as f: + for line in f.readlines(): + line = line.strip().split("\t") + if line[0] == "PROFILE": + trace_list = prase(line[1], line[2], counter) + counter += 1 + for trace in trace_list: + all_list.append(trace) + + trace = json.dumps(all_list, indent=2, separators=(',', ':')) + with open(trace_file, "w") as f: + f.write(trace) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 16fce4f3106c840f500e0cdcdf35ff97e0c3c844..8c189d415b5718788da2ff0e6757ba3af259e750 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -201,7 +201,12 @@ class Client(object): def shape_check(self, feed, key): if key in self.lod_tensor_set: return - if len(feed[key]) != self.feed_tensor_len[key]: + if isinstance(feed[key], + list) and len(feed[key]) != self.feed_tensor_len[key]: + raise SystemExit("The shape of feed tensor {} not match.".format( + key)) + if type(feed[key]).__module__ == np.__name__ and np.size(feed[ + key]) != self.feed_tensor_len[key]: raise SystemExit("The shape of feed tensor {} not match.".format( key)) @@ -252,23 +257,16 @@ class Client(object): for key in feed_i: if key not in self.feed_names_: raise ValueError("Wrong feed name: {}.".format(key)) - if not isinstance(feed_i[key], np.ndarray): - self.shape_check(feed_i, key) + #if not isinstance(feed_i[key], np.ndarray): + self.shape_check(feed_i, key) if self.feed_types_[key] == int_type: if i == 0: int_feed_names.append(key) if isinstance(feed_i[key], np.ndarray): - if key in self.lod_tensor_set: - raise ValueError( - "LodTensor var can not be ndarray type.") int_shape.append(list(feed_i[key].shape)) else: int_shape.append(self.feed_shapes_[key]) if isinstance(feed_i[key], np.ndarray): - if key in self.lod_tensor_set: - raise ValueError( - "LodTensor var can not be ndarray type.") - #int_slot.append(np.reshape(feed_i[key], (-1)).tolist()) int_slot.append(feed_i[key]) self.has_numpy_input = True else: @@ -278,17 +276,10 @@ class Client(object): if i == 0: float_feed_names.append(key) if isinstance(feed_i[key], np.ndarray): - if key in self.lod_tensor_set: - raise ValueError( - "LodTensor var can not be ndarray type.") float_shape.append(list(feed_i[key].shape)) else: float_shape.append(self.feed_shapes_[key]) if isinstance(feed_i[key], np.ndarray): - if key in self.lod_tensor_set: - raise ValueError( - "LodTensor var can not be ndarray type.") - #float_slot.append(np.reshape(feed_i[key], (-1)).tolist()) float_slot.append(feed_i[key]) self.has_numpy_input = True else: diff --git a/python/paddle_serving_client/io/__init__.py b/python/paddle_serving_client/io/__init__.py index 4f174866e5521577ba35f39216f7dd0793879a6c..93ae37056320c2c7d779c5bbfc4d004a1be4f639 100644 --- a/python/paddle_serving_client/io/__init__.py +++ b/python/paddle_serving_client/io/__init__.py @@ -104,10 +104,10 @@ def save_model(server_model_folder, def inference_model_to_serving(dirname, - model_filename=None, - params_filename=None, serving_server="serving_server", - serving_client="serving_client"): + serving_client="serving_client", + model_filename=None, + params_filename=None): place = fluid.CPUPlace() exe = fluid.Executor(place) inference_program, feed_target_names, fetch_targets = \ diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py index 971359fca0df3a122b28889e0711c86364a1c45d..3cb96a8f04922362fdb4b4c497f7679355e3879f 100644 --- a/python/paddle_serving_server/__init__.py +++ b/python/paddle_serving_server/__init__.py @@ -274,7 +274,8 @@ class Server(object): self.model_config_paths[node.name] = path print("You have specified multiple model paths, please ensure " "that the input and output of multiple models are the same.") - workflow_oi_config_path = self.model_config_paths.items()[0][1] + workflow_oi_config_path = list(self.model_config_paths.items())[0][ + 1] else: raise Exception("The type of model_config_paths must be str or " "dict({op: model_path}), not {}.".format( diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py index 5a06bd712a836617047b0cc947956fc5d2213daa..7acc926c7f7fc465da20a7609bc767a5289d2e61 100644 --- a/python/paddle_serving_server_gpu/__init__.py +++ b/python/paddle_serving_server_gpu/__init__.py @@ -320,7 +320,8 @@ class Server(object): self.model_config_paths[node.name] = path print("You have specified multiple model paths, please ensure " "that the input and output of multiple models are the same.") - workflow_oi_config_path = self.model_config_paths.items()[0][1] + workflow_oi_config_path = list(self.model_config_paths.items())[0][ + 1] else: raise Exception("The type of model_config_paths must be str or " "dict({op: model_path}), not {}.".format( diff --git a/tools/Dockerfile.centos6.devel b/tools/Dockerfile.centos6.devel index dd5a2ef786ed8a9c239a99cabbcfe2d482e6341c..5223693d846bdbc90bdefe58c26db29d6a81359d 100644 --- a/tools/Dockerfile.centos6.devel +++ b/tools/Dockerfile.centos6.devel @@ -43,5 +43,5 @@ RUN yum -y install wget && \ source /root/.bashrc && \ cd .. && rm -rf Python-3.6.8* && \ pip3 install google protobuf setuptools wheel flask numpy==1.16.4 && \ - yum -y install epel-release && yum -y install patchelf && \ + yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \ yum clean all diff --git a/tools/Dockerfile.centos6.gpu.devel b/tools/Dockerfile.centos6.gpu.devel index c34780c151e960134af5f8b448e0465b8285e8b2..1432d49abe9a4aec3b558d855c9cfcf30efef461 100644 --- a/tools/Dockerfile.centos6.gpu.devel +++ b/tools/Dockerfile.centos6.gpu.devel @@ -43,5 +43,5 @@ RUN yum -y install wget && \ source /root/.bashrc && \ cd .. && rm -rf Python-3.6.8* && \ pip3 install google protobuf setuptools wheel flask numpy==1.16.4 && \ - yum -y install epel-release && yum -y install patchelf && \ + yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \ yum clean all diff --git a/tools/Dockerfile.devel b/tools/Dockerfile.devel index 6cb228f587054d5b579df0d85109d41c15c128e9..385e568273eab54f7dfa51a20bb7dcd89cfa98a8 100644 --- a/tools/Dockerfile.devel +++ b/tools/Dockerfile.devel @@ -20,5 +20,5 @@ RUN yum -y install wget >/dev/null \ && rm get-pip.py \ && yum install -y python3 python3-devel \ && pip3 install google protobuf setuptools wheel flask \ - && yum -y install epel-release && yum -y install patchelf \ + && yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\ && yum clean all diff --git a/tools/Dockerfile.gpu.devel b/tools/Dockerfile.gpu.devel index 8cd7a6dbbddd5e1b60b7833086aa25cd849da519..2ffbe4601e1f7e9b05c87f9562b3e0ffc4b967ff 100644 --- a/tools/Dockerfile.gpu.devel +++ b/tools/Dockerfile.gpu.devel @@ -21,5 +21,5 @@ RUN yum -y install wget >/dev/null \ && rm get-pip.py \ && yum install -y python3 python3-devel \ && pip3 install google protobuf setuptools wheel flask \ - && yum -y install epel-release && yum -y install patchelf \ + && yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\ && yum clean all diff --git a/tools/python_tag.py b/tools/python_tag.py index 75947cff0b1b39d4c262a306bbe2bc878ae7d3ba..7c0fb5aa9928bb83c51df698b2f66df17793feb1 100644 --- a/tools/python_tag.py +++ b/tools/python_tag.py @@ -15,6 +15,6 @@ from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag import re with open("setup.cfg", "w") as f: - line = "[bdist_wheel]\npython-tag={0}{1}\nplat-name=linux_x86_64".format( + line = "[bdist_wheel]\npython-tag={0}{1}\nplat-name=manylinux1_x86_64".format( get_abbr_impl(), get_impl_ver()) f.write(line) diff --git a/tools/serving_build.sh b/tools/serving_build.sh index a522efe19cb9f4170341f291d8c30db0e6749ad1..43e55174ab30374d853ed1bb25aa4a9cc637afd5 100644 --- a/tools/serving_build.sh +++ b/tools/serving_build.sh @@ -343,7 +343,7 @@ function python_test_imdb() { sleep 5 check_cmd "head test_data/part-0 | python test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab" # test batch predict - check_cmd "python benchmark_batch.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc --endpoint 127.0.0.1:9292" + check_cmd "python benchmark.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc --endpoint 127.0.0.1:9292" echo "imdb CPU RPC inference pass" kill_server_process rm -rf work_dir1 @@ -359,7 +359,7 @@ function python_test_imdb() { exit 1 fi # test batch predict - check_cmd "python benchmark_batch.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request http --endpoint 127.0.0.1:9292" + check_cmd "python benchmark.py --thread 4 --batch_size 8 --model imdb_bow_client_conf/serving_client_conf.prototxt --request http --endpoint 127.0.0.1:9292" setproxy # recover proxy state kill_server_process ps -ef | grep "text_classify_service.py" | grep -v grep | awk '{print $2}' | xargs kill