未验证 提交 1d8ebc26 编写于 作者: M MRXLT 提交者: GitHub

Merge branch 'develop' into doc-fix

# Pipeline Serving
([简体中文](PIPELINE_SERVING_CN.md)|English)
Paddle Serving is usually used for the deployment of single model, but the end-to-end deep learning model can not solve all the problems at present. Usually, it is necessary to use multiple deep learning models to solve practical problems.
Paddle Serving provides a user-friendly programming framework for multi-model composite services, Pipeline Serving, which aims to reduce the threshold of programming, improve resource utilization (especially GPU), and improve the prediction efficiency.
## Architecture Design
The Server side is built based on gRPC and graph execution engine. The relationship between them is shown in the following figure.
<center>
<img src='pipeline_serving-image1.png' height = "250" align="middle"/>
</center>
### Graph Execution Engine
The graph execution engine consists of OPs and Channels, and the connected OPs share one Channel.
- Channel can be understood as a buffer queue. Each OP accepts only one Channel input and multiply Channel outputs (each output is the same); a Channel can contain outputs from multiple OPs, and data from the same Channel can be used as input for multiple OPs.
- Users only need to define relationships between OPs. Graph engine will analyze the dependencies of the entire graph and declaring Channels at the compile time.
- After Request data enters the graph execution engine service, the graph engine will generator an Request ID, and Reponse is returned through corresponding Request ID.
- For cases where large data needs to be transferred between OPs, consider RAM DB external memory for global storage and data transfer by passing index keys in Channel.
<center>
<img src='pipeline_serving-image2.png' height = "300" align="middle"/>
</center>
### OP Design
- The default function of a single OP is to access a single Paddle Serving Service based on the input Channel data and put the result into the output Channel.
- OP supports user customization, including preprocess, process, postprocess functions that can be inherited and implemented by the user.
- OP can set the number of concurrencies to increase the number of concurrencies processed.
- OP can be started by a thread or process.
### Channel Design
- Channel is the data structure for sharing data between OPs, responsible for sharing data or sharing data status information.
- Outputs from multiple OPs can be stored in the same Channel, and data from the same Channel can be used by multiple OPs.
- The following illustration shows the design of Channel in the graph execution engine, using input buffer and output buffer to align data between multiple OP inputs and multiple OP outputs, with a queue in the middle to buffer.
<center>
<img src='pipeline_serving-image3.png' height = "500" align="middle"/>
</center>
### Extreme Case Consideration
- Request timeout
The entire graph execution engine may time out at every step. The graph execution engine controls the time out by setting `timeout` value. Requests that time out at any step will return a timeout response.
- Channel stores too much data
Channels may store too much data, causing copy time to be too high. Graph execution engines can store OP calculation results in external memory, such as high-speed memory KV systems.
- Whether input buffers and output buffers in Channel will increase indefinitely
- It will not increase indefinitely. The input to the entire graph execution engine is placed inside a Channel's internal queue, directly acting as a traffic control buffer queue for the entire service.
- For input buffer, adjust the number of concurrencies of OP1 and OP2 according to the amount of computation, so that the number of input buffers from each input OP is relatively balanced.
- For output buffer, you can use a similar process as input buffer, which adjusts the concurrency of OP3 and OP4 to control the buffer length of output buffer.
- Note: The length of the input buffer depends on the speed at which each item in the internal queue is ready, and the length of the output buffer depends on the speed at which downstream OPs obtain data from the output buffer.
## Detailed Design
### User Interface Design
#### 1. General OP Definition
As the basic unit of graph execution engine, the general OP constructor is as follows:
```python
def __init__(name=None,
input_ops=[],
server_endpoints=[],
fetch_list=[],
client_config=None,
concurrency=1,
timeout=-1,
retry=1)
```
The meaning of each parameter is as follows:
| Parameter | Meaning |
| :--------------: | :----------------------------------------------------------: |
| name | (str) String used to identify the OP type, which must be globally unique. |
| input_ops | (list) A list of all previous OPs of the current Op. |
| server_endpoints | (list) List of endpoints for remote Paddle Serving Service. If this parameter is not set, the OP will not access the remote Paddle Serving Service, that is, the process operation will not be performed. |
| fetch_list | (list) List of fetch variable names for remote Paddle Serving Service. |
| client_config | (str) The path of the client configuration file corresponding to the Paddle Serving Service. |
| concurrency | (int) The number of concurrent OPs. |
| timeout | (int) The timeout time of the process operation, in seconds. If the value is less than zero, no timeout is considered. |
| retry | (int) Timeout number of retries. When the value is 1, no retries are made. |
#### 2. General OP Secondary Development Interface
| Interface or Variable | Explain |
| :--------------------------------------------: | :----------------------------------------------------------: |
| def preprocess(self, input_dicts) | Process the data obtained from the channel, and the processed data will be used as the input of the **process** function. |
| def process(self, feed_dict) | The RPC prediction process is based on the Paddle Serving Client, and the processed data will be used as the input of the **postprocess** function. |
| def postprocess(self, input_dicts, fetch_dict) | After processing the prediction results, the processed data will be put into the subsequent Channel to be obtained by the subsequent OP. |
| def init_op(self) | Used to load resources (such as word dictionary). |
| self.concurrency_idx | Concurrency index of current thread / process (different kinds of OP are calculated separately). |
In a running cycle, OP will execute three operations: preprocess, process, and postprocess (when the `server_endpoints` parameter is not set, the process operation is not executed). Users can rewrite these three functions. The default implementation is as follows:
```python
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this func.'
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names)
return call_result
def postprocess(self, input_dicts, fetch_dict):
return fetch_dict
```
The parameter of **preprocess** is the data `input_dicts` in the previous Channel. This variable is a dictionary with the name of the previous OP as key and the output of the corresponding OP as value.
The parameter of **process** is the input variable `fetch_dict` (the return value of the preprocess function) of the Paddle Serving Client prediction interface. This variable is a dictionary with feed_name as the key and the data in the ndarray format as the value.
The parameters of **postprocess** are `input_dicts` and `fetch_dict`. `input_dicts` is consistent with the parameter of preprocess, and `fetch_dict` is the return value of the process function (if process is not executed, this value is the return value of preprocess).
Users can also rewrite the **init_op** function to load some custom resources (such as word dictionary). The default implementation is as follows:
```python
def init_op(self):
pass
```
It should be noted that in the threaded version of OP, each OP will only call this function once, so the loaded resources must be thread safe.
#### 3. RequestOp Definition
RequestOp is used to process RPC data received by Pipeline Server, and the processed data will be added to the graph execution engine. Its constructor is as follows:
```python
def __init__(self)
```
#### 4. RequestOp Secondary Development Interface
| Interface or Variable | Explain |
| :---------------------------------------: | :----------------------------------------------------------: |
| def init_op(self) | It is used to load resources (such as dictionaries), and is consistent with general OP. |
| def unpack_request_package(self, request) | Process received RPC data. |
The default implementation of **unpack_request_package** is to make the key and value in RPC request into a dictionary:
```python
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
data = request.value[idx]
try:
data = eval(data)
except Exception as e:
pass
dictdata[key] = data
return dictdata
```
The return value is required to be a dictionary type.
#### 5. ResponseOp Definition
ResponseOp is used to process the prediction results of the graph execution engine. The processed data will be used as the RPC return value of Pipeline Server. Its constructor is as follows:
```python
def __init__(self, input_ops)
```
`input_ops` is the last OP of graph execution engine. Users can construct different DAGs by setting different `input_ops` without modifying the topology of OPs.
#### 6. ResponseOp Secondary Development Interface
| Interface or Variable | Explain |
| :------------------------------------------: | :----------------------------------------------------------: |
| def init_op(self) | It is used to load resources (such as dictionaries), and is consistent with general OP. |
| def pack_response_package(self, channeldata) | Process the prediction results of the graph execution engine as the return of RPC. |
The default implementation of **pack_response_package** is to convert the dictionary of prediction results into key and value in RPC response:
```python
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
np.set_printoptions(threshold=np.nan)
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
else:
resp.error_info = channeldata.error_info
return resp
```
#### 7. PipelineServer Definition
The definition of PipelineServer is relatively simple, as follows:
```python
server = PipelineServer()
server.set_response_op(response_op)
server.prepare_server(config_yml_path)
server.run_server()
```
Where `response_op` is the responseop mentioned above, PipelineServer will initialize Channels according to the topology relationship of each OP and build the calculation graph. `config_yml_path` is the configuration file of PipelineServer. The example file is as follows:
```yaml
port: 18080 # gRPC port
worker_num: 1 # gRPC thread pool size (the number of processes in the process version servicer). The default is 1
build_dag_each_worker: false # Whether to use process server or not. The default is false
dag:
is_thread_op: true # Whether to use the thread version of OP. The default is true
client_type: brpc # Use brpc or grpc client. The default is brpc
retry: 1 # The number of times DAG executor retries after failure. The default value is 1, that is, no retrying
use_profile: false # Whether to print the log on the server side. The default is false
```
## Example
Here, we build a simple imdb model enable example to show how to use Pipeline Serving. The relevant code can be found in the `python/examples/pipeline/imdb_model_ensemble` folder. The Server-side structure in the example is shown in the following figure:
<center>
<img src='pipeline_serving-image4.png' height = "200" align="middle"/>
</center>
### Get the model file and start the Paddle Serving Service
```shell
cd python/examples/pipeline/imdb_model_ensemble
sh get_data.sh
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
```
### Start PipelineServer
Run the following code
```python
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
logging.basicConfig(level=logging.DEBUG)
_LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp):
def init_op(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
if key != "words":
continue
words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids)
return dictdata
class CombineOp(Op):
def preprocess(self, input_data):
combined_prediction = 0
for op_name, data in input_data.items():
_LOGGER.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2}
return data
read_op = ImdbRequestOp()
bow_op = Op(name="bow",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
cnn_op = Op(name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=5,
timeout=-1,
retry=1)
# use default ResponseOp implementation
response_op = ResponseOp(input_ops=[combine_op])
server = PipelineServer()
server.set_response_op(response_op)
server.prepare_server('config.yml')
server.run_server()
```
### Perform prediction through PipelineClient
```python
from paddle_serving_client.pipeline import PipelineClient
import numpy as np
client = PipelineClient()
client.connect(['127.0.0.1:18080'])
words = 'i am very sad | 0'
futures = []
for i in range(3):
futures.append(
client.predict(
feed_dict={"words": words},
fetch=["prediction"],
asyn=True))
for f in futures:
res = f.result()
if res["ecode"] != 0:
print(res)
exit(1)
```
## How to optimize through the timeline tool
In order to better optimize the performance, PipelineServing provides a timeline tool to monitor the time of each stage of the whole service.
### Output profile information on server side
The server is controlled by the `use_profile` field in yaml:
```yaml
dag:
use_profile: true
```
After the function is enabled, the server will print the corresponding log information to the standard output in the process of prediction. In order to show the time consumption of each stage more intuitively, scripts are provided for further analysis and processing of log files.
The output of the server is first saved to a file. Taking profile as an example, the script converts the time monitoring information in the log into JSON format and saves it to the trace file. The trace file can be visualized through the tracing function of Chrome browser.
```shell
python timeline_trace.py profile trace
```
Specific operation: open Chrome browser, input in the address bar `chrome://tracing/` , jump to the tracing page, click the load button, open the saved trace file, and then visualize the time information of each stage of the prediction service.
### Output profile information on client side
The profile function can be enabled by setting `profile=True` in the `predict` interface on the client side.
After the function is enabled, the client will print the log information corresponding to the prediction to the standard output during the prediction process, and the subsequent analysis and processing are the same as that of the server.
# Pipeline Serving
(简体中文|[English](PIPELINE_SERVING.md))
Paddle Serving 通常用于单模型的一键部署,但端到端的深度学习模型当前还不能解决所有问题,多个深度学习模型配合起来使用还是解决现实问题的常规手段。
Paddle Serving 提供了用户友好的多模型组合服务编程框架,Pipeline Serving,旨在降低编程门槛,提高资源使用率(尤其是GPU设备),提升整体的预估效率。
## 整体架构设计
Server端基于 gRPC 和图执行引擎构建,两者的关系如下图所示。
<center>
<img src='pipeline_serving-image1.png' height = "250" align="middle"/>
</center>
### 图执行引擎
图执行引擎由 OP 和 Channel 构成,相连接的 OP 之间会共享一个 Channel。
- Channel 可以理解为一个缓冲队列。每个 OP 只接受一个 Channel 的输入和多个 Channel 的输出(每个输出相同);一个 Channel 可以包含来自多个 OP 的输出,同一个 Channel 的数据可以作为多个 OP 的输入Channel
- 用户只需要定义 OP 间的关系,在编译期图引擎负责分析整个图的依赖关系,并声明Channel
- Request 进入图执行引擎服务后会产生一个 Request Id,Reponse 会通过 Request Id 进行对应的返回
- 对于 OP 之间需要传输过大数据的情况,可以考虑 RAM DB 外存进行全局存储,通过在 Channel 中传递索引的 Key 来进行数据传输
<center>
<img src='pipeline_serving-image2.png' height = "300" align="middle"/>
</center>
### OP的设计
- 单个OP默认的功能是根据输入的 Channel 数据,访问一个 Paddle Serving 的单模型服务,并将结果存在输出的 Channel
- 单个 OP 可以支持用户自定义,包括 preprocess,process,postprocess 三个函数都可以由用户继承和实现
- 单个 OP 可以控制并发数,从而增加处理并发数
- OP 可以由线程或进程启动
### Channel的设计
- Channel 是 OP 之间共享数据的数据结构,负责共享数据或者共享数据状态信息
- Channel 可以支持多个OP的输出存储在同一个 Channel,同一个 Channel 中的数据可以被多个 OP 使用
- 下图为图执行引擎中 Channel 的设计,采用 input buffer 和 output buffer 进行多 OP 输入或多 OP 输出的数据对齐,中间采用一个 Queue 进行缓冲
<center>
<img src='pipeline_serving-image3.png' height = "500" align="middle"/>
</center>
### 极端情况的考虑
- 请求超时的处理
整个图执行引擎每一步都有可能发生超时,图执行引擎里面通过设置 timeout 值来控制,任何环节超时的请求都会返回超时响应。
- Channel 存储的数据过大
Channel 中可能会存储过大的数据,导致拷贝等耗时过高,图执行引擎里面可以通过将 OP 计算结果数据存储到外存,如高速的内存 KV 系统
- Channel 设计中的 input buffer 和 output buffer 是否会无限增加
- 不会。整个图执行引擎的输入会放到一个 Channel 的 internal queue 里面,直接作为整个服务的流量控制缓冲队列
- 对于 input buffer,根据计算量的情况调整 OP1 和 OP2 的并发数,使得 input buffer 来自各个输入 OP 的数量相对平衡
- 对于 output buffer,可以采用和 input buffer 类似的处理方法,即调整 OP3 和 OP4 的并发数,使得 output buffer 的缓冲长度得到控制
- 注:input buffer 的长度取决于 internal queue 中每个 item 完全 ready 的速度,output buffer 的长度取决于下游 OP 从 output buffer 获取数据的速度
## 详细设计
### 用户接口设计
#### 1. 普通 OP 定义
普通 OP 作为图执行引擎中的基本单元,其构造函数如下:
```python
def __init__(name=None,
input_ops=[],
server_endpoints=[],
fetch_list=[],
client_config=None,
concurrency=1,
timeout=-1,
retry=1)
```
各参数含义如下
| 参数名 | 含义 |
| :--------------: | :----------------------------------------------------------: |
| name | (str)用于标识 OP 类型的字符串,该字段必须全局唯一。 |
| input_ops | (list)当前 OP 的所有前继 OP 的列表。 |
| server_endpoints | (list)远程 Paddle Serving Service 的 endpoints 列表。如果不设置该参数,则不访问远程 Paddle Serving Service,即 不会执行 process 操作。 |
| fetch_list | (list)远程 Paddle Serving Service 的 fetch 列表。 |
| client_config | (str)Paddle Serving Service 对应的 Client 端配置文件路径。 |
| concurrency | (int)OP 的并发数。 |
| timeout | (int)process 操作的超时时间,单位为秒。若该值小于零,则视作不超时。 |
| retry | (int)超时重试次数。当该值为 1 时,不进行重试。 |
#### 2. 普通 OP二次开发接口
| 变量或接口 | 说明 |
| :--------------------------------------------: | :----------------------------------------------------------: |
| def preprocess(self, input_dicts) | 对从 Channel 中获取的数据进行处理,处理完的数据将作为 **process** 函数的输入。 |
| def process(self, feed_dict) | 基于 Paddle Serving Client 进行 RPC 预测,处理完的数据将作为 **postprocess** 函数的输入。 |
| def postprocess(self, input_dicts, fetch_dict) | 处理预测结果,处理完的数据将被放入后继 Channel 中,以被后继 OP 获取。 |
| def init_op(self) | 用于加载资源(如字典等)。 |
| self.concurrency_idx | 当前线程(进程)的并发数索引(不同种类的 OP 单独计算)。 |
OP 在一个运行周期中会依次执行 preprocess,process,postprocess 三个操作(当不设置 `server_endpoints` 参数时,不执行 process 操作),用户可以对这三个函数进行重写,默认实现如下:
```python
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this func.'
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names)
return call_result
def postprocess(self, input_dicts, fetch_dict):
return fetch_dict
```
**preprocess** 的参数是前继 Channel 中的数据 `input_dicts`,该变量是一个以前继 OP 的 name 为 Key,对应 OP 的输出为 Value 的字典。
**process** 的参数是 Paddle Serving Client 预测接口的输入变量 `fetch_dict`(preprocess 函数的返回值),该变量是一个以 feed_name 为 Key,对应 ndarray 格式的数据为 Value 的字典。
**postprocess** 的参数是 `input_dicts``fetch_dict``input_dicts` 与 preprocess 的参数一致,`fetch_dict` 是 process 函数的返回值(如果没有执行 process ,则该值为 preprocess 的返回值)。
用户还可以对 **init_op** 函数进行重写,已加载自定义的一些资源(比如字典等),默认实现如下:
```python
def init_op(self):
pass
```
需要注意的是,在线程版 OP 中,每个 OP 只会调用一次该函数,故加载的资源必须要求是线程安全的。
#### 3. RequestOp 定义
RequestOp 用于处理 Pipeline Server 接收到的 RPC 数据,处理后的数据将会被加入到图执行引擎中。其构造函数如下:
```python
def __init__(self)
```
#### 4. RequestOp 二次开发接口
| 变量或接口 | 说明 |
| :---------------------------------------: | :----------------------------------------: |
| def init_op(self) | 用于加载资源(如字典等),与普通 OP 一致。 |
| def unpack_request_package(self, request) | 处理接收到的 RPC 数据。 |
**unpack_request_package** 的默认实现是将 RPC request 中的 key 和 value 做成字典:
```python
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
data = request.value[idx]
try:
data = eval(data)
except Exception as e:
pass
dictdata[key] = data
return dictdata
```
要求返回值是一个字典类型。
#### 5. ResponseOp 定义
ResponseOp 用于处理图执行引擎的预测结果,处理后的数据将会作为 Pipeline Server 的RPC 返回值,其构造函数如下:
```python
def __init__(self, input_ops)
```
其中,`input_ops` 是图执行引擎的最后一个 OP,用户可以通过设置不同的 `input_ops` 以在不修改 OP 的拓扑关系下构造不同的 DAG。
#### 6. ResponseOp 二次开发接口
| 变量或接口 | 说明 |
| :------------------------------------------: | :-----------------------------------------: |
| def init_op(self) | 用于加载资源(如字典等),与普通 OP 一致。 |
| def pack_response_package(self, channeldata) | 处理图执行引擎的预测结果,作为 RPC 的返回。 |
**pack_response_package** 的默认实现是将预测结果的字典转化为 RPC response 中的 key 和 value:
```python
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
np.set_printoptions(threshold=np.nan)
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
else:
resp.error_info = channeldata.error_info
return resp
```
#### 7. PipelineServer定义
PipelineServer 的定义比较简单,如下所示:
```python
server = PipelineServer()
server.set_response_op(response_op)
server.prepare_server(config_yml_path)
server.run_server()
```
其中,`response_op` 为上面提到的 ResponseOp,PipelineServer 将会根据各个 OP 的拓扑关系初始化 Channel 并构建计算图。`config_yml_path` 为 PipelineServer 的配置文件,示例文件如下:
```yaml
port: 18080 # gRPC端口号
worker_num: 1 # gRPC线程池大小(进程版 Servicer 中为进程数),默认为 1
build_dag_each_worker: false # 是否使用进程版 Servicer,默认为 false
dag:
is_thread_op: true # 是否使用线程版Op,默认为 true
client_type: brpc # 使用 brpc 或 grpc client,默认为 brpc
retry: 1 # DAG Executor 在失败后重试次数,默认为 1,即不重试
use_profile: false # 是否在 Server 端打印日志,默认为 false
```
## 例子
这里通过搭建简单的 imdb model ensemble 例子来展示如何使用 Pipeline Serving,相关代码在 `python/examples/pipeline/imdb_model_ensemble` 文件夹下可以找到,例子中的 Server 端结构如下图所示:
<center>
<img src='pipeline_serving-image4.png' height = "200" align="middle"/>
</center>
### 获取模型文件并启动 Paddle Serving Service
```shell
cd python/examples/pipeline/imdb_model_ensemble
sh get_data.sh
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
```
### 启动 PipelineServer
运行下面代码
```python
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
logging.basicConfig(level=logging.DEBUG)
_LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp):
def init_op(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
if key != "words":
continue
words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids)
return dictdata
class CombineOp(Op):
def preprocess(self, input_data):
combined_prediction = 0
for op_name, data in input_data.items():
_LOGGER.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2}
return data
read_op = ImdbRequestOp()
bow_op = Op(name="bow",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
cnn_op = Op(name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
concurrency=1,
timeout=-1,
retry=1)
combine_op = CombineOp(
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=5,
timeout=-1,
retry=1)
# use default ResponseOp implementation
response_op = ResponseOp(input_ops=[combine_op])
server = PipelineServer()
server.set_response_op(response_op)
server.prepare_server('config.yml')
server.run_server()
```
### 通过 PipelineClient 执行预测
```python
from paddle_serving_client.pipeline import PipelineClient
import numpy as np
client = PipelineClient()
client.connect(['127.0.0.1:18080'])
words = 'i am very sad | 0'
futures = []
for i in range(3):
futures.append(
client.predict(
feed_dict={"words": words},
fetch=["prediction"],
asyn=True))
for f in futures:
res = f.result()
if res["ecode"] != 0:
print(res)
exit(1)
```
## 如何通过 Timeline 工具进行优化
为了更好地对性能进行优化,PipelineServing 提供了 Timeline 工具,对整个服务的各个阶段时间进行打点。
### 在 Server 端输出 Profile 信息
Server 端用 yaml 中的 `use_profile` 字段进行控制:
```yaml
dag:
use_profile: true
```
开启该功能后,Server 端在预测的过程中会将对应的日志信息打印到标准输出,为了更直观地展现各阶段的耗时,提供脚本对日志文件做进一步的分析处理。
使用时先将 Server 的输出保存到文件,以 profile 为例,脚本将日志中的时间打点信息转换成 json 格式保存到trace 文件,trace 文件可以通过 chrome 浏览器的 tracing 功能进行可视化。
```shell
python timeline_trace.py profile trace
```
具体操作:打开 chrome 浏览器,在地址栏输入 chrome://tracing/ ,跳转至 tracing 页面,点击 load 按钮,打开保存的 trace 文件,即可将预测服务的各阶段时间信息可视化。
### 在 Client 端输出 Profile 信息
Client 端在 `predict` 接口设置 `profile=True`,即可开启 Profile 功能。
开启该功能后,Client 端在预测的过程中会将该次预测对应的日志信息打印到标准输出,后续分析处理同 Server。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_server.pipeline import Analyst
import json
import logging
import sys
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
if len(sys.argv) < 3:
print("Usage: python analyse.py <log_filename> <trace_filename>")
exit(1)
log_filename = sys.argv[1]
trace_filename = sys.argv[2]
analyst = Analyst(log_filename)
analyst.save_trace(trace_filename)
op_analyst = analyst.get_op_analyst()
op_concurrency = op_analyst.concurrency_analysis("analyse.yaml")
print(json.dumps(op_concurrency, indent=2, separators=(',', ':')))
use_multithread: true
client_type: brpc
retry: 1
profile: false
prot: 8080
worker_num: 2
port: 18080
worker_num: 1
build_dag_each_worker: false
dag:
is_thread_op: true
client_type: brpc
retry: 1
use_profile: false
......@@ -13,18 +13,19 @@
# limitations under the License.
from paddle_serving_client.pipeline import PipelineClient
import numpy as np
from line_profiler import LineProfiler
client = PipelineClient()
client.connect('localhost:8080')
lp = LineProfiler()
lp_wrapper = lp(client.predict)
client.connect(['127.0.0.1:18080'])
words = 'i am very sad | 0'
for i in range(1):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map)
futures = []
for i in range(100):
futures.append(
client.predict(
feed_dict={"words": words}, fetch=["prediction"], asyn=True))
#lp.print_stats()
for f in futures:
res = f.result()
if res["ecode"] != 0:
print("predict failed: {}".format(res))
......@@ -21,16 +21,13 @@ import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.DEBUG)
_LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp):
def load_user_resources(self):
def init_op(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
......@@ -91,7 +88,7 @@ cnn_op = Op(name="cnn",
combine_op = CombineOp(
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=1,
concurrency=5,
timeout=-1,
retry=1)
......
......@@ -43,6 +43,8 @@ if __name__ == "__main__":
for line in f.readlines():
line = line.strip().split("\t")
if line[0] == "PROFILE":
if len(line) < 2:
continue
trace_list = prase(line[1], line[2], counter)
counter += 1
for trace in trace_list:
......
......@@ -15,3 +15,4 @@
from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer
from pipeline_client import PipelineClient
from analyse import Analyst
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import json
import copy
import re
import logging
_LOGGER = logging.getLogger()
class Analyst(object):
def __init__(self, profile_file):
self._profile_file = profile_file
self._trace = None
self.ave_call = None
self.ave_prepack = None
self.ave_postpack = None
self.op_analyst = None
self.start_time = None
self.end_time = None
def _prase_line(self, pid_str, time_str, counter):
pid = pid_str.split(":")[1]
event_list = time_str.split(" ")
trace_list = []
for event in event_list:
name, ts = event.split(":")
name_list = name.split("_")
ph = "B" if (name_list[-1] == "0") else "E"
if len(name_list) == 2:
name = name_list[0]
else:
name = "_".join(name_list[:-1])
name_list = name.split("#")
if len(name_list) > 1:
tid = name_list[-1]
name = "#".join(name_list[:-1])
else:
tid = 0
event_dict = {}
event_dict["name"] = name
event_dict["tid"] = tid
event_dict["pid"] = pid
event_dict["ts"] = ts
event_dict["ph"] = ph
trace_list.append(event_dict)
return trace_list
def get_trace(self):
if self._trace is not None:
return self._trace
all_list = []
counter = 0
with open(self._profile_file) as f:
for line in f.readlines():
line = line.strip().split("\t")
if line[0] == "PROFILE":
trace_list = self._prase_line(line[1], line[2], counter)
counter += 1
for trace in trace_list:
all_list.append(trace)
self._trace = all_list
return self._trace
def save_trace(self, trace_file):
self.get_trace()
trace = json.dumps(self._trace, indent=2, separators=(',', ':'))
with open(trace_file, "w") as f:
f.write(trace)
def print_profile(self):
self.get_profile()
print("graph engine call: {}".format(self.ave_call))
print("rpc prepack: {}".format(self.ave_prepack))
print("rpc postpack: {}".format(self.ave_postpack))
print("OP: {}".format(self.op_analyst))
def get_op_analyst(self):
self.get_profile()
return self.op_analyst
def get_profile(self):
if self.ave_call is not None and \
self.ave_prepack is not None and \
self.ave_postpack is not None and \
self.op_analyst is not None:
return (self.ave_call, self.ave_prepack, self.ave_postpack,
self.op_analyst)
trace = self.get_trace()
time_dict = {}
time_list_dict = {}
start, end = None, None
for event in trace:
name = "{}#{}".format(event["name"], event["tid"])
event_t = int(event["ts"])
if name in time_dict:
ts = event_t - time_dict.pop(name)
ts = ts / 1e3 # ms
if name not in time_list_dict:
time_list_dict[name] = []
time_list_dict[name].append(ts)
else:
time_dict[name] = event_t
if start is None:
start = event_t
elif start > event_t:
start = event_t
if end is None:
end = event_t
elif end < event_t:
end = event_t
self.start_time = start
self.end_time = end
op_analyst = OpAnalyst(start, end)
# reduce prepack_n, postpack_n, call_n
pat_prepack = re.compile(r"prepack_\d+#@G")
prepack_time_list = []
pat_postpack = re.compile(r"postpack_\d+#@G")
postpack_time_list = []
pat_call = re.compile(r"call_\d+#DAG")
call_time_list = []
for name in time_list_dict:
if pat_prepack.match(name):
prepack_time_list.extend(time_list_dict[name])
elif pat_postpack.match(name):
postpack_time_list.extend(time_list_dict[name])
elif pat_call.match(name):
call_time_list.extend(time_list_dict[name])
else:
op_analyst.add(name, time_list_dict[name])
self.ave_call = sum(call_time_list) * 1.0 / len(call_time_list)
self.ave_prepack = sum(prepack_time_list) * 1.0 / len(prepack_time_list)
self.ave_postpack = sum(postpack_time_list) * 1.0 / len(
postpack_time_list)
self.op_analyst = op_analyst
return (self.ave_call, self.ave_prepack, self.ave_postpack,
self.op_analyst)
class OpAnalyst(object):
def __init__(self, start_time, end_time):
self.op_time_list_dict = {}
self._qps = None
self._close = False
self.start_time = start_time
self.end_time = end_time
def add(self, name_str, ts_list):
if self._close:
_LOGGER.error("OpAnalyst is closed.")
return
op_name, curr_idx, step = self._parse(name_str)
if op_name not in self.op_time_list_dict:
self.op_time_list_dict[op_name] = {}
if curr_idx not in self.op_time_list_dict[op_name]:
self.op_time_list_dict[op_name][curr_idx] = {}
if step not in self.op_time_list_dict[op_name][curr_idx]:
self.op_time_list_dict[op_name][curr_idx][step] = []
self.op_time_list_dict[op_name][curr_idx][step].extend(ts_list)
def _parse(self, name):
step, name_str = name.split("#")
name_str = name_str[1:-1]
op_name, curr_idx = name_str.split("|")
return op_name, curr_idx, step
def _reduce_profile(self):
"""
Calculating the average time-consuming of multiple concurrent OPs.
"""
if self._close:
return
for op_name in self.op_time_list_dict:
total_time = None
for curr_idx in self.op_time_list_dict[op_name]:
ave_dict = {}
for step in self.op_time_list_dict[op_name][curr_idx]:
ave_dict[step] = sum(self.op_time_list_dict[op_name][
curr_idx][step]) * 1.0 / len(self.op_time_list_dict[
op_name][curr_idx][step])
if total_time is None:
total_time = ave_dict
else:
for step in ave_dict:
total_time[step] += ave_dict[step]
for step in total_time:
total_time[step] = total_time[step] * 1.0 / len(
self.op_time_list_dict[op_name])
self.op_time_list_dict[op_name] = total_time
self._close = True
def _get_qps(self):
"""
Calculating QPS for each step based on the time
consumed in each step of OP.
"""
if self._qps is not None:
return self._qps
self._reduce_profile()
self._qps = {}
for op_name, times in self.op_time_list_dict.items():
self._qps[op_name] = {
step: 1000.0 / ts
for step, ts in times.items()
}
return self._qps
def __str__(self):
self._reduce_profile()
return json.dumps(
self.op_time_list_dict, indent=2, separators=(', ', ':'))
def qps(self, op_name=None):
"""
Get the average QPS of each step of each OP (in q/s)
"""
self._get_qps()
if op_name is None:
return self._qps
else:
return self._qps[op_name]
def times(self, op_name=None):
"""
Get the average time of each step of each OP (in ms)
"""
self._reduce_profile()
if op_name is None:
return self.op_time_list_dict
else:
return self.op_time_list_dict[op_name]
def concurrency_analysis(self, op_config_yaml):
"""
Through OP time consuming and op_config_yaml to
calculate the theoretical QPS, as well as the
number of concurrency required by each OPs.
It should be noted that since multiple models
will affect each other on one card, only the
case that each model is on a different card can
be calculated.
The format of the yaml file is as follows:
```yaml
<op_name>:
<step(prep, midp or postp)>: <GPU id>
```
For example:
```yaml
cnn:
midp: 0
bow:
midp: 1
```
"""
import yaml
with open(op_config_yaml) as f:
op_config = yaml.load(f)
# check that each model is deployed on a different card
card_set = set()
# and finding the most time consuming part (GPU)
op_times = self.times()
most_time = 0
most_time_op_name = None
for op in op_config:
for step, cards in op_config[op].items():
if isinstance(cards, int):
cards = [cards]
elif isinstance(cards, str):
cards = [int(x) for x in cards.split(',')]
else:
raise Exception("Error cards type.")
for card in cards:
if card in card_set:
raise Exception(
"Analysis is failed because "
"different services interact when different"
" models are deployed on one card.")
else:
card_set.add(card)
times_each_card = op_times[op][step] / len(cards)
if most_time < times_each_card:
most_time = times_each_card
most_time_op_name = op
# calculate base qps
base_qps = 1.0 / most_time # q/ms
_LOGGER.info("Most Time Consuming (GPU): {} ms (op: {})"
.format(most_time, most_time_op_name))
_LOGGER.info("Theoretically Expected QPS: {} q/s".format(base_qps *
1000))
# reduce op times
op_times = {
op_name: sum(step_times.values())
for op_name, step_times in op_times.items()
}
# calculate op concurrency
op_concurrency = {
op_name: round(base_qps * times, 3)
for op_name, times in op_times.items()
}
return op_concurrency
......@@ -37,7 +37,8 @@ class ChannelDataEcode(enum.Enum):
TYPE_ERROR = 3
RPC_PACKAGE_ERROR = 4
CLIENT_ERROR = 5
UNKNOW = 6
CLOSED_ERROR = 6
UNKNOW = 7
class ChannelDataType(enum.Enum):
......@@ -53,7 +54,8 @@ class ChannelData(object):
dictdata=None,
data_id=None,
ecode=None,
error_info=None):
error_info=None,
client_need_profile=False):
'''
There are several ways to use it:
......@@ -87,6 +89,13 @@ class ChannelData(object):
self.id = data_id
self.ecode = ecode
self.error_info = error_info
self.client_need_profile = client_need_profile
self.profile_data_set = set()
def add_profile(self, profile_set):
if self.client_need_profile is False:
self.client_need_profile = True
self.profile_data_set |= profile_set
@staticmethod
def check_dictdata(dictdata):
......@@ -156,7 +165,7 @@ class ChannelData(object):
ChannelDataType(self.datatype).name, self.ecode, self.id)
class ProcessChannel(multiprocessing.queues.Queue):
class ProcessChannel(object):
"""
(Process version) The channel used for communication between Ops.
......@@ -186,18 +195,17 @@ class ProcessChannel(multiprocessing.queues.Queue):
"""
def __init__(self, manager, name=None, maxsize=0, timeout=None):
# https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/
if sys.version_info.major == 2:
super(ProcessChannel, self).__init__(maxsize=maxsize)
elif sys.version_info.major == 3:
super(ProcessChannel, self).__init__(
maxsize=maxsize, ctx=multiprocessing.get_context())
else:
raise Exception("Error Python version")
# For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty`
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
self._que = manager.Queue(maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self.name = name
self._stop = False
self._stop = manager.Value('i', 0)
self._cv = multiprocessing.Condition()
......@@ -253,15 +261,17 @@ class ProcessChannel(multiprocessing.queues.Queue):
))
elif len(self._producers) == 1:
with self._cv:
while self._stop is False:
while self._stop.value == 0:
try:
self.put({op_name: channeldata}, timeout=0)
self._que.put({op_name: channeldata}, timeout=0)
break
except Queue.Full:
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} channel size: {}".format(op_name,
self.qsize())))
self._que.qsize())))
self._cv.notify_all()
_LOGGER.debug(self._log("{} notify all".format(op_name)))
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
......@@ -300,15 +310,17 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._log("{} push data succ, but not push to queue.".
format(op_name)))
else:
while self._stop is False:
while self._stop.value == 0:
try:
_LOGGER.debug(
self._log("{} push data succ: {}".format(
op_name, put_data.__str__())))
self.put(put_data, timeout=0)
self._que.put(put_data, timeout=0)
break
except Queue.Empty:
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
......@@ -325,25 +337,21 @@ class ProcessChannel(multiprocessing.queues.Queue):
elif len(self._consumer_cursors) == 1:
resp = None
with self._cv:
while self._stop is False and resp is None:
while self._stop.value == 0 and resp is None:
try:
_LOGGER.debug(
self._log("{} try to get(with channel empty: {})".
format(op_name, self.empty())))
# For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty`
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
resp = self.get(timeout=1e-3)
format(op_name, self._que.empty())))
resp = self._que.get(timeout=0)
break
except Queue.Empty:
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel empty: {})".
format(op_name, self.empty())))
format(op_name, self._que.empty())))
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
......@@ -366,7 +374,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
with self._cv:
# When the data required by the current Op is not in output_buf,
# it is necessary to obtain a data from queue and add it to output_buf.
while self._stop is False and self._consumer_cursors[
while self._stop.value == 0 and self._consumer_cursors[
op_name] - self._base_cursor.value >= len(self._output_buf):
_LOGGER.debug(
self._log(
......@@ -376,22 +384,18 @@ class ProcessChannel(multiprocessing.queues.Queue):
try:
_LOGGER.debug(
self._log("{} try to get(with channel size: {})".format(
op_name, self.qsize())))
# For queue multiprocess: after putting an object on
# an empty queue there may be an infinitessimal delay
# before the queue's :meth:`~Queue.empty`
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
channeldata = self.get(timeout=1e-3)
op_name, self._que.qsize())))
channeldata = self._que.get(timeout=0)
self._output_buf.append(channeldata)
break
except Queue.Empty:
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize())))
format(op_name, self._que.qsize())))
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor.value
......@@ -438,10 +442,10 @@ class ProcessChannel(multiprocessing.queues.Queue):
return resp # reference, read only
def stop(self):
#TODO
self.close()
self._stop = True
self._cv.notify_all()
_LOGGER.debug(self._log("stop."))
self._stop.value = 1
with self._cv:
self._cv.notify_all()
class ThreadChannel(Queue.Queue):
......@@ -540,6 +544,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Full:
self._cv.wait()
if self._stop:
raise ChannelStopError()
self._cv.notify_all()
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True
......@@ -578,6 +584,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
......@@ -600,6 +608,8 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
......@@ -630,12 +640,14 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
if self._stop:
raise ChannelStopError()
consumer_cursor = self._consumer_cursors[op_name]
base_cursor = self._base_cursor
data_idx = consumer_cursor - base_cursor
resp = self._output_buf[data_idx]
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
resp = None
self._cursor_count[consumer_cursor] -= 1
if consumer_cursor == base_cursor and self._cursor_count[
......@@ -643,7 +655,7 @@ class ThreadChannel(Queue.Queue):
# When all the different Ops get the data that data_idx points
# to, pop the data from output_buf.
self._cursor_count.pop(consumer_cursor)
self._output_buf.pop(0)
resp = self._output_buf.pop(0)
self._base_cursor += 1
# to avoid cursor overflow
if self._base_cursor >= self._reset_max_cursor:
......@@ -654,6 +666,9 @@ class ThreadChannel(Queue.Queue):
cursor - self._reset_max_cursor: count
for cursor, count in self._cursor_count.items()
}
else:
resp = copy.deepcopy(self._output_buf[data_idx])
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._consumer_cursors[op_name] += 1
new_consumer_cursor = self._consumer_cursors[op_name]
......@@ -664,11 +679,15 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all()
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
# return resp # reference, read only
return copy.deepcopy(resp)
return resp
def stop(self):
#TODO
self.close()
_LOGGER.debug(self._log("stop."))
self._stop = True
self._cv.notify_all()
with self._cv:
self._cv.notify_all()
class ChannelStopError(RuntimeError):
def __init__(self):
pass
此差异已折叠。
......@@ -20,11 +20,15 @@ from concurrent import futures
import logging
import func_timeout
import os
import sys
import numpy as np
from numpy import *
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
ChannelData, ChannelDataType, ChannelStopError)
from .util import NameGenerator
from .profiler import TimeProfiler
_LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op")
......@@ -42,7 +46,6 @@ class Op(object):
retry=1):
if name is None:
name = _op_name_gen.next()
self._is_run = False
self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
......@@ -58,14 +61,17 @@ class Op(object):
self._retry = max(1, retry)
self._input = None
self._outputs = []
self._profiler = None
self._server_use_profile = False
# only for multithread
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
def init_profiler(self, profiler):
self._profiler = profiler
def use_profiler(self, use_profile):
self._server_use_profile = use_profile
def _profiler_record(self, string):
if self._profiler is None:
......@@ -90,9 +96,6 @@ class Op(object):
self._fetch_names = fetch_names
return client
def _get_input_channel(self):
return self._input
def get_input_ops(self):
return self._input_ops
......@@ -115,8 +118,11 @@ class Op(object):
channel.add_consumer(self.name)
self._input = channel
def _get_output_channels(self):
return self._outputs
def clean_input_channel(self):
self._input = None
def _get_input_channel(self):
return self._input
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
......@@ -126,6 +132,12 @@ class Op(object):
channel.add_producer(self.name)
self._outputs.append(channel)
def clean_output_channels(self):
self._outputs = []
def _get_output_channels(self):
return self._outputs
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
......@@ -136,12 +148,12 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, client_predict_handler, feed_dict):
def process(self, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
call_result = client_predict_handler(
call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
......@@ -149,29 +161,48 @@ class Op(object):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
self._is_run = False
def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None
client_need_profile, profile_set = False, set()
parsed_data = {}
key = list(channeldata_dict.keys())[0]
data_id = channeldata_dict[key].id
client_need_profile = channeldata_dict[key].client_need_profile
for name, data in channeldata_dict.items():
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
parsed_data[name] = data.parse()
return data_id, error_channeldata, parsed_data
def _push_to_output_channels(self, data, channels, name=None):
if client_need_profile:
profile_set |= data.profile_data_set
return (data_id, error_channeldata, parsed_data, client_need_profile,
profile_set)
def _push_to_output_channels(self,
data,
channels,
name=None,
client_need_profile=False,
profile_set=None):
if name is None:
name = self.name
self._add_profile_into_channeldata(data, client_need_profile,
profile_set)
for channel in channels:
channel.push(data, name)
def _add_profile_into_channeldata(self, data, client_need_profile,
profile_set):
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
sys.stderr.write(profile_str)
if client_need_profile and profile_set is not None:
profile_set.add(profile_str)
data.add_profile(profile_set)
def start_with_process(self, client_type):
proces = []
for concurrency_idx in range(self.concurrency):
......@@ -226,15 +257,13 @@ class Op(object):
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
def _run_process(self, preped_data, data_id, log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(client_predict_handler,
preped_data)
midped_data = self.process(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -243,11 +272,7 @@ class Op(object):
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
self._timeout, self.process, args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -314,7 +339,7 @@ class Op(object):
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
is_thread_op):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......@@ -325,80 +350,130 @@ class Op(object):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
client = None
client_predict_handler = None
# create client based on client_type
# init op
self.concurrency_idx = concurrency_idx
try:
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
if client is not None:
client_predict_handler = client.predict
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
# load user resources
try:
if use_multithread:
if is_thread_op:
with self._for_init_op_lock:
if not self._succ_init_op:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op()
self._succ_init_op = True
self._succ_close_op = False
else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(True)
# init client
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
while True:
#self._profiler_record("get#{}_0".format(op_info_prefix))
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.debug(log("stop."))
if is_thread_op:
with self._for_close_op_lock:
if not self._succ_close_op:
self._profiler = None
self.client = None
self._succ_init_op = False
self._succ_close_op = True
break
#self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
(data_id, error_channeldata, parsed_data, client_need_profile,
profile_set) = self._parse_channeldata(channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
continue
# preprecess
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid))
self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid))
self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
continue
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
continue
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
self._profiler_record("postp#{}_0".format(op_info_prefix))
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
try:
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
continue
# push data to channel (if run succ)
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
self._push_to_output_channels(output_data, output_channels)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
#self._profiler_record("push#{}_0".format(op_info_prefix))
try:
self._push_to_output_channels(
output_data,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
#self._profiler_record("push#{}_1".format(op_info_prefix))
def _log(self, info):
return "{} {}".format(self.name, info)
......@@ -407,11 +482,10 @@ class Op(object):
class RequestOp(Op):
""" RequestOp do not run preprocess, process, postprocess. """
def __init__(self, concurrency=1):
# PipelineService.name = "#G"
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
def __init__(self):
# PipelineService.name = "@G"
super(RequestOp, self).__init__(name="@G", input_ops=[])
# init op
try:
self.init_op()
except Exception as e:
......@@ -433,10 +507,9 @@ class RequestOp(Op):
class ResponseOp(Op):
""" ResponseOp do not run preprocess, process, postprocess. """
def __init__(self, input_ops, concurrency=1):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
def __init__(self, input_ops):
super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
# init op
try:
self.init_op()
except Exception as e:
......@@ -451,6 +524,7 @@ class ResponseOp(Op):
feed = channeldata.parse()
# ndarray to string:
# https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
np.set_printoptions(threshold=np.nan)
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
......@@ -505,7 +579,7 @@ class VirtualOp(Op):
self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
is_thread_op):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......@@ -516,14 +590,17 @@ class VirtualOp(Op):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
while True:
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
try:
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
......@@ -13,6 +13,7 @@
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
import sys
import numpy as np
from numpy import *
import logging
......@@ -26,13 +27,19 @@ _LOGGER = logging.getLogger()
class PipelineClient(object):
def __init__(self):
self._channel = None
self._profile_key = "pipeline.profile"
self._profile_value = "1"
def connect(self, endpoint):
self._channel = grpc.insecure_channel(endpoint)
def connect(self, endpoints):
options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
('grpc.max_send_message_length', 512 * 1024 * 1024),
('grpc.lb_policy_name', 'round_robin')]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
self._channel = grpc.insecure_channel(g_endpoint, options=options)
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel)
def _pack_request_package(self, feed_dict):
def _pack_request_package(self, feed_dict, profile):
req = pipeline_service_pb2.Request()
for key, value in feed_dict.items():
req.key.append(key)
......@@ -45,6 +52,9 @@ class PipelineClient(object):
else:
raise TypeError("only str and np.ndarray type is supported: {}".
format(type(value)))
if profile:
req.key.append(self._profile_key)
req.value.append(self._profile_value)
return req
def _unpack_response_package(self, resp, fetch):
......@@ -52,6 +62,10 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key == self._profile_key:
if resp.value[idx] != "":
sys.stderr.write(resp.value[idx])
continue
if fetch is not None and key not in fetch:
continue
data = resp.value[idx]
......@@ -62,13 +76,13 @@ class PipelineClient(object):
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch=None, asyn=False):
def predict(self, feed_dict, fetch=None, asyn=False, profile=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
req = self._pack_request_package(feed_dict, profile)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp, fetch)
......
此差异已折叠。
......@@ -23,6 +23,7 @@ elif sys.version_info.major == 3:
else:
raise Exception("Error Python version")
import time
import threading
_LOGGER = logging.getLogger()
......@@ -33,6 +34,7 @@ class TimeProfiler(object):
self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
self._time_record = Queue.Queue()
self._enable = False
self._lock = threading.Lock()
def enable(self, enable):
self._enable = enable
......@@ -40,26 +42,34 @@ class TimeProfiler(object):
def record(self, name_with_tag):
if self._enable is False:
return
timestamp = int(round(time.time() * 1000000))
name_with_tag = name_with_tag.split("_")
tag = name_with_tag[-1]
name = '_'.join(name_with_tag[:-1])
self._time_record.put((name, tag, int(round(time.time() * 1000000))))
with self._lock:
self._time_record.put((name, tag, timestamp))
def print_profile(self):
if self._enable is False:
return
sys.stderr.write(self.gen_profile_str())
def gen_profile_str(self):
if self._enable is False:
return
print_str = self._print_head
tmp = {}
while not self._time_record.empty():
name, tag, timestamp = self._time_record.get()
if name in tmp:
ptag, ptimestamp = tmp.pop(name)
print_str += "{}_{}:{} ".format(name, ptag, ptimestamp)
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str)
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
with self._lock:
while not self._time_record.empty():
name, tag, timestamp = self._time_record.get()
if name in tmp:
ptag, ptimestamp = tmp.pop(name)
print_str += "{}_{}:{} ".format(name, ptag, ptimestamp)
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str = "\n{}\n".format(print_str)
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
return print_str
numpy>=1.12, <=1.16.4 ; python_version<"3.5"
protobuf>=3.12.2
grpcio-tools>=1.28.1
grpcio>=1.28.1
func-timeout>=4.3.5
pyyaml>=1.3.0
......@@ -48,3 +48,5 @@ RUN yum -y update >/dev/null \
RUN yum install -y java \
&& wget http://repos.fedorapeople.org/repos/dchen/apache-maven/epel-apache-maven.repo -O /etc/yum.repos.d/epel-apache-maven.repo \
&& yum install -y apache-maven
RUN yum install -y lsof
......@@ -137,6 +137,15 @@ function kill_server_process() {
sleep 1
}
function kill_process_by_port() {
if [ $# != 1 ]; then
echo "usage: kill_process_by_port <PID>"
exit 1
fi
local PID=$1
lsof -i:$PID | awk 'NR == 1 {next} {print $2}' | xargs kill
}
function python_test_fit_a_line() {
# pwd: /Serving/python/examples
cd fit_a_line # pwd: /Serving/python/examples/fit_a_line
......@@ -579,6 +588,7 @@ function python_test_grpc_impl() {
check_cmd "python test_batch_client.py > /dev/null"
check_cmd "python test_timeout_client.py > /dev/null"
kill_server_process
kill_process_by_port 9393
check_cmd "python test_server.py uci_housing_model > /dev/null &"
sleep 5 # wait for the server to start
......@@ -589,6 +599,7 @@ function python_test_grpc_impl() {
check_cmd "python test_batch_client.py > /dev/null"
check_cmd "python test_timeout_client.py > /dev/null"
kill_server_process
kill_process_by_port 9393
cd .. # pwd: /Serving/python/examples/grpc_impl_example
......@@ -637,6 +648,7 @@ function python_test_grpc_impl() {
check_cmd "python test_batch_client.py > /dev/null"
check_cmd "python test_timeout_client.py > /dev/null"
kill_server_process
kill_process_by_port 9393
check_cmd "python test_server_gpu.py uci_housing_model > /dev/null &"
sleep 5 # wait for the server to start
......@@ -647,7 +659,8 @@ function python_test_grpc_impl() {
check_cmd "python test_batch_client.py > /dev/null"
check_cmd "python test_timeout_client.py > /dev/null"
kill_server_process
ps -ef | grep "test_server_gpu" | grep -v serving_build | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 9393
#ps -ef | grep "test_server_gpu" | grep -v serving_build | grep -v grep | awk '{print $2}' | xargs kill
cd .. # pwd: /Serving/python/examples/grpc_impl_example
......@@ -749,6 +762,128 @@ function python_test_resnet50(){
cd ..
}
function python_test_pipeline(){
# pwd:/ Serving/python/examples
local TYPE=$1
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
unsetproxy
cd pipeline/imdb_model_ensemble
case $TYPE in
CPU)
# start paddle serving service (brpc)
sh get_data.sh
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 --workdir test9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 --workdir test9393 &> bow.log &
sleep 5
# test: thread servicer & thread op
cat << EOF > config.yml
port: 18080
worker_num: 2
build_dag_each_worker: false
dag:
is_thread_op: true
client_type: brpc
retry: 1
use_profile: false
EOF
python test_pipeline_server.py > /dev/null &
sleep 5
check_cmd "python test_pipeline_client.py"
ps -ef | grep "pipeline_server" | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 18080
# test: thread servicer & process op
cat << EOF > config.yml
port: 18080
worker_num: 2
build_dag_each_worker: false
dag:
is_thread_op: false
client_type: brpc
retry: 1
use_profile: false
EOF
python test_pipeline_server.py > /dev/null &
sleep 5
check_cmd "python test_pipeline_client.py"
ps -ef | grep "pipeline_server" | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 18080
# test: process servicer & thread op
cat << EOF > config.yml
port: 18080
worker_num: 2
build_dag_each_worker: true
dag:
is_thread_op: flase
client_type: brpc
retry: 1
use_profile: false
EOF
python test_pipeline_server.py > /dev/null &
sleep 5
check_cmd "python test_pipeline_client.py"
ps -ef | grep "pipeline_server" | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 18080
# test: process servicer & process op
cat << EOF > config.yml
port: 18080
worker_num: 2
build_dag_each_worker: false
dag:
is_thread_op: false
client_type: brpc
retry: 1
use_profile: false
EOF
python test_pipeline_server.py > /dev/null &
sleep 5
check_cmd "python test_pipeline_client.py"
ps -ef | grep "pipeline_server" | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 18080
kill_server_process
kill_process_by_port 9292
kill_process_by_port 9393
# start paddle serving service (grpc)
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 --use_multilang --workdir test9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 --use_multilang --workdir test9393 &> bow.log &
sleep 5
cat << EOF > config.yml
port: 18080
worker_num: 2
build_dag_each_worker: false
dag:
is_thread_op: false
client_type: grpc
retry: 1
use_profile: false
EOF
python test_pipeline_server.py > /dev/null &
sleep 5
check_cmd "python test_pipeline_client.py"
ps -ef | grep "pipeline_server" | grep -v grep | awk '{print $2}' | xargs kill
kill_process_by_port 18080
kill_server_process
kill_process_by_port 9292
kill_process_by_port 9393
;;
GPU)
echo "pipeline ignore GPU test"
;;
*)
echo "error type"
exit 1
;;
esac
cd ../../
setproxy
unset SERVING_BIN
}
function python_app_api_test(){
#pwd:/ Serving/python/examples
#test image reader
......@@ -784,6 +919,7 @@ function python_run_test() {
python_test_yolov4 $TYPE # pwd: /Serving/python/examples
python_test_grpc_impl $TYPE # pwd: /Serving/python/examples
python_test_resnet50 $TYPE # pwd: /Serving/python/examples
python_test_pipeline $TYPE # pwd: /Serving/python/examples
echo "test python $TYPE part finished as expected."
cd ../.. # pwd: /Serving
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册