提交 80bdc849 编写于 作者: B barrierye

Merge branch 'develop' of https://github.com/PaddlePaddle/Serving into develop

...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
syntax = "proto2"; syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
message Tensor { message Tensor {
optional bytes data = 1; optional bytes data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
......
# Paddle Serving Client Java SDK
([简体中文](JAVA_SDK_CN.md)|English)
Paddle Serving provides Java SDK,which supports predict on the Client side with Java language. This document shows how to use the Java SDK.
## Getting started
### Prerequisites
```
- Java 8 or higher
- Apache Maven
```
The following table shows compatibilities between Paddle Serving Server and Java SDK.
| Paddle Serving Server version | Java SDK version |
| :---------------------------: | :--------------: |
| 0.3.2 | 0.0.1 |
### Install Java SDK
You can download jar and install it to the local Maven repository:
```shell
wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar
mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar
```
Or compile from the source code and install it to the local Maven repository:
```shell
cd Serving/java
mvn compile
mvn install
```
### Maven configure
```text
<dependency>
<groupId>io.paddle.serving.client</groupId>
<artifactId>paddle-serving-sdk-java</artifactId>
<version>0.0.1</version>
</dependency>
```
## Example
Here we will show how to use Java SDK for Boston house price prediction. Please refer to [examples](../java/examples) folder for more examples.
### Get model
```shell
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
tar -xzf uci_housing.tar.gz
```
### Start Python Server
```shell
python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang
```
#### Client side code example
```java
import io.paddle.serving.client.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
public class PaddleServingClientExample {
public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return ;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
System.out.println("predict failed.");
return ;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return ;
}
}
```
# Paddle Serving Client Java SDK
(简体中文|[English](JAVA_SDK.md))
Paddle Serving 提供了 Java SDK,支持 Client 端用 Java 语言进行预测,本文档说明了如何使用 Java SDK。
## 快速开始
### 环境要求
```
- Java 8 or higher
- Apache Maven
```
下表显示了 Paddle Serving Server 和 Java SDK 之间的兼容性
| Paddle Serving Server version | Java SDK version |
| :---------------------------: | :--------------: |
| 0.3.2 | 0.0.1 |
### 安装
您可以直接下载 jar,安装到本地 Maven 库:
```shell
wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar
mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar
```
或者从源码进行编译,安装到本地 Maven 库:
```shell
cd Serving/java
mvn compile
mvn install
```
### Maven 配置
```text
<dependency>
<groupId>io.paddle.serving.client</groupId>
<artifactId>paddle-serving-sdk-java</artifactId>
<version>0.0.1</version>
</dependency>
```
## 使用样例
这里将展示如何使用 Java SDK 进行房价预测,更多例子详见 [examples](../java/examples) 文件夹。
### 获取房价预测模型
```shell
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
tar -xzf uci_housing.tar.gz
```
### 启动 Python 端 Server
```shell
python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang
```
### Client 端代码示例
```java
import io.paddle.serving.client.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
public class PaddleServingClientExample {
public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return ;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
System.out.println("predict failed.");
return ;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return ;
}
}
```
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>io.paddle.serving.client</groupId>
<artifactId>paddle-serving-sdk-java-examples</artifactId>
<version>0.0.1</version>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
<version>3.8.1</version>
</plugin>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<mainClass>my.fully.qualified.class.Main</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-my-jar-with-dependencies</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<nd4j.backend>nd4j-native</nd4j.backend>
<nd4j.version>1.0.0-beta7</nd4j.version>
<datavec.version>1.0.0-beta7</datavec.version>
<paddle.serving.client.version>0.0.1</paddle.serving.client.version>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>io.paddle.serving.client</groupId>
<artifactId>paddle-serving-sdk-java</artifactId>
<version>${paddle.serving.client.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.30</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${datavec.version}</version>
</dependency>
</dependencies>
</project>
import io.paddle.serving.client.*;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
public class PaddleServingClientExample {
boolean fit_a_line() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean yolov4(String filename) {
// https://deeplearning4j.konduit.ai/
int height = 608;
int width = 608;
int channels = 3;
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
INDArray BGRimage = null;
try {
BGRimage = loader.asMatrix(new File(filename));
} catch (java.io.IOException e) {
System.out.println("load image fail.");
return false;
}
// shape: (channels, height, width)
BGRimage = BGRimage.reshape(channels, height, width);
INDArray RGBimage = Nd4j.create(BGRimage.shape());
// BGR2RGB
CustomOp op = DynamicCustomOp.builder("reverse")
.addInputs(BGRimage)
.addOutputs(RGBimage)
.addIntegerArguments(0)
.build();
Nd4j.getExecutioner().exec(op);
// Div(255.0)
INDArray image = RGBimage.divi(255.0);
INDArray im_size = Nd4j.createFromArray(new int[]{height, width});
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("image", image);
put("im_size", im_size);
}};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
succ = client.setRpcTimeoutMs(20000); // cpu
if (succ != true) {
System.out.println("set timeout failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean batch_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>() {{
add(feed_data);
add(feed_data);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_batch, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean asyn_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
PredictFuture future = client.asyn_predict(feed_data, fetch);
Map<String, INDArray> fetch_map = future.get();
if (fetch_map == null) {
System.out.println("Get future reslut failed");
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean model_ensemble() {
long[] data = {8, 233, 52, 601};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("words", npdata);
}};
List<String> fetch = Arrays.asList("prediction");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, HashMap<String, INDArray>> fetch_map
= client.ensemble_predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
return true;
}
boolean bert() {
float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] segment_ids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("input_mask", Nd4j.createFromArray(input_mask));
put("position_ids", Nd4j.createFromArray(position_ids));
put("input_ids", Nd4j.createFromArray(input_ids));
put("segment_ids", Nd4j.createFromArray(segment_ids));
}};
List<String> fetch = Arrays.asList("pooled_output");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean cube_local() {
long[] embedding_14 = {250644};
long[] embedding_2 = {890346};
long[] embedding_10 = {3939};
long[] embedding_17 = {421122};
long[] embedding_23 = {664215};
long[] embedding_6 = {704846};
float[] dense_input = {0.0f, 0.006633499170812604f, 0.03f, 0.0f,
0.145078125f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long[] embedding_24 = {269955};
long[] embedding_12 = {295309};
long[] embedding_7 = {437731};
long[] embedding_3 = {990128};
long[] embedding_1 = {7753};
long[] embedding_4 = {286835};
long[] embedding_8 = {27346};
long[] embedding_9 = {636474};
long[] embedding_18 = {880474};
long[] embedding_16 = {681378};
long[] embedding_22 = {410878};
long[] embedding_13 = {255651};
long[] embedding_5 = {25207};
long[] embedding_11 = {10891};
long[] embedding_20 = {238459};
long[] embedding_21 = {26235};
long[] embedding_15 = {691460};
long[] embedding_25 = {544187};
long[] embedding_19 = {537425};
long[] embedding_0 = {737395};
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("embedding_14.tmp_0", Nd4j.createFromArray(embedding_14));
put("embedding_2.tmp_0", Nd4j.createFromArray(embedding_2));
put("embedding_10.tmp_0", Nd4j.createFromArray(embedding_10));
put("embedding_17.tmp_0", Nd4j.createFromArray(embedding_17));
put("embedding_23.tmp_0", Nd4j.createFromArray(embedding_23));
put("embedding_6.tmp_0", Nd4j.createFromArray(embedding_6));
put("dense_input", Nd4j.createFromArray(dense_input));
put("embedding_24.tmp_0", Nd4j.createFromArray(embedding_24));
put("embedding_12.tmp_0", Nd4j.createFromArray(embedding_12));
put("embedding_7.tmp_0", Nd4j.createFromArray(embedding_7));
put("embedding_3.tmp_0", Nd4j.createFromArray(embedding_3));
put("embedding_1.tmp_0", Nd4j.createFromArray(embedding_1));
put("embedding_4.tmp_0", Nd4j.createFromArray(embedding_4));
put("embedding_8.tmp_0", Nd4j.createFromArray(embedding_8));
put("embedding_9.tmp_0", Nd4j.createFromArray(embedding_9));
put("embedding_18.tmp_0", Nd4j.createFromArray(embedding_18));
put("embedding_16.tmp_0", Nd4j.createFromArray(embedding_16));
put("embedding_22.tmp_0", Nd4j.createFromArray(embedding_22));
put("embedding_13.tmp_0", Nd4j.createFromArray(embedding_13));
put("embedding_5.tmp_0", Nd4j.createFromArray(embedding_5));
put("embedding_11.tmp_0", Nd4j.createFromArray(embedding_11));
put("embedding_20.tmp_0", Nd4j.createFromArray(embedding_20));
put("embedding_21.tmp_0", Nd4j.createFromArray(embedding_21));
put("embedding_15.tmp_0", Nd4j.createFromArray(embedding_15));
put("embedding_25.tmp_0", Nd4j.createFromArray(embedding_25));
put("embedding_19.tmp_0", Nd4j.createFromArray(embedding_19));
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}};
List<String> fetch = Arrays.asList("prob");
Client client = new Client();
String target = "localhost:9393";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
if (fetch_map == null) {
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
public static void main( String[] args ) {
// DL4J(Deep Learning for Java)Document:
// https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md
PaddleServingClientExample e = new PaddleServingClientExample();
boolean succ = false;
if (args.length < 1) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type>.");
System.out.println("<test-type>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4");
return;
}
String testType = args[0];
System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) {
succ = e.fit_a_line();
} else if ("bert".equals(testType)) {
succ = e.bert();
} else if ("model_ensemble".equals(testType)) {
succ = e.model_ensemble();
} else if ("asyn_predict".equals(testType)) {
succ = e.asyn_predict();
} else if ("batch_predict".equals(testType)) {
succ = e.batch_predict();
} else if ("cube_local".equals(testType)) {
succ = e.cube_local();
} else if ("cube_quant".equals(testType)) {
succ = e.cube_local();
} else if ("yolov4".equals(testType)) {
if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample yolov4 <image-filepath>.");
return;
}
succ = e.yolov4(args[1]);
} else {
System.out.format("test-type(%s) not match.\n", testType);
return;
}
if (succ == true) {
System.out.println("[Example] succ.");
} else {
System.out.println("[Example] fail.");
}
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>io.paddle.serving.client</groupId>
<artifactId>paddle-serving-sdk-java</artifactId>
<version>0.0.1</version>
<packaging>jar</packaging>
<name>paddle-serving-sdk-java</name>
<description>Java SDK for Paddle Sering Client.</description>
<url>https://github.com/PaddlePaddle/Serving</url>
<licenses>
<license>
<name>Apache License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
</license>
</licenses>
<developers>
<developer>
<name>PaddlePaddle Author</name>
<email>guru4elephant@gmail.com</email>
<organization>PaddlePaddle</organization>
<organizationUrl>https://github.com/PaddlePaddle/Serving</organizationUrl>
</developer>
</developers>
<scm>
<connection>scm:git:https://github.com/PaddlePaddle/Serving.git</connection>
<developerConnection>scm:git:https://github.com/PaddlePaddle/Serving.git</developerConnection>
<url>https://github.com/PaddlePaddle/Serving</url>
</scm>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<grpc.version>1.27.2</grpc.version>
<protobuf.version>3.11.0</protobuf.version>
<protoc.version>3.11.0</protoc.version>
<nd4j.backend>nd4j-native</nd4j.backend>
<nd4j.version>1.0.0-beta7</nd4j.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-bom</artifactId>
<version>${grpc.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
</dependency>
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
<version>1.2</version>
<scope>provided</scope> <!-- not needed at runtime -->
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-testing</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>${protobuf.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.google.errorprone</groupId>
<artifactId>error_prone_annotations</artifactId>
<version>2.3.4</version> <!-- prefer to use 2.3.3 or later -->
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.5.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-text</artifactId>
<version>1.6</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.4</version>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20190722</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.30</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.1.1</version>
<configuration>
<javadocExecutable>${java.home}/bin/javadoc</javadocExecutable>
</configuration>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.6.2</version>
</extension>
</extensions>
<plugins>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<version>1.6.8</version>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-release-plugin</artifactId>
<version>2.5.3</version>
<configuration>
<autoVersionSubmodules>true</autoVersionSubmodules>
<useReleaseProfile>false</useReleaseProfile>
<releaseProfiles>release</releaseProfiles>
<goals>deploy</goals>
</configuration>
</plugin>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.6.1</version>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}
</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}
</pluginArtifact>
</configuration>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>compile-custom</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<version>3.0.0-M2</version>
<executions>
<execution>
<id>enforce</id>
<configuration>
<rules>
<requireUpperBoundDeps/>
</rules>
</configuration>
<goals>
<goal>enforce</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
package io.paddle.serving.client;
import java.util.*;
import java.util.function.Function;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import com.google.protobuf.ByteString;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*;
import io.paddle.serving.client.PredictFuture;
class Profiler {
int pid_;
String print_head_ = null;
List<String> time_record_ = null;
boolean enable_ = false;
Profiler() {
RuntimeMXBean runtimeMXBean = ManagementFactory.getRuntimeMXBean();
pid_ = Integer.valueOf(runtimeMXBean.getName().split("@")[0]).intValue();
print_head_ = "\nPROFILE\tpid:" + pid_ + "\t";
time_record_ = new ArrayList<String>();
time_record_.add(print_head_);
}
void record(String name) {
if (enable_) {
long ctime = System.currentTimeMillis() * 1000;
time_record_.add(name + ":" + String.valueOf(ctime) + " ");
}
}
void printProfile() {
if (enable_) {
String profile_str = String.join("", time_record_);
time_record_ = new ArrayList<String>();
time_record_.add(print_head_);
}
}
void enable(boolean flag) {
enable_ = flag;
}
}
public class Client {
private ManagedChannel channel_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceFutureStub futureStub_;
private double rpcTimeoutS_;
private List<String> feedNames_;
private Map<String, Integer> feedTypes_;
private Map<String, List<Integer>> feedShapes_;
private List<String> fetchNames_;
private Map<String, Integer> fetchTypes_;
private Set<String> lodTensorSet_;
private Map<String, Integer> feedTensorLen_;
private Profiler profiler_;
public Client() {
channel_ = null;
blockingStub_ = null;
futureStub_ = null;
rpcTimeoutS_ = 2;
feedNames_ = null;
feedTypes_ = null;
feedShapes_ = null;
fetchNames_ = null;
fetchTypes_ = null;
lodTensorSet_ = null;
feedTensorLen_ = null;
profiler_ = new Profiler();
boolean is_profile = false;
String FLAGS_profile_client = System.getenv("FLAGS_profile_client");
if (FLAGS_profile_client != null && FLAGS_profile_client.equals("1")) {
is_profile = true;
}
profiler_.enable(is_profile);
}
public boolean setRpcTimeoutMs(int rpc_timeout) {
if (futureStub_ == null || blockingStub_ == null) {
System.out.println("set timeout must be set after connect.");
return false;
}
rpcTimeoutS_ = rpc_timeout / 1000.0;
SetTimeoutRequest timeout_req = SetTimeoutRequest.newBuilder()
.setTimeoutMs(rpc_timeout)
.build();
SimpleResponse resp;
try {
resp = blockingStub_.setTimeout(timeout_req);
} catch (StatusRuntimeException e) {
System.out.format("Set RPC timeout failed: %s\n", e.toString());
return false;
}
return resp.getErrCode() == 0;
}
public boolean connect(String target) {
// TODO: target must be NameResolver-compliant URI
// https://grpc.github.io/grpc-java/javadoc/io/grpc/ManagedChannelBuilder.html
try {
channel_ = ManagedChannelBuilder.forTarget(target)
.defaultLoadBalancingPolicy("round_robin")
.maxInboundMessageSize(Integer.MAX_VALUE)
.usePlaintext()
.build();
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_);
} catch (Exception e) {
System.out.format("Connect failed: %s\n", e.toString());
return false;
}
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
GetClientConfigResponse resp;
try {
resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (Exception e) {
System.out.format("Get Client config failed: %s\n", e.toString());
return false;
}
String model_config_str = resp.getClientConfigStr();
_parseModelConfig(model_config_str);
return true;
}
private void _parseModelConfig(String model_config_str) {
GeneralModelConfig.Builder model_conf_builder = GeneralModelConfig.newBuilder();
try {
com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder);
} catch (com.google.protobuf.TextFormat.ParseException e) {
System.out.format("Parse client config failed: %s\n", e.toString());
}
GeneralModelConfig model_conf = model_conf_builder.build();
feedNames_ = new ArrayList<String>();
fetchNames_ = new ArrayList<String>();
feedTypes_ = new HashMap<String, Integer>();
feedShapes_ = new HashMap<String, List<Integer>>();
fetchTypes_ = new HashMap<String, Integer>();
lodTensorSet_ = new HashSet<String>();
feedTensorLen_ = new HashMap<String, Integer>();
List<FeedVar> feed_var_list = model_conf.getFeedVarList();
for (FeedVar feed_var : feed_var_list) {
feedNames_.add(feed_var.getAliasName());
}
List<FetchVar> fetch_var_list = model_conf.getFetchVarList();
for (FetchVar fetch_var : fetch_var_list) {
fetchNames_.add(fetch_var.getAliasName());
}
for (int i = 0; i < feed_var_list.size(); ++i) {
FeedVar feed_var = feed_var_list.get(i);
String var_name = feed_var.getAliasName();
feedTypes_.put(var_name, feed_var.getFeedType());
feedShapes_.put(var_name, feed_var.getShapeList());
if (feed_var.getIsLodTensor()) {
lodTensorSet_.add(var_name);
} else {
int counter = 1;
for (int dim : feedShapes_.get(var_name)) {
counter *= dim;
}
feedTensorLen_.put(var_name, counter);
}
}
for (int i = 0; i < fetch_var_list.size(); i++) {
FetchVar fetch_var = fetch_var_list.get(i);
String var_name = fetch_var.getAliasName();
fetchTypes_.put(var_name, fetch_var.getFetchType());
if (fetch_var.getIsLodTensor()) {
lodTensorSet_.add(var_name);
}
}
}
private InferenceRequest _packInferenceRequest(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet());
InferenceRequest.Builder req_builder = InferenceRequest.newBuilder()
.addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch)
.setIsPython(false);
for (HashMap<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder();
INDArray variable = feed_data.get(name);
long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape);
int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
if (v_type == 0) { // int64
while (iter.hasNext()) {
long[] next_index = iter.next();
long x = flattened_list.getLong(next_index);
tensor_builder.addInt64Data(x);
}
} else if (v_type == 1) { // float32
while (iter.hasNext()) {
long[] next_index = iter.next();
float x = flattened_list.getFloat(next_index);
tensor_builder.addFloatData(x);
}
} else if (v_type == 2) { // int32
while (iter.hasNext()) {
long[] next_index = iter.next();
// the interface of INDArray is strange:
// https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html
int[] int_next_index = new int[next_index.length];
for(int i = 0; i < next_index.length; i++) {
int_next_index[i] = (int)next_index[i];
}
int x = flattened_list.getInt(int_next_index);
tensor_builder.addIntData(x);
}
} else {
throw new IllegalArgumentException("error tensor value type.");
}
tensor_builder.addAllShape(feedShapes_.get(name));
inst_builder.addTensorArray(tensor_builder.build());
}
req_builder.addInsts(inst_builder.build());
}
return req_builder.build();
}
private Map<String, HashMap<String, INDArray>>
_unpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
Boolean need_variant_tag) throws IllegalArgumentException {
return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
private static Map<String, HashMap<String, INDArray>>
_staticUnpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
Map<String, Integer> fetchTypes,
Set<String> lodTensorSet,
Boolean need_variant_tag) throws IllegalArgumentException {
if (resp.getErrCode() != 0) {
return null;
}
String tag = resp.getTag();
HashMap<String, HashMap<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) {
String engine_name = model_result.getEngineName();
FetchInst inst = model_result.getInsts(0);
HashMap<String, INDArray> result_map
= new HashMap<String, INDArray>();
int index = 0;
for (String name: fetch) {
Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name);
INDArray data = null;
if (v_type == 0) { // int64
List<Long> list = variable.getInt64DataList();
long[] array = new long[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
data = Nd4j.createFromArray(array);
} else if (v_type == 1) { // float32
List<Float> list = variable.getFloatDataList();
float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
data = Nd4j.createFromArray(array);
} else if (v_type == 2) { // int32
List<Integer> list = variable.getIntDataList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
data = Nd4j.createFromArray(array);
} else {
throw new IllegalArgumentException("error tensor value type.");
}
// shape
List<Integer> shape_lsit = variable.getShapeList();
int[] shape_array = new int[shape_lsit.size()];
for (int i = 0; i < shape_lsit.size(); ++i) {
shape_array[i] = shape_lsit.get(i);
}
data = data.reshape(shape_array);
// put data to result_map
result_map.put(name, data);
// lod
if (lodTensorSet.contains(name)) {
List<Integer> list = variable.getLodList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name + ".lod", Nd4j.createFromArray(array));
}
index += 1;
}
multi_result_map.put(engine_name, result_map);
}
// TODO: tag(ABtest not support now)
return multi_result_map;
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return predict(feed, fetch, false);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return ensemble_predict(feed, fetch, false);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return asyn_predict(feed, fetch, false);
}
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return ensemble_predict(feed_batch, fetch, need_variant_tag);
}
public PredictFuture asyn_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return asyn_predict(feed_batch, fetch, need_variant_tag);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch, fetch, false);
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return ensemble_predict(feed_batch, fetch, false);
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return asyn_predict(feed_batch, fetch, false);
}
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
InferenceResponse resp = blockingStub_.inference(req);
profiler_.record("java_client_infer_1");
profiler_.record("java_postpro_0");
Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("predict failed: please use ensemble_predict impl.\n");
return null;
}
profiler_.record("java_postpro_1");
profiler_.printProfile();
return list.get(0).getValue();
} catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString());
return null;
}
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
try {
profiler_.record("java_prepro_0");
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
profiler_.record("java_prepro_1");
profiler_.record("java_client_infer_0");
InferenceResponse resp = blockingStub_.inference(req);
profiler_.record("java_client_infer_1");
profiler_.record("java_postpro_0");
Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag);
profiler_.record("java_postpro_1");
profiler_.printProfile();
return ensemble_result;
} catch (StatusRuntimeException e) {
System.out.format("predict failed: %s\n", e.toString());
return null;
}
}
public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
PredictFuture predict_future = new PredictFuture(future,
(InferenceResponse resp) -> {
return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
);
return predict_future;
}
}
package io.paddle.serving.client;
import java.util.*;
import java.util.function.Function;
import io.grpc.StatusRuntimeException;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import io.paddle.serving.client.Client;
import io.paddle.serving.grpc.*;
public class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
public Map<String, INDArray> get() {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("predict failed: %s\n", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
= callBackFunc_.apply(resp);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("predict failed: please use get_ensemble impl.\n");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> ensemble_get() {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("predict failed: %s\n", e.toString());
return null;
}
return callBackFunc_.apply(resp);
}
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.configure";
option java_outer_classname = "ConfigureProto";
package paddle.serving.configure;
message FeedVar {
optional string name = 1;
optional string alias_name = 2;
optional bool is_lod_tensor = 3 [ default = false ];
optional int32 feed_type = 4 [ default = 0 ];
repeated int32 shape = 5;
}
message FetchVar {
optional string name = 1;
optional string alias_name = 2;
optional bool is_lod_tensor = 3 [ default = false ];
optional int32 fetch_type = 4 [ default = 0 ];
repeated int32 shape = 5;
}
message GeneralModelConfig {
repeated FeedVar feed_var = 1;
repeated FetchVar fetch_var = 2;
};
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message InferenceRequest {
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
};
message InferenceResponse {
repeated ModelOutput outputs = 1;
optional string tag = 2;
required int32 err_code = 3;
};
message ModelOutput {
repeated FetchInst insts = 1;
optional string engine_name = 2;
}
message SetTimeoutRequest { required int32 timeout_ms = 1; }
message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
service MultiLangGeneralModelService {
rpc Inference(InferenceRequest) returns (InferenceResponse) {}
rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {}
rpc GetClientConfig(GetClientConfigRequest)
returns (GetClientConfigResponse) {}
};
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="INFO">
<Appenders>
<Console name="Console" target="SYSTEM_OUT">
<PatternLayout pattern="%highlight{%d{yyyy-MM-dd HH:mm:ss} %C %M %n%p: %m%n}{STYLE=Logback}"/>
</Console>
</Appenders>
<Loggers>
<Root level="INFO">
<AppenderRef ref="Console"/>
</Root>
</Loggers>
</Configuration>
...@@ -45,7 +45,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data ...@@ -45,7 +45,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data
CPU :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz CPU :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz
Model :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/ctr_criteo_with_cube/network_conf.py) Model :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py)
server core/thread num : 4/8 server core/thread num : 4/8
......
...@@ -43,7 +43,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data ...@@ -43,7 +43,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data
设备 :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz 设备 :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz
模型 :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/ctr_criteo_with_cube/network_conf.py) 模型 :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py)
server core/thread num : 4/8 server core/thread num : 4/8
......
...@@ -33,5 +33,9 @@ server = Server() ...@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4) server.set_num_threads(4)
server.load_model_config(sys.argv[1]) server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu") server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server() server.run_server()
...@@ -33,5 +33,9 @@ server = Server() ...@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4) server.set_num_threads(4)
server.load_model_config(sys.argv[1], sys.argv[2]) server.load_model_config(sys.argv[1], sys.argv[2])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu") server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server() server.run_server()
wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz
tar -zxvf text_classification_data.tar.gz
tar -zxvf imdb_model.tar.gz
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import sys
import os
import paddle
import re
import paddle.fluid.incubate.data_generator as dg
py_version = sys.version_info[0]
class IMDBDataset(dg.MultiSlotDataGenerator):
def load_resource(self, dictfile):
self._vocab = {}
wid = 0
if py_version == 2:
with open(dictfile) as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
else:
with open(dictfile, encoding="utf-8") as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
def get_words_only(self, line):
sent = line.lower().replace("<br />", " ").strip()
words = [x for x in self._pattern.split(sent) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas
def get_words_and_label(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
label = [int(line.split('|')[-1])]
words = [x for x in self._pattern.split(send) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas, label
def infer_reader(self, infer_filelist, batch, buf_size):
def local_iter():
for fname in infer_filelist:
with open(fname, "r") as fin:
for line in fin:
feas, label = self.get_words_and_label(line)
yield feas, label
import paddle
batch_iter = paddle.batch(
paddle.reader.shuffle(
local_iter, buf_size=buf_size),
batch_size=batch)
return batch_iter
def generate_sample(self, line):
def memory_iter():
for i in range(1000):
yield self.return_value
def data_iter():
feas, label = self.get_words_and_label(line)
yield ("words", feas), ("label", label)
return data_iter
if __name__ == "__main__":
imdb = IMDBDataset()
imdb.load_resource("imdb.vocab")
imdb.run_from_stdin()
...@@ -34,4 +34,6 @@ for i in range(3): ...@@ -34,4 +34,6 @@ for i in range(3):
fetch = ["prediction"] fetch = ["prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch) fetch_maps = client.predict(feed=feed, fetch=fetch)
for model, fetch_map in fetch_maps.items(): for model, fetch_map in fetch_maps.items():
if model == "serving_status_code":
continue
print("step: {}, model: {}, res: {}".format(i, model, fetch_map)) print("step: {}, model: {}, res: {}".format(i, model, fetch_map))
# Yolov4 Detection Service
([简体中文](README_CN.md)|English)
## Get Model
```
python -m paddle_serving_app.package --get_model yolov4
tar -xzvf yolov4.tar.gz
```
## Start RPC Service
```
python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang
```
## Prediction
```
python test_client.py 000000570688.jpg
```
After the prediction is completed, a json file to save the prediction result and a picture with the detection result box will be generated in the `./outpu folder.
# Yolov4 检测服务
(简体中文|[English](README.md))
## 获取模型
```
python -m paddle_serving_app.package --get_model yolov4
tar -xzvf yolov4.tar.gz
```
## 启动RPC服务
```
python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang
```
## 预测
```
python test_client.py 000000570688.jpg
```
预测完成会在`./output`文件夹下生成保存预测结果的json文件以及标出检测结果框的图片。
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle_serving_client import MultiLangClient as Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize(
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
(2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.connect(['127.0.0.1:9393'])
# client.set_rpc_timeout_ms(10000)
im = preprocess(sys.argv[1])
fetch_map = client.predict(
feed={
"image": im,
"im_size": np.array(list(im.shape[1:])),
},
fetch=["save_infer_model/scale_0.tmp_0"])
fetch_map.pop("serving_status_code")
fetch_map["image"] = sys.argv[1]
postprocess(fetch_map)
# OCR # OCR
(English|[简体中文](./README_CN.md))
## Get Model ## Get Model
``` ```
python -m paddle_serving_app.package --get_model ocr_rec python -m paddle_serving_app.package --get_model ocr_rec
...@@ -8,38 +10,78 @@ python -m paddle_serving_app.package --get_model ocr_det ...@@ -8,38 +10,78 @@ python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz tar -xzvf ocr_det.tar.gz
``` ```
## RPC Service ## Get Dataset (Optional)
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar
tar xf test_imgs.tar
```
## Web Service
### Start Service ### Start Service
For the following two code block, please check your devices and pick one
for GPU device
``` ```
python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
python ocr_web_server.py
``` ```
for CPU device
### Client Prediction
``` ```
python -m paddle_serving_server.serve --model ocr_rec_model --port 9292 python ocr_web_client.py
python -m paddle_serving_server.serve --model ocr_det_model --port 9293
``` ```
If you want a faster web service, please try Web Debugger Service
### Client Prediction ## Web Debugger Service
```
python ocr_debugger_server.py
```
## Web Debugger Client Prediction
``` ```
python ocr_rpc_client.py python ocr_web_client.py
``` ```
## Web Service ## Benchmark
### Start Service CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40
GPU: Nvidia Tesla V100 * 1
Dataset: RCTW 500 sample images
| engine | client read image(ms) | client-server tras time(ms) | server read image(ms) | det pre(ms) | det infer(ms) | det post(ms) | rec pre(ms) | rec infer(ms) | rec post(ms) | server-client trans time(ms) | server side time consumption(ms) | server side overhead(ms) | total time(ms) |
|------------------------------|----------------|----------------------------|------------------|--------------------|------------------|--------------------|--------------------|------------------|--------------------|--------------------------|--------------------|--------------|---------------|
| Serving web service | 8.69 | 13.41 | 109.97 | 2.82 | 87.76 | 4.29 | 3.98 | 78.51 | 3.66 | 4.12 | 181.02 | 136.49 | 317.51 |
| Serving Debugger web service | 8.73 | 16.42 | 115.27 | 2.93 | 20.63 | 3.97 | 4.48 | 13.84 | 3.60 | 6.91 | 49.45 | 147.33 | 196.78 |
## Appendix: Det or Rec only
if you are going to detect images not recognize it or directly recognize the words from images. We also provide Det and Rec server for you.
### Det Server
``` ```
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 python det_web_server.py
python ocr_web_server.py #or
python det_debugger_server.py
``` ```
### Client Prediction ### Det Client
```
# also use ocr_web_client.py
python ocr_web_client.py
```
### Rec Server
```
python rec_web_server.py
#or
python rec_debugger_server.py
```
### Rec Client
``` ```
sh ocr_web_client.sh python rec_web_client.py
``` ```
# OCR 服务
([English](./README.md)|简体中文)
## 获取模型
```
python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz
python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz
```
## 获取数据集(可选)
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar
tar xf test_imgs.tar
```
### 客户端预测
```
python ocr_rpc_client.py
```
## Web Service服务
### 启动服务
```
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
python ocr_web_server.py
```
### 启动客户端
```
python ocr_web_client.py
```
如果用户需要更快的执行速度,请尝试Debugger版Web服务
## 启动Debugger版Web服务
```
python ocr_debugger_server.py
```
## 启动客户端
```
python ocr_web_client.py
```
## 性能指标
CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40
GPU: Nvidia Tesla V100单卡
数据集:RCTW 500张测试数据集
| engine | 客户端读图(ms) | 客户端发送请求到服务端(ms) | 服务端读图(ms) | 检测预处理耗时(ms) | 检测模型耗时(ms) | 检测后处理耗时(ms) | 识别预处理耗时(ms) | 识别模型耗时(ms) | 识别后处理耗时(ms) | 服务端回传客户端时间(ms) | 服务端整体耗时(ms) | 空跑耗时(ms) | 整体耗时(ms) |
|------------------------------|----------------|----------------------------|------------------|--------------------|------------------|--------------------|--------------------|------------------|--------------------|--------------------------|--------------------|--------------|---------------|
| Serving web service | 8.69 | 13.41 | 109.97 | 2.82 | 87.76 | 4.29 | 3.98 | 78.51 | 3.66 | 4.12 | 181.02 | 136.49 | 317.51 |
| Serving Debugger web service | 8.73 | 16.42 | 115.27 | 2.93 | 20.63 | 3.97 | 4.48 | 13.84 | 3.60 | 6.91 | 49.45 | 147.33 | 196.78 |
## 附录: 检测/识别单服务启动
如果您想单独启动检测或者识别服务,我们也提供了启动单服务的代码
### 启动检测服务
```
python det_web_server.py
#or
python det_debugger_server.py
```
### 检测服务客户端
```
# also use ocr_web_client.py
python ocr_web_client.py
```
### 启动识别服务
```
python rec_web_server.py
#or
python rec_debugger_server.py
```
### 识别服务客户端
```
python rec_web_client.py
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_det(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"]
def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
return {"dt_boxes": dt_boxes.tolist()}
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_det_model")
ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.init_det()
ocr_service.run_debugger_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_det(self):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape
print(det_img)
return {"image": det_img}, ["concat_1.tmp_0"]
def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
return {"dt_boxes": dt_boxes.tolist()}
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_det_model")
ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.init_det()
ocr_service.run_rpc_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from paddle_serving_server_gpu.web_service import WebService
from paddle_serving_app.local_predict import Debugger
import time
import re
import base64
class OCRService(WebService):
def init_det_debugger(self, det_model_config):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.det_client = Debugger()
self.det_client.load_model_config(
det_model_config, gpu=True, profile=False)
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im)
_, new_h, new_w = det_img.shape
det_img = det_img[np.newaxis, :]
det_img = det_img.copy()
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage()
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
if len(img_list) == 0:
return [], []
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": res_lst}
return res
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_debugger(det_model_config="ocr_det_model")
ocr_service.run_debugger_service(gpu=True)
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
import time
import re
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
#img = cv2.imread(img)
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def read_det_box_file(filename):
with open(filename, 'r') as f:
line = f.readline()
a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int(
line.split(' ')[2])
dt_boxes = np.zeros((a, b, c)).astype(np.float32)
line = f.readline()
for i in range(a):
for j in range(b):
line = f.readline()
dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float(
line.split(' ')[0]), float(line.split(' ')[1])
line = f.readline()
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def main():
client1 = Client()
client1.load_client_config("ocr_det_client/serving_client_conf.prototxt")
client1.connect(["127.0.0.1:9293"])
client2 = Client()
client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client2.connect(["127.0.0.1:9292"])
read_image_file = File2Image()
preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
filter_func = FilterBoxes(10, 10)
ocr_reader = OCRReader()
files = [
"./imgs/{}".format(f) for f in os.listdir('./imgs')
if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f)
]
#files = ["2.jpg"]*30
#files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)]
time_all = 0
time_det_all = 0
time_rec_all = 0
for name in files:
#print(name)
im = read_image_file(name)
ori_h, ori_w, _ = im.shape
time1 = time.time()
img = preprocess(im)
_, new_h, new_w = img.shape
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
#print(new_h, new_w, ori_h, ori_w)
time_before_det = time.time()
outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"])
time_after_det = time.time()
time_det_all += (time_after_det - time_before_det)
#print(outputs)
dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio)
#norm_img = norm_img[np.newaxis, :]
feed = {"image": norm_img}
feed_list.append(feed)
#fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch = ["ctc_greedy_decoder_0.tmp_0"]
time_before_rec = time.time()
if len(feed_list) == 0:
continue
fetch_map = client2.predict(feed=feed_list, fetch=fetch)
time_after_rec = time.time()
time_rec_all += (time_after_rec - time_before_rec)
rec_res = ocr_reader.postprocess(fetch_map)
#for res in rec_res:
# print(res[0].encode("utf-8"))
time2 = time.time()
time_all += (time2 - time1)
print("rpc+det time: {}".format(time_all / len(files)))
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import requests
import json
import cv2
import base64
import os, sys
import time
def cv2_to_base64(image):
#data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(image).decode(
'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction"
test_img_dir = "imgs/"
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction
...@@ -21,10 +21,11 @@ import os ...@@ -21,10 +21,11 @@ import os
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from paddle_serving_server_gpu.web_service import WebService from paddle_serving_server_gpu.web_service import WebService
import time import time
import re import re
import base64
class OCRService(WebService): class OCRService(WebService):
...@@ -37,74 +38,16 @@ class OCRService(WebService): ...@@ -37,74 +38,16 @@ class OCRService(WebService):
self.det_client = Client() self.det_client = Client()
self.det_client.load_client_config(det_client_config) self.det_client.load_client_config(det_client_config)
self.det_client.connect(["127.0.0.1:{}".format(det_port)]) self.det_client.connect(["127.0.0.1:{}".format(det_port)])
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
img_url = feed[0]["image"] data = base64.b64decode(feed[0]["image"].encode('utf8'))
#print(feed, img_url) data = np.fromstring(data, np.uint8)
read_from_url = URL2Image() im = cv2.imdecode(data, cv2.IMREAD_COLOR)
im = read_from_url(img_url)
ori_h, ori_w, _ = im.shape ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
#print("det_img", det_img, det_img.shape)
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"]) feed={"image": det_img}, fetch=["concat_1.tmp_0"])
#print("det_out", det_out)
def sorted_boxes(dt_boxes):
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
_, new_h, new_w = det_img.shape _, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
...@@ -114,10 +57,12 @@ class OCRService(WebService): ...@@ -114,10 +57,12 @@ class OCRService(WebService):
"unclip_ratio": 1.5, "unclip_ratio": 1.5,
"min_size": 3 "min_size": 3
}) })
sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage()
feed_list = [] feed_list = []
img_list = [] img_list = []
max_wh_ratio = 0 max_wh_ratio = 0
...@@ -128,29 +73,25 @@ class OCRService(WebService): ...@@ -128,29 +73,25 @@ class OCRService(WebService):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list: for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img} feed = {"image": norm_img}
feed_list.append(feed) feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0"] fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
#print("feed_list", feed_list)
return feed_list, fetch return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
#print(fetch_map) rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
ocr_reader = OCRReader()
rec_res = ocr_reader.postprocess(fetch_map)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
fetch_map["res"] = res_lst res = {"res": res_lst}
del fetch_map["ctc_greedy_decoder_0.tmp_0"] return res
del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"]
return fetch_map
ocr_service = OCRService(name="ocr") ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model") ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292) ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.init_det_client( ocr_service.init_det_client(
det_port=9293, det_port=9293,
det_client_config="ocr_det_client/serving_client_conf.prototxt") det_client_config="ocr_det_client/serving_client_conf.prototxt")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_rec(self):
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
img_list = []
for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im)
max_wh_ratio = 0
for i, boximg in enumerate(img_list):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
for i, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[i] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": res_lst}
return res
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.set_gpus("0")
ocr_service.init_rec()
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.run_debugger_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import requests
import json
import cv2
import base64
import os, sys
import time
def cv2_to_base64(image):
#data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(image).decode(
'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction"
test_img_dir = "rec_img/"
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
#data = {"feed": [{"image": image}], "fetch": ["res"]}
data = {"feed": [{"image": image}] * 3, "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_rec(self):
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = []
for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im)
feed_list = []
max_wh_ratio = 0
for i, boximg in enumerate(img_list):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": res_lst}
return res
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.set_gpus("0")
ocr_service.init_rec()
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.run_rpc_service()
ocr_service.run_web_service()
...@@ -70,9 +70,10 @@ class Debugger(object): ...@@ -70,9 +70,10 @@ class Debugger(object):
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
if profile: if profile:
config.enable_profile() config.enable_profile()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.set_cpu_math_library_num_threads(cpu_num) config.set_cpu_math_library_num_threads(cpu_num)
config.switch_ir_optim(False) config.switch_ir_optim(False)
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config) self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None): def predict(self, feed=None, fetch=None):
...@@ -113,8 +114,8 @@ class Debugger(object): ...@@ -113,8 +114,8 @@ class Debugger(object):
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.")
return {} return {}
inputs = [] input_names = self.predictor.get_input_names()
for name in self.feed_names_: for name in input_names:
if isinstance(feed[name], list): if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name]) name])
...@@ -122,11 +123,21 @@ class Debugger(object): ...@@ -122,11 +123,21 @@ class Debugger(object):
feed[name] = feed[name].astype("int64") feed[name] = feed[name].astype("int64")
else: else:
feed[name] = feed[name].astype("float32") feed[name] = feed[name].astype("float32")
inputs.append(PaddleTensor(feed[name][np.newaxis, :])) input_tensor = self.predictor.get_input_tensor(name)
input_tensor.copy_from_cpu(feed[name])
outputs = self.predictor.run(inputs) output_tensors = []
output_names = self.predictor.get_output_names()
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
outputs = []
self.predictor.zero_copy_run()
for output_tensor in output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
fetch_map = {} fetch_map = {}
for name in fetch: for i, name in enumerate(fetch):
fetch_map[name] = outputs[self.fetch_names_to_idx_[ fetch_map[name] = outputs[i]
name]].as_ndarray() if len(output_tensors[i].lod()) > 0:
fetch_map[name + ".lod"] = output_tensors[i].lod()[0]
return fetch_map return fetch_map
...@@ -15,7 +15,7 @@ from .chinese_bert_reader import ChineseBertReader ...@@ -15,7 +15,7 @@ from .chinese_bert_reader import ChineseBertReader
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor
from .image_reader import RCNNPostprocess, SegPostprocess, PadStride from .image_reader import RCNNPostprocess, SegPostprocess, PadStride
from .image_reader import DBPostProcess, FilterBoxes from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from .lac_reader import LACReader from .lac_reader import LACReader
from .senta_reader import SentaReader from .senta_reader import SentaReader
from .imdb_reader import IMDBDataset from .imdb_reader import IMDBDataset
......
...@@ -797,6 +797,59 @@ class Transpose(object): ...@@ -797,6 +797,59 @@ class Transpose(object):
return format_string return format_string
class SortedBoxes(object):
"""
Sorted bounding boxes from Detection
"""
def __init__(self):
pass
def __call__(self, dt_boxes):
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
class GetRotateCropImage(object):
"""
Rotate and Crop image from OCR Det output
"""
def __init__(self):
pass
def __call__(self, img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
class ImageReader(): class ImageReader():
def __init__(self, def __init__(self,
image_shape=[3, 224, 224], image_shape=[3, 224, 224],
......
...@@ -120,29 +120,21 @@ class CharacterOps(object): ...@@ -120,29 +120,21 @@ class CharacterOps(object):
class OCRReader(object): class OCRReader(object):
def __init__(self): def __init__(self,
args = self.parse_args() algorithm="CRNN",
image_shape = [int(v) for v in args.rec_image_shape.split(",")] image_shape=[3, 32, 320],
char_type="ch",
batch_num=1,
char_dict_path="./ppocr_keys_v1.txt"):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type self.character_type = char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = batch_num
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = char_dict_path
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def parse_args(self):
parser = argparse.ArgumentParser()
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=1)
parser.add_argument(
"--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt")
return parser.parse_args()
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
if self.character_type == "ch": if self.character_type == "ch":
...@@ -154,15 +146,14 @@ class OCRReader(object): ...@@ -154,15 +146,14 @@ class OCRReader(object):
resized_w = imgW resized_w = imgW
else: else:
resized_w = int(math.ceil(imgH * ratio)) resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
seq = Sequential([ resized_image = resized_image.astype('float32')
Resize(imgH, resized_w), Transpose((2, 0, 1)), Div(255), resized_image = resized_image.transpose((2, 0, 1)) / 255
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], True) resized_image -= 0.5
]) resized_image /= 0.5
resized_image = seq(img)
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def preprocess(self, img_list): def preprocess(self, img_list):
...@@ -191,11 +182,17 @@ class OCRReader(object): ...@@ -191,11 +182,17 @@ class OCRReader(object):
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno] beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1] end = rec_idx_lod[rno + 1]
if isinstance(rec_idx_batch, list):
rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]]
else: #nd array
rec_idx_tmp = rec_idx_batch[beg:end, 0] rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp) preds_text = self.char_ops.decode(rec_idx_tmp)
if with_score: if with_score:
beg = predict_lod[rno] beg = predict_lod[rno]
end = predict_lod[rno + 1] end = predict_lod[rno + 1]
if isinstance(outputs["softmax_0.tmp_0"], list):
outputs["softmax_0.tmp_0"] = np.array(outputs[
"softmax_0.tmp_0"]).astype(np.float32)
probs = outputs["softmax_0.tmp_0"][beg:end, :] probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
blank = probs.shape[1] blank = probs.shape[1]
......
...@@ -399,6 +399,7 @@ class MultiLangClient(object): ...@@ -399,6 +399,7 @@ class MultiLangClient(object):
self.channel_ = None self.channel_ = None
self.stub_ = None self.stub_ = None
self.rpc_timeout_s_ = 2 self.rpc_timeout_s_ = 2
self.profile_ = _Profiler()
def add_variant(self, tag, cluster, variant_weight): def add_variant(self, tag, cluster, variant_weight):
# TODO # TODO
...@@ -520,7 +521,7 @@ class MultiLangClient(object): ...@@ -520,7 +521,7 @@ class MultiLangClient(object):
tensor.float_data.extend( tensor.float_data.extend(
var.reshape(-1).astype('float32').tolist()) var.reshape(-1).astype('float32').tolist())
elif v_type == 2: elif v_type == 2:
tensor.int32_data.extend( tensor.int_data.extend(
var.reshape(-1).astype('int32').tolist()) var.reshape(-1).astype('int32').tolist())
else: else:
raise Exception("error tensor value type.") raise Exception("error tensor value type.")
...@@ -530,7 +531,7 @@ class MultiLangClient(object): ...@@ -530,7 +531,7 @@ class MultiLangClient(object):
elif v_type == 1: elif v_type == 1:
tensor.float_data.extend(self._flatten_list(var)) tensor.float_data.extend(self._flatten_list(var))
elif v_type == 2: elif v_type == 2:
tensor.int32_data.extend(self._flatten_list(var)) tensor.int_data.extend(self._flatten_list(var))
else: else:
raise Exception("error tensor value type.") raise Exception("error tensor value type.")
else: else:
...@@ -582,6 +583,7 @@ class MultiLangClient(object): ...@@ -582,6 +583,7 @@ class MultiLangClient(object):
ret = list(multi_result_map.values())[0] ret = list(multi_result_map.values())[0]
else: else:
ret = multi_result_map ret = multi_result_map
ret["serving_status_code"] = 0 ret["serving_status_code"] = 0
return ret if not need_variant_tag else [ret, tag] return ret if not need_variant_tag else [ret, tag]
...@@ -601,18 +603,30 @@ class MultiLangClient(object): ...@@ -601,18 +603,30 @@ class MultiLangClient(object):
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True): is_python=True):
req = self._pack_inference_request(feed, fetch, is_python=is_python)
if not asyn: if not asyn:
try: try:
self.profile_.record('py_prepro_0')
req = self._pack_inference_request(
feed, fetch, is_python=is_python)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_) resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
return self._unpack_inference_response( self.profile_.record('py_client_infer_1')
self.profile_.record('py_postpro_0')
ret = self._unpack_inference_response(
resp, resp,
fetch, fetch,
is_python=is_python, is_python=is_python,
need_variant_tag=need_variant_tag) need_variant_tag=need_variant_tag)
self.profile_.record('py_postpro_1')
self.profile_.print_profile()
return ret
except grpc.RpcError as e: except grpc.RpcError as e:
return {"serving_status_code": e.code()} return {"serving_status_code": e.code()}
else: else:
req = self._pack_inference_request(feed, fetch, is_python=is_python)
call_future = self.stub_.Inference.future( call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_) req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture( return MultiLangPredictFuture(
......
...@@ -524,7 +524,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -524,7 +524,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
elif v_type == 1: # float32 elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32") data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32 elif v_type == 2: # int32
data = np.array(list(var.int32_data), dtype="int32") data = np.array(list(var.int_data), dtype="int32")
else: else:
raise Exception("error type.") raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape) data.shape = list(feed_inst.tensor_array[idx].shape)
...@@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret results, tag = ret
resp.tag = tag resp.tag = tag
resp.err_code = 0 resp.err_code = 0
if not self.is_multi_model_: if not self.is_multi_model_:
results = {'general_infer_0': results} results = {'general_infer_0': results}
for model_name, model_result in results.items(): for model_name, model_result in results.items():
...@@ -558,7 +559,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -558,7 +559,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.float_data.extend(model_result[name].reshape(-1) tensor.float_data.extend(model_result[name].reshape(-1)
.tolist()) .tolist())
elif v_type == 2: # int32 elif v_type == 2: # int32
tensor.int32_data.extend(model_result[name].reshape(-1) tensor.int_data.extend(model_result[name].reshape(-1)
.tolist()) .tolist())
else: else:
raise Exception("error type.") raise Exception("error type.")
......
...@@ -571,7 +571,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -571,7 +571,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
elif v_type == 1: # float32 elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32") data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: elif v_type == 2:
data = np.array(list(var.int32_data), dtype="int32") data = np.array(list(var.int_data), dtype="int32")
else: else:
raise Exception("error type.") raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape) data.shape = list(feed_inst.tensor_array[idx].shape)
...@@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret results, tag = ret
resp.tag = tag resp.tag = tag
resp.err_code = 0 resp.err_code = 0
if not self.is_multi_model_: if not self.is_multi_model_:
results = {'general_infer_0': results} results = {'general_infer_0': results}
for model_name, model_result in results.items(): for model_name, model_result in results.items():
...@@ -605,7 +606,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -605,7 +606,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.float_data.extend(model_result[name].reshape(-1) tensor.float_data.extend(model_result[name].reshape(-1)
.tolist()) .tolist())
elif v_type == 2: # int32 elif v_type == 2: # int32
tensor.int32_data.extend(model_result[name].reshape(-1) tensor.int_data.extend(model_result[name].reshape(-1)
.tolist()) .tolist())
else: else:
raise Exception("error type.") raise Exception("error type.")
......
...@@ -127,9 +127,9 @@ class WebService(object): ...@@ -127,9 +127,9 @@ class WebService(object):
request.json["fetch"]) request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
if len(feed) == 0:
raise ValueError("empty input")
fetch_map = self.client.predict(feed=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
result = self.postprocess( result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result} result = {"result": result}
...@@ -164,6 +164,33 @@ class WebService(object): ...@@ -164,6 +164,33 @@ class WebService(object):
self.app_instance = app_instance self.app_instance = app_instance
# TODO: maybe change another API name: maybe run_local_predictor?
def run_debugger_service(self, gpu=False):
import socket
localIP = socket.gethostbyname(socket.gethostname())
print("web service address:")
print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name))
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
self._launch_local_predictor(gpu)
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return self.get_prediction(request)
self.app_instance = app_instance
def _launch_local_predictor(self, gpu):
from paddle_serving_app.local_predict import Debugger
self.client = Debugger()
self.client.load_model_config(
"{}".format(self.model_config), gpu=gpu, profile=False)
def run_web_service(self): def run_web_service(self):
self.app_instance.run(host="0.0.0.0", self.app_instance.run(host="0.0.0.0",
port=self.port, port=self.port,
......
FROM centos:7.3.1611 FROM centos:7.3.1611
RUN yum -y install wget >/dev/null \ RUN yum -y install wget >/dev/null \
&& yum -y install gcc gcc-c++ make glibc-static which >/dev/null \ && yum -y install gcc gcc-c++ make glibc-static which >/dev/null \
&& yum -y install git openssl-devel curl-devel bzip2-devel python-devel >/dev/null \ && yum -y install git openssl-devel curl-devel bzip2-devel python-devel >/dev/null \
&& yum -y install libSM-1.2.2-2.el7.x86_64 --setopt=protected_multilib=false \ && yum -y install libSM-1.2.2-2.el7.x86_64 --setopt=protected_multilib=false \
&& yum -y install libXrender-0.9.10-1.el7.x86_64 --setopt=protected_multilib=false \ && yum -y install libXrender-0.9.10-1.el7.x86_64 --setopt=protected_multilib=false \
&& yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false \ && yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false
&& wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \
RUN wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \
&& tar xzf cmake-3.2.0-Linux-x86_64.tar.gz \ && tar xzf cmake-3.2.0-Linux-x86_64.tar.gz \
&& mv cmake-3.2.0-Linux-x86_64 /usr/local/cmake3.2.0 \ && mv cmake-3.2.0-Linux-x86_64 /usr/local/cmake3.2.0 \
&& echo 'export PATH=/usr/local/cmake3.2.0/bin:$PATH' >> /root/.bashrc \ && echo 'export PATH=/usr/local/cmake3.2.0/bin:$PATH' >> /root/.bashrc \
&& rm cmake-3.2.0-Linux-x86_64.tar.gz \ && rm cmake-3.2.0-Linux-x86_64.tar.gz
&& wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
&& tar xzf go1.14.linux-amd64.tar.gz \ && tar xzf go1.14.linux-amd64.tar.gz \
&& mv go /usr/local/go \ && mv go /usr/local/go \
&& echo 'export GOROOT=/usr/local/go' >> /root/.bashrc \ && echo 'export GOROOT=/usr/local/go' >> /root/.bashrc \
&& echo 'export PATH=/usr/local/go/bin:$PATH' >> /root/.bashrc \ && echo 'export PATH=/usr/local/go/bin:$PATH' >> /root/.bashrc \
&& rm go1.14.linux-amd64.tar.gz \ && rm go1.14.linux-amd64.tar.gz
&& yum -y install python-devel sqlite-devel >/dev/null \
RUN yum -y install python-devel sqlite-devel >/dev/null \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \ && python get-pip.py >/dev/null \
&& pip install google protobuf setuptools wheel flask >/dev/null \ && pip install google protobuf setuptools wheel flask >/dev/null \
&& rm get-pip.py \ && rm get-pip.py
&& wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
&& yum -y install bzip2 >/dev/null \ && yum -y install bzip2 >/dev/null \
&& tar -jxf patchelf-0.10.tar.bz2 \ && tar -jxf patchelf-0.10.tar.bz2 \
&& cd patchelf-0.10 \ && cd patchelf-0.10 \
&& ./configure --prefix=/usr \ && ./configure --prefix=/usr \
&& make >/dev/null && make install >/dev/null \ && make >/dev/null && make install >/dev/null \
&& cd .. \ && cd .. \
&& rm -rf patchelf-0.10* \ && rm -rf patchelf-0.10*
&& yum install -y python3 python3-devel \
&& pip3 install google protobuf setuptools wheel flask \ RUN yum install -y python3 python3-devel \
&& yum -y update >/dev/null \ && pip3 install google protobuf setuptools wheel flask
RUN yum -y update >/dev/null \
&& yum -y install dnf >/dev/null \ && yum -y install dnf >/dev/null \
&& yum -y install dnf-plugins-core >/dev/null \ && yum -y install dnf-plugins-core >/dev/null \
&& dnf copr enable alonid/llvm-3.8.0 -y \ && dnf copr enable alonid/llvm-3.8.0 -y \
&& dnf install llvm-3.8.0 clang-3.8.0 compiler-rt-3.8.0 -y \ && dnf install llvm-3.8.0 clang-3.8.0 compiler-rt-3.8.0 -y \
&& echo 'export PATH=/opt/llvm-3.8.0/bin:$PATH' >> /root/.bashrc && echo 'export PATH=/opt/llvm-3.8.0/bin:$PATH' >> /root/.bashrc
RUN yum install -y java \
&& wget http://repos.fedorapeople.org/repos/dchen/apache-maven/epel-apache-maven.repo -O /etc/yum.repos.d/epel-apache-maven.repo \
&& yum install -y apache-maven
...@@ -182,26 +182,26 @@ function python_test_fit_a_line() { ...@@ -182,26 +182,26 @@ function python_test_fit_a_line() {
kill_server_process kill_server_process
# test web # test web
unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed. #unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed.
check_cmd "python -m paddle_serving_server_gpu.serve --model uci_housing_model --port 9393 --thread 2 --gpu_ids 0 --name uci > /dev/null &" #check_cmd "python -m paddle_serving_server_gpu.serve --model uci_housing_model --port 9393 --thread 2 --gpu_ids 0 --name uci > /dev/null &"
sleep 5 # wait for the server to start #sleep 5 # wait for the server to start
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" #check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction"
# check http code # check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` #http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction`
if [ ${http_code} -ne 200 ]; then #if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200" # echo "HTTP status code -ne 200"
exit 1 # exit 1
fi #fi
# test web batch # test web batch
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" #check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction"
# check http code # check http code
http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` #http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction`
if [ ${http_code} -ne 200 ]; then #if [ ${http_code} -ne 200 ]; then
echo "HTTP status code -ne 200" # echo "HTTP status code -ne 200"
exit 1 # exit 1
fi #fi
setproxy # recover proxy state #setproxy # recover proxy state
kill_server_process #kill_server_process
;; ;;
*) *)
echo "error type" echo "error type"
...@@ -499,6 +499,64 @@ function python_test_lac() { ...@@ -499,6 +499,64 @@ function python_test_lac() {
cd .. cd ..
} }
function java_run_test() {
# pwd: /Serving
local TYPE=$1
export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
unsetproxy
case $TYPE in
CPU)
# compile java sdk
cd java # pwd: /Serving/java
mvn compile > /dev/null
mvn install > /dev/null
# compile java sdk example
cd examples # pwd: /Serving/java/examples
mvn compile > /dev/null
mvn install > /dev/null
# fit_a_line (general, asyn_predict, batch_predict)
cd ../../python/examples/grpc_impl_example/fit_a_line # pwd: /Serving/python/examples/grpc_impl_example/fit_a_line
sh get_data.sh
check_cmd "python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --thread 4 --use_multilang > /dev/null &"
sleep 5 # wait for the server to start
cd ../../../java/examples # /Serving/java/examples
java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample fit_a_line
java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample asyn_predict
java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample batch_predict
kill_server_process
# imdb (model_ensemble)
cd ../../python/examples/grpc_impl_example/imdb # pwd: /Serving/python/examples/grpc_impl_example/imdb
sh get_data.sh > /dev/null
check_cmd "python test_multilang_ensemble_server.py > /dev/null &"
sleep 5 # wait for the server to start
cd ../../../java/examples # /Serving/java/examples
java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample model_ensemble
kill_server_process
# yolov4 (int32)
cd ../../python/examples/grpc_impl_example/yolov4 # pwd: /Serving/python/examples/grpc_impl_example/yolov4
python -m paddle_serving_app.package --get_model yolov4 > /dev/null
tar -xzf yolov4.tar.gz > /dev/null
check_cmd "python -m paddle_serving_server.serve --model yolov4_model --port 9393 --use_multilang --mem_optim > /dev/null &"
cd ../../../java/examples # /Serving/java/examples
java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample yolov4 src/main/resources/000000570688.jpg
kill_server_process
cd ../../ # pwd: /Serving
;;
GPU)
;;
*)
echo "error type"
exit 1
;;
esac
echo "java-sdk $TYPE part finished as expected."
setproxy
unset SERVING_BIN
}
function python_test_grpc_impl() { function python_test_grpc_impl() {
# pwd: /Serving/python/examples # pwd: /Serving/python/examples
cd grpc_impl_example # pwd: /Serving/python/examples/grpc_impl_example cd grpc_impl_example # pwd: /Serving/python/examples/grpc_impl_example
...@@ -537,7 +595,7 @@ function python_test_grpc_impl() { ...@@ -537,7 +595,7 @@ function python_test_grpc_impl() {
# test load server config and client config in Server side # test load server config and client config in Server side
cd criteo_ctr_with_cube # pwd: /Serving/python/examples/grpc_impl_example/criteo_ctr_with_cube cd criteo_ctr_with_cube # pwd: /Serving/python/examples/grpc_impl_example/criteo_ctr_with_cube
check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz" check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz > /dev/null"
check_cmd "tar xf ctr_cube_unittest.tar.gz" check_cmd "tar xf ctr_cube_unittest.tar.gz"
check_cmd "mv models/ctr_client_conf ./" check_cmd "mv models/ctr_client_conf ./"
check_cmd "mv models/ctr_serving_model_kv ./" check_cmd "mv models/ctr_serving_model_kv ./"
...@@ -978,6 +1036,7 @@ function main() { ...@@ -978,6 +1036,7 @@ function main() {
build_client $TYPE # pwd: /Serving build_client $TYPE # pwd: /Serving
build_server $TYPE # pwd: /Serving build_server $TYPE # pwd: /Serving
build_app $TYPE # pwd: /Serving build_app $TYPE # pwd: /Serving
java_run_test $TYPE # pwd: /Serving
python_run_test $TYPE # pwd: /Serving python_run_test $TYPE # pwd: /Serving
monitor_test $TYPE # pwd: /Serving monitor_test $TYPE # pwd: /Serving
echo "serving $TYPE part finished as expected." echo "serving $TYPE part finished as expected."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册