diff --git a/core/configure/proto/multi_lang_general_model_service.proto b/core/configure/proto/multi_lang_general_model_service.proto index 2a8a8bc1532c19aa02a1998aa751aa7ba9d41570..b83450aed666b96de324050d53b10c56e059a8d5 100644 --- a/core/configure/proto/multi_lang_general_model_service.proto +++ b/core/configure/proto/multi_lang_general_model_service.proto @@ -14,6 +14,10 @@ syntax = "proto2"; +option java_multiple_files = true; +option java_package = "io.paddle.serving.grpc"; +option java_outer_classname = "ServingProto"; + message Tensor { optional bytes data = 1; repeated int32 int_data = 2; diff --git a/doc/JAVA_SDK.md b/doc/JAVA_SDK.md new file mode 100644 index 0000000000000000000000000000000000000000..4880e74bfee123b432b6b583a239d2d2ccbb45ac --- /dev/null +++ b/doc/JAVA_SDK.md @@ -0,0 +1,109 @@ +# Paddle Serving Client Java SDK + +([简体中文](JAVA_SDK_CN.md)|English) + +Paddle Serving provides Java SDK,which supports predict on the Client side with Java language. This document shows how to use the Java SDK. + +## Getting started + + +### Prerequisites + +``` +- Java 8 or higher +- Apache Maven +``` + +The following table shows compatibilities between Paddle Serving Server and Java SDK. + +| Paddle Serving Server version | Java SDK version | +| :---------------------------: | :--------------: | +| 0.3.2 | 0.0.1 | + +### Install Java SDK + +You can download jar and install it to the local Maven repository: + +```shell +wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar +mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar +``` + +Or compile from the source code and install it to the local Maven repository: + +```shell +cd Serving/java +mvn compile +mvn install +``` + +### Maven configure + +```text + + io.paddle.serving.client + paddle-serving-sdk-java + 0.0.1 + +``` + + + +## Example + +Here we will show how to use Java SDK for Boston house price prediction. Please refer to [examples](../java/examples) folder for more examples. + +### Get model + +```shell +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz +tar -xzf uci_housing.tar.gz +``` + +### Start Python Server + +```shell +python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang +``` + +#### Client side code example + +```java +import io.paddle.serving.client.*; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import java.util.*; + +public class PaddleServingClientExample { + public static void main( String[] args ) { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return ; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + System.out.println("predict failed."); + return ; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return ; + } +} +``` diff --git a/doc/JAVA_SDK_CN.md b/doc/JAVA_SDK_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..f624a4403371f5b284f34cbf310fef64d59602d9 --- /dev/null +++ b/doc/JAVA_SDK_CN.md @@ -0,0 +1,109 @@ +# Paddle Serving Client Java SDK + +(简体中文|[English](JAVA_SDK.md)) + +Paddle Serving 提供了 Java SDK,支持 Client 端用 Java 语言进行预测,本文档说明了如何使用 Java SDK。 + +## 快速开始 + +### 环境要求 + +``` +- Java 8 or higher +- Apache Maven +``` + +下表显示了 Paddle Serving Server 和 Java SDK 之间的兼容性 + +| Paddle Serving Server version | Java SDK version | +| :---------------------------: | :--------------: | +| 0.3.2 | 0.0.1 | + +### 安装 + +您可以直接下载 jar,安装到本地 Maven 库: + +```shell +wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar +mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar +``` + +或者从源码进行编译,安装到本地 Maven 库: + +```shell +cd Serving/java +mvn compile +mvn install +``` + +### Maven 配置 + +```text + + io.paddle.serving.client + paddle-serving-sdk-java + 0.0.1 + +``` + + + + +## 使用样例 + +这里将展示如何使用 Java SDK 进行房价预测,更多例子详见 [examples](../java/examples) 文件夹。 + +### 获取房价预测模型 + +```shell +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz +tar -xzf uci_housing.tar.gz +``` + +### 启动 Python 端 Server + +```shell +python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang +``` + +### Client 端代码示例 + +```java +import io.paddle.serving.client.*; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import java.util.*; + +public class PaddleServingClientExample { + public static void main( String[] args ) { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return ; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + System.out.println("predict failed."); + return ; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return ; + } +} +``` diff --git a/java/examples/pom.xml b/java/examples/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..b6c8bc424f5d528d74a4a45828fd9b5e7e5d008e --- /dev/null +++ b/java/examples/pom.xml @@ -0,0 +1,88 @@ + + + + 4.0.0 + + io.paddle.serving.client + paddle-serving-sdk-java-examples + 0.0.1 + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 8 + 8 + + 3.8.1 + + + maven-assembly-plugin + + + + true + my.fully.qualified.class.Main + + + + jar-with-dependencies + + + + + make-my-jar-with-dependencies + package + + single + + + + + + + + + UTF-8 + nd4j-native + 1.0.0-beta7 + 1.0.0-beta7 + 0.0.1 + 1.7 + 1.7 + + + + + io.paddle.serving.client + paddle-serving-sdk-java + ${paddle.serving.client.version} + + + org.slf4j + slf4j-api + 1.7.30 + + + org.nd4j + ${nd4j.backend} + ${nd4j.version} + + + junit + junit + 4.11 + test + + + org.datavec + datavec-data-image + ${datavec.version} + + + + diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java new file mode 100644 index 0000000000000000000000000000000000000000..cdc11df130095d668734ae0a23adb12ef735b2ea --- /dev/null +++ b/java/examples/src/main/java/PaddleServingClientExample.java @@ -0,0 +1,363 @@ +import io.paddle.serving.client.*; +import java.io.File; +import java.io.IOException; +import java.net.URL; +import org.nd4j.linalg.api.iter.NdIndexIterator; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.datavec.image.loader.NativeImageLoader; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import java.util.*; + +public class PaddleServingClientExample { + boolean fit_a_line() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean yolov4(String filename) { + // https://deeplearning4j.konduit.ai/ + int height = 608; + int width = 608; + int channels = 3; + NativeImageLoader loader = new NativeImageLoader(height, width, channels); + INDArray BGRimage = null; + try { + BGRimage = loader.asMatrix(new File(filename)); + } catch (java.io.IOException e) { + System.out.println("load image fail."); + return false; + } + + // shape: (channels, height, width) + BGRimage = BGRimage.reshape(channels, height, width); + INDArray RGBimage = Nd4j.create(BGRimage.shape()); + + // BGR2RGB + CustomOp op = DynamicCustomOp.builder("reverse") + .addInputs(BGRimage) + .addOutputs(RGBimage) + .addIntegerArguments(0) + .build(); + Nd4j.getExecutioner().exec(op); + + // Div(255.0) + INDArray image = RGBimage.divi(255.0); + + INDArray im_size = Nd4j.createFromArray(new int[]{height, width}); + HashMap feed_data + = new HashMap() {{ + put("image", image); + put("im_size", im_size); + }}; + List fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + succ = client.setRpcTimeoutMs(20000); // cpu + if (succ != true) { + System.out.println("set timeout failed."); + return false; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean batch_predict() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List> feed_batch + = new ArrayList>() {{ + add(feed_data); + add(feed_data); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_batch, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean asyn_predict() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + PredictFuture future = client.asyn_predict(feed_data, fetch); + Map fetch_map = future.get(); + if (fetch_map == null) { + System.out.println("Get future reslut failed"); + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean model_ensemble() { + long[] data = {8, 233, 52, 601}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("words", npdata); + }}; + List fetch = Arrays.asList("prediction"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map> fetch_map + = client.ensemble_predict(feed_data, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry> entry : fetch_map.entrySet()) { + System.out.println("Model = " + entry.getKey()); + HashMap tt = entry.getValue(); + for (Map.Entry e : tt.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + } + return true; + } + + boolean bert() { + float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long[] segment_ids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + HashMap feed_data + = new HashMap() {{ + put("input_mask", Nd4j.createFromArray(input_mask)); + put("position_ids", Nd4j.createFromArray(position_ids)); + put("input_ids", Nd4j.createFromArray(input_ids)); + put("segment_ids", Nd4j.createFromArray(segment_ids)); + }}; + List fetch = Arrays.asList("pooled_output"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean cube_local() { + long[] embedding_14 = {250644}; + long[] embedding_2 = {890346}; + long[] embedding_10 = {3939}; + long[] embedding_17 = {421122}; + long[] embedding_23 = {664215}; + long[] embedding_6 = {704846}; + float[] dense_input = {0.0f, 0.006633499170812604f, 0.03f, 0.0f, + 0.145078125f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + long[] embedding_24 = {269955}; + long[] embedding_12 = {295309}; + long[] embedding_7 = {437731}; + long[] embedding_3 = {990128}; + long[] embedding_1 = {7753}; + long[] embedding_4 = {286835}; + long[] embedding_8 = {27346}; + long[] embedding_9 = {636474}; + long[] embedding_18 = {880474}; + long[] embedding_16 = {681378}; + long[] embedding_22 = {410878}; + long[] embedding_13 = {255651}; + long[] embedding_5 = {25207}; + long[] embedding_11 = {10891}; + long[] embedding_20 = {238459}; + long[] embedding_21 = {26235}; + long[] embedding_15 = {691460}; + long[] embedding_25 = {544187}; + long[] embedding_19 = {537425}; + long[] embedding_0 = {737395}; + + HashMap feed_data + = new HashMap() {{ + put("embedding_14.tmp_0", Nd4j.createFromArray(embedding_14)); + put("embedding_2.tmp_0", Nd4j.createFromArray(embedding_2)); + put("embedding_10.tmp_0", Nd4j.createFromArray(embedding_10)); + put("embedding_17.tmp_0", Nd4j.createFromArray(embedding_17)); + put("embedding_23.tmp_0", Nd4j.createFromArray(embedding_23)); + put("embedding_6.tmp_0", Nd4j.createFromArray(embedding_6)); + put("dense_input", Nd4j.createFromArray(dense_input)); + put("embedding_24.tmp_0", Nd4j.createFromArray(embedding_24)); + put("embedding_12.tmp_0", Nd4j.createFromArray(embedding_12)); + put("embedding_7.tmp_0", Nd4j.createFromArray(embedding_7)); + put("embedding_3.tmp_0", Nd4j.createFromArray(embedding_3)); + put("embedding_1.tmp_0", Nd4j.createFromArray(embedding_1)); + put("embedding_4.tmp_0", Nd4j.createFromArray(embedding_4)); + put("embedding_8.tmp_0", Nd4j.createFromArray(embedding_8)); + put("embedding_9.tmp_0", Nd4j.createFromArray(embedding_9)); + put("embedding_18.tmp_0", Nd4j.createFromArray(embedding_18)); + put("embedding_16.tmp_0", Nd4j.createFromArray(embedding_16)); + put("embedding_22.tmp_0", Nd4j.createFromArray(embedding_22)); + put("embedding_13.tmp_0", Nd4j.createFromArray(embedding_13)); + put("embedding_5.tmp_0", Nd4j.createFromArray(embedding_5)); + put("embedding_11.tmp_0", Nd4j.createFromArray(embedding_11)); + put("embedding_20.tmp_0", Nd4j.createFromArray(embedding_20)); + put("embedding_21.tmp_0", Nd4j.createFromArray(embedding_21)); + put("embedding_15.tmp_0", Nd4j.createFromArray(embedding_15)); + put("embedding_25.tmp_0", Nd4j.createFromArray(embedding_25)); + put("embedding_19.tmp_0", Nd4j.createFromArray(embedding_19)); + put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0)); + }}; + List fetch = Arrays.asList("prob"); + + Client client = new Client(); + String target = "localhost:9393"; + boolean succ = client.connect(target); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_data, fetch); + if (fetch_map == null) { + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + public static void main( String[] args ) { + // DL4J(Deep Learning for Java)Document: + // https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md + PaddleServingClientExample e = new PaddleServingClientExample(); + boolean succ = false; + + if (args.length < 1) { + System.out.println("Usage: java -cp PaddleServingClientExample ."); + System.out.println(": fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4"); + return; + } + String testType = args[0]; + System.out.format("[Example] %s\n", testType); + if ("fit_a_line".equals(testType)) { + succ = e.fit_a_line(); + } else if ("bert".equals(testType)) { + succ = e.bert(); + } else if ("model_ensemble".equals(testType)) { + succ = e.model_ensemble(); + } else if ("asyn_predict".equals(testType)) { + succ = e.asyn_predict(); + } else if ("batch_predict".equals(testType)) { + succ = e.batch_predict(); + } else if ("cube_local".equals(testType)) { + succ = e.cube_local(); + } else if ("cube_quant".equals(testType)) { + succ = e.cube_local(); + } else if ("yolov4".equals(testType)) { + if (args.length < 2) { + System.out.println("Usage: java -cp PaddleServingClientExample yolov4 ."); + return; + } + succ = e.yolov4(args[1]); + } else { + System.out.format("test-type(%s) not match.\n", testType); + return; + } + + if (succ == true) { + System.out.println("[Example] succ."); + } else { + System.out.println("[Example] fail."); + } + } +} diff --git a/java/examples/src/main/resources/000000570688.jpg b/java/examples/src/main/resources/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/java/examples/src/main/resources/000000570688.jpg differ diff --git a/java/pom.xml b/java/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..d7e9ea7a097ea1ea2f41f930773d4a5d72a6d515 --- /dev/null +++ b/java/pom.xml @@ -0,0 +1,267 @@ + + + + 4.0.0 + + io.paddle.serving.client + paddle-serving-sdk-java + 0.0.1 + jar + + paddle-serving-sdk-java + Java SDK for Paddle Sering Client. + https://github.com/PaddlePaddle/Serving + + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + + PaddlePaddle Author + guru4elephant@gmail.com + PaddlePaddle + https://github.com/PaddlePaddle/Serving + + + + + scm:git:https://github.com/PaddlePaddle/Serving.git + scm:git:https://github.com/PaddlePaddle/Serving.git + https://github.com/PaddlePaddle/Serving + + + + UTF-8 + 1.27.2 + 3.11.0 + 3.11.0 + nd4j-native + 1.0.0-beta7 + 1.8 + 1.8 + + + + + + io.grpc + grpc-bom + ${grpc.version} + pom + import + + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.6 + + + io.grpc + grpc-netty-shaded + runtime + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + javax.annotation + javax.annotation-api + 1.2 + provided + + + io.grpc + grpc-testing + test + + + com.google.protobuf + protobuf-java-util + ${protobuf.version} + runtime + + + com.google.errorprone + error_prone_annotations + 2.3.4 + + + org.junit.jupiter + junit-jupiter + 5.5.2 + test + + + org.apache.commons + commons-text + 1.6 + + + org.apache.commons + commons-collections4 + 4.4 + + + org.json + json + 20190722 + + + org.slf4j + slf4j-api + 1.7.30 + + + org.apache.logging.log4j + log4j-slf4j-impl + 2.12.1 + + + org.nd4j + ${nd4j.backend} + ${nd4j.version} + + + + + + release + + + + org.apache.maven.plugins + maven-source-plugin + 3.1.0 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.1.1 + + ${java.home}/bin/javadoc + + + + attach-javadocs + + jar + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.6 + + + sign-artifacts + verify + + sign + + + + + + + + + + + + + kr.motd.maven + os-maven-plugin + 1.6.2 + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.6.8 + true + + ossrh + https://oss.sonatype.org/ + true + + + + org.apache.maven.plugins + maven-release-plugin + 2.5.3 + + true + false + release + deploy + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} + + grpc-java + io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} + + + + + + compile + compile-custom + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.0.0-M2 + + + enforce + + + + + + + enforce + + + + + + + + diff --git a/java/src/main/java/io/paddle/serving/client/Client.java b/java/src/main/java/io/paddle/serving/client/Client.java new file mode 100644 index 0000000000000000000000000000000000000000..1e09e0c23c89dd4f0d70e0c93269b2185a69807f --- /dev/null +++ b/java/src/main/java/io/paddle/serving/client/Client.java @@ -0,0 +1,471 @@ +package io.paddle.serving.client; + +import java.util.*; +import java.util.function.Function; +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.StatusRuntimeException; +import com.google.protobuf.ByteString; + +import com.google.common.util.concurrent.ListenableFuture; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.iter.NdIndexIterator; +import org.nd4j.linalg.factory.Nd4j; + +import io.paddle.serving.grpc.*; +import io.paddle.serving.configure.*; +import io.paddle.serving.client.PredictFuture; + +class Profiler { + int pid_; + String print_head_ = null; + List time_record_ = null; + boolean enable_ = false; + + Profiler() { + RuntimeMXBean runtimeMXBean = ManagementFactory.getRuntimeMXBean(); + pid_ = Integer.valueOf(runtimeMXBean.getName().split("@")[0]).intValue(); + print_head_ = "\nPROFILE\tpid:" + pid_ + "\t"; + time_record_ = new ArrayList(); + time_record_.add(print_head_); + } + + void record(String name) { + if (enable_) { + long ctime = System.currentTimeMillis() * 1000; + time_record_.add(name + ":" + String.valueOf(ctime) + " "); + } + } + + void printProfile() { + if (enable_) { + String profile_str = String.join("", time_record_); + time_record_ = new ArrayList(); + time_record_.add(print_head_); + } + } + + void enable(boolean flag) { + enable_ = flag; + } +} + +public class Client { + private ManagedChannel channel_; + private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_; + private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceFutureStub futureStub_; + private double rpcTimeoutS_; + private List feedNames_; + private Map feedTypes_; + private Map> feedShapes_; + private List fetchNames_; + private Map fetchTypes_; + private Set lodTensorSet_; + private Map feedTensorLen_; + private Profiler profiler_; + + public Client() { + channel_ = null; + blockingStub_ = null; + futureStub_ = null; + rpcTimeoutS_ = 2; + + feedNames_ = null; + feedTypes_ = null; + feedShapes_ = null; + fetchNames_ = null; + fetchTypes_ = null; + lodTensorSet_ = null; + feedTensorLen_ = null; + + profiler_ = new Profiler(); + boolean is_profile = false; + String FLAGS_profile_client = System.getenv("FLAGS_profile_client"); + if (FLAGS_profile_client != null && FLAGS_profile_client.equals("1")) { + is_profile = true; + } + profiler_.enable(is_profile); + } + + public boolean setRpcTimeoutMs(int rpc_timeout) { + if (futureStub_ == null || blockingStub_ == null) { + System.out.println("set timeout must be set after connect."); + return false; + } + rpcTimeoutS_ = rpc_timeout / 1000.0; + SetTimeoutRequest timeout_req = SetTimeoutRequest.newBuilder() + .setTimeoutMs(rpc_timeout) + .build(); + SimpleResponse resp; + try { + resp = blockingStub_.setTimeout(timeout_req); + } catch (StatusRuntimeException e) { + System.out.format("Set RPC timeout failed: %s\n", e.toString()); + return false; + } + return resp.getErrCode() == 0; + } + + public boolean connect(String target) { + // TODO: target must be NameResolver-compliant URI + // https://grpc.github.io/grpc-java/javadoc/io/grpc/ManagedChannelBuilder.html + try { + channel_ = ManagedChannelBuilder.forTarget(target) + .defaultLoadBalancingPolicy("round_robin") + .maxInboundMessageSize(Integer.MAX_VALUE) + .usePlaintext() + .build(); + blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_); + futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_); + } catch (Exception e) { + System.out.format("Connect failed: %s\n", e.toString()); + return false; + } + GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); + GetClientConfigResponse resp; + try { + resp = blockingStub_.getClientConfig(get_client_config_req); + } catch (Exception e) { + System.out.format("Get Client config failed: %s\n", e.toString()); + return false; + } + String model_config_str = resp.getClientConfigStr(); + _parseModelConfig(model_config_str); + return true; + } + + private void _parseModelConfig(String model_config_str) { + GeneralModelConfig.Builder model_conf_builder = GeneralModelConfig.newBuilder(); + try { + com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder); + } catch (com.google.protobuf.TextFormat.ParseException e) { + System.out.format("Parse client config failed: %s\n", e.toString()); + } + GeneralModelConfig model_conf = model_conf_builder.build(); + + feedNames_ = new ArrayList(); + fetchNames_ = new ArrayList(); + feedTypes_ = new HashMap(); + feedShapes_ = new HashMap>(); + fetchTypes_ = new HashMap(); + lodTensorSet_ = new HashSet(); + feedTensorLen_ = new HashMap(); + + List feed_var_list = model_conf.getFeedVarList(); + for (FeedVar feed_var : feed_var_list) { + feedNames_.add(feed_var.getAliasName()); + } + List fetch_var_list = model_conf.getFetchVarList(); + for (FetchVar fetch_var : fetch_var_list) { + fetchNames_.add(fetch_var.getAliasName()); + } + + for (int i = 0; i < feed_var_list.size(); ++i) { + FeedVar feed_var = feed_var_list.get(i); + String var_name = feed_var.getAliasName(); + feedTypes_.put(var_name, feed_var.getFeedType()); + feedShapes_.put(var_name, feed_var.getShapeList()); + if (feed_var.getIsLodTensor()) { + lodTensorSet_.add(var_name); + } else { + int counter = 1; + for (int dim : feedShapes_.get(var_name)) { + counter *= dim; + } + feedTensorLen_.put(var_name, counter); + } + } + + for (int i = 0; i < fetch_var_list.size(); i++) { + FetchVar fetch_var = fetch_var_list.get(i); + String var_name = fetch_var.getAliasName(); + fetchTypes_.put(var_name, fetch_var.getFetchType()); + if (fetch_var.getIsLodTensor()) { + lodTensorSet_.add(var_name); + } + } + } + + private InferenceRequest _packInferenceRequest( + List> feed_batch, + Iterable fetch) throws IllegalArgumentException { + List feed_var_names = new ArrayList(); + feed_var_names.addAll(feed_batch.get(0).keySet()); + + InferenceRequest.Builder req_builder = InferenceRequest.newBuilder() + .addAllFeedVarNames(feed_var_names) + .addAllFetchVarNames(fetch) + .setIsPython(false); + for (HashMap feed_data: feed_batch) { + FeedInst.Builder inst_builder = FeedInst.newBuilder(); + for (String name: feed_var_names) { + Tensor.Builder tensor_builder = Tensor.newBuilder(); + INDArray variable = feed_data.get(name); + long[] flattened_shape = {-1}; + INDArray flattened_list = variable.reshape(flattened_shape); + int v_type = feedTypes_.get(name); + NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); + if (v_type == 0) { // int64 + while (iter.hasNext()) { + long[] next_index = iter.next(); + long x = flattened_list.getLong(next_index); + tensor_builder.addInt64Data(x); + } + } else if (v_type == 1) { // float32 + while (iter.hasNext()) { + long[] next_index = iter.next(); + float x = flattened_list.getFloat(next_index); + tensor_builder.addFloatData(x); + } + } else if (v_type == 2) { // int32 + while (iter.hasNext()) { + long[] next_index = iter.next(); + // the interface of INDArray is strange: + // https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html + int[] int_next_index = new int[next_index.length]; + for(int i = 0; i < next_index.length; i++) { + int_next_index[i] = (int)next_index[i]; + } + int x = flattened_list.getInt(int_next_index); + tensor_builder.addIntData(x); + } + } else { + throw new IllegalArgumentException("error tensor value type."); + } + tensor_builder.addAllShape(feedShapes_.get(name)); + inst_builder.addTensorArray(tensor_builder.build()); + } + req_builder.addInsts(inst_builder.build()); + } + return req_builder.build(); + } + + private Map> + _unpackInferenceResponse( + InferenceResponse resp, + Iterable fetch, + Boolean need_variant_tag) throws IllegalArgumentException { + return Client._staticUnpackInferenceResponse( + resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); + } + + private static Map> + _staticUnpackInferenceResponse( + InferenceResponse resp, + Iterable fetch, + Map fetchTypes, + Set lodTensorSet, + Boolean need_variant_tag) throws IllegalArgumentException { + if (resp.getErrCode() != 0) { + return null; + } + String tag = resp.getTag(); + HashMap> multi_result_map + = new HashMap>(); + for (ModelOutput model_result: resp.getOutputsList()) { + String engine_name = model_result.getEngineName(); + FetchInst inst = model_result.getInsts(0); + HashMap result_map + = new HashMap(); + int index = 0; + for (String name: fetch) { + Tensor variable = inst.getTensorArray(index); + int v_type = fetchTypes.get(name); + INDArray data = null; + if (v_type == 0) { // int64 + List list = variable.getInt64DataList(); + long[] array = new long[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + data = Nd4j.createFromArray(array); + } else if (v_type == 1) { // float32 + List list = variable.getFloatDataList(); + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + data = Nd4j.createFromArray(array); + } else if (v_type == 2) { // int32 + List list = variable.getIntDataList(); + int[] array = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + data = Nd4j.createFromArray(array); + } else { + throw new IllegalArgumentException("error tensor value type."); + } + // shape + List shape_lsit = variable.getShapeList(); + int[] shape_array = new int[shape_lsit.size()]; + for (int i = 0; i < shape_lsit.size(); ++i) { + shape_array[i] = shape_lsit.get(i); + } + data = data.reshape(shape_array); + + // put data to result_map + result_map.put(name, data); + + // lod + if (lodTensorSet.contains(name)) { + List list = variable.getLodList(); + int[] array = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + result_map.put(name + ".lod", Nd4j.createFromArray(array)); + } + index += 1; + } + multi_result_map.put(engine_name, result_map); + } + + // TODO: tag(ABtest not support now) + return multi_result_map; + } + + public Map predict( + HashMap feed, + Iterable fetch) { + return predict(feed, fetch, false); + } + + public Map> ensemble_predict( + HashMap feed, + Iterable fetch) { + return ensemble_predict(feed, fetch, false); + } + + public PredictFuture asyn_predict( + HashMap feed, + Iterable fetch) { + return asyn_predict(feed, fetch, false); + } + + public Map predict( + HashMap feed, + Iterable fetch, + Boolean need_variant_tag) { + List> feed_batch + = new ArrayList>(); + feed_batch.add(feed); + return predict(feed_batch, fetch, need_variant_tag); + } + + public Map> ensemble_predict( + HashMap feed, + Iterable fetch, + Boolean need_variant_tag) { + List> feed_batch + = new ArrayList>(); + feed_batch.add(feed); + return ensemble_predict(feed_batch, fetch, need_variant_tag); + } + + public PredictFuture asyn_predict( + HashMap feed, + Iterable fetch, + Boolean need_variant_tag) { + List> feed_batch + = new ArrayList>(); + feed_batch.add(feed); + return asyn_predict(feed_batch, fetch, need_variant_tag); + } + + public Map predict( + List> feed_batch, + Iterable fetch) { + return predict(feed_batch, fetch, false); + } + + public Map> ensemble_predict( + List> feed_batch, + Iterable fetch) { + return ensemble_predict(feed_batch, fetch, false); + } + + public PredictFuture asyn_predict( + List> feed_batch, + Iterable fetch) { + return asyn_predict(feed_batch, fetch, false); + } + + public Map predict( + List> feed_batch, + Iterable fetch, + Boolean need_variant_tag) { + try { + profiler_.record("java_prepro_0"); + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); + profiler_.record("java_prepro_1"); + + profiler_.record("java_client_infer_0"); + InferenceResponse resp = blockingStub_.inference(req); + profiler_.record("java_client_infer_1"); + + profiler_.record("java_postpro_0"); + Map> ensemble_result + = _unpackInferenceResponse(resp, fetch, need_variant_tag); + List>> list + = new ArrayList>>( + ensemble_result.entrySet()); + if (list.size() != 1) { + System.out.format("predict failed: please use ensemble_predict impl.\n"); + return null; + } + profiler_.record("java_postpro_1"); + profiler_.printProfile(); + + return list.get(0).getValue(); + } catch (StatusRuntimeException e) { + System.out.format("predict failed: %s\n", e.toString()); + return null; + } + } + + public Map> ensemble_predict( + List> feed_batch, + Iterable fetch, + Boolean need_variant_tag) { + try { + profiler_.record("java_prepro_0"); + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); + profiler_.record("java_prepro_1"); + + profiler_.record("java_client_infer_0"); + InferenceResponse resp = blockingStub_.inference(req); + profiler_.record("java_client_infer_1"); + + profiler_.record("java_postpro_0"); + Map> ensemble_result + = _unpackInferenceResponse(resp, fetch, need_variant_tag); + profiler_.record("java_postpro_1"); + profiler_.printProfile(); + + return ensemble_result; + } catch (StatusRuntimeException e) { + System.out.format("predict failed: %s\n", e.toString()); + return null; + } + } + + public PredictFuture asyn_predict( + List> feed_batch, + Iterable fetch, + Boolean need_variant_tag) { + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); + ListenableFuture future = futureStub_.inference(req); + PredictFuture predict_future = new PredictFuture(future, + (InferenceResponse resp) -> { + return Client._staticUnpackInferenceResponse( + resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); + } + ); + return predict_future; + } +} diff --git a/java/src/main/java/io/paddle/serving/client/PredictFuture.java b/java/src/main/java/io/paddle/serving/client/PredictFuture.java new file mode 100644 index 0000000000000000000000000000000000000000..28156d965e76db889358be00ab8c05381e0f89d8 --- /dev/null +++ b/java/src/main/java/io/paddle/serving/client/PredictFuture.java @@ -0,0 +1,54 @@ +package io.paddle.serving.client; + +import java.util.*; +import java.util.function.Function; +import io.grpc.StatusRuntimeException; +import com.google.common.util.concurrent.ListenableFuture; +import org.nd4j.linalg.api.ndarray.INDArray; + +import io.paddle.serving.client.Client; +import io.paddle.serving.grpc.*; + +public class PredictFuture { + private ListenableFuture callFuture_; + private Function>> callBackFunc_; + + PredictFuture(ListenableFuture call_future, + Function>> call_back_func) { + callFuture_ = call_future; + callBackFunc_ = call_back_func; + } + + public Map get() { + InferenceResponse resp = null; + try { + resp = callFuture_.get(); + } catch (Exception e) { + System.out.format("predict failed: %s\n", e.toString()); + return null; + } + Map> ensemble_result + = callBackFunc_.apply(resp); + List>> list + = new ArrayList>>( + ensemble_result.entrySet()); + if (list.size() != 1) { + System.out.format("predict failed: please use get_ensemble impl.\n"); + return null; + } + return list.get(0).getValue(); + } + + public Map> ensemble_get() { + InferenceResponse resp = null; + try { + resp = callFuture_.get(); + } catch (Exception e) { + System.out.format("predict failed: %s\n", e.toString()); + return null; + } + return callBackFunc_.apply(resp); + } +} diff --git a/java/src/main/proto/general_model_config.proto b/java/src/main/proto/general_model_config.proto new file mode 100644 index 0000000000000000000000000000000000000000..03cff3f8c1ab4a369f132d64d7e4f2c871ebb077 --- /dev/null +++ b/java/src/main/proto/general_model_config.proto @@ -0,0 +1,40 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +option java_multiple_files = true; +option java_package = "io.paddle.serving.configure"; +option java_outer_classname = "ConfigureProto"; + +package paddle.serving.configure; + +message FeedVar { + optional string name = 1; + optional string alias_name = 2; + optional bool is_lod_tensor = 3 [ default = false ]; + optional int32 feed_type = 4 [ default = 0 ]; + repeated int32 shape = 5; +} +message FetchVar { + optional string name = 1; + optional string alias_name = 2; + optional bool is_lod_tensor = 3 [ default = false ]; + optional int32 fetch_type = 4 [ default = 0 ]; + repeated int32 shape = 5; +} +message GeneralModelConfig { + repeated FeedVar feed_var = 1; + repeated FetchVar fetch_var = 2; +}; diff --git a/java/src/main/proto/multi_lang_general_model_service.proto b/java/src/main/proto/multi_lang_general_model_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..b83450aed666b96de324050d53b10c56e059a8d5 --- /dev/null +++ b/java/src/main/proto/multi_lang_general_model_service.proto @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +option java_multiple_files = true; +option java_package = "io.paddle.serving.grpc"; +option java_outer_classname = "ServingProto"; + +message Tensor { + optional bytes data = 1; + repeated int32 int_data = 2; + repeated int64 int64_data = 3; + repeated float float_data = 4; + optional int32 elem_type = 5; + repeated int32 shape = 6; + repeated int32 lod = 7; // only for fetch tensor currently +}; + +message FeedInst { repeated Tensor tensor_array = 1; }; + +message FetchInst { repeated Tensor tensor_array = 1; }; + +message InferenceRequest { + repeated FeedInst insts = 1; + repeated string feed_var_names = 2; + repeated string fetch_var_names = 3; + required bool is_python = 4 [ default = false ]; +}; + +message InferenceResponse { + repeated ModelOutput outputs = 1; + optional string tag = 2; + required int32 err_code = 3; +}; + +message ModelOutput { + repeated FetchInst insts = 1; + optional string engine_name = 2; +} + +message SetTimeoutRequest { required int32 timeout_ms = 1; } + +message SimpleResponse { required int32 err_code = 1; } + +message GetClientConfigRequest {} + +message GetClientConfigResponse { required string client_config_str = 1; } + +service MultiLangGeneralModelService { + rpc Inference(InferenceRequest) returns (InferenceResponse) {} + rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {} + rpc GetClientConfig(GetClientConfigRequest) + returns (GetClientConfigResponse) {} +}; diff --git a/java/src/main/resources/log4j2.xml b/java/src/main/resources/log4j2.xml new file mode 100644 index 0000000000000000000000000000000000000000..e13b79d3f92acca50cafde874b501513dbdb292f --- /dev/null +++ b/java/src/main/resources/log4j2.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/python/examples/criteo_ctr_with_cube/README.md b/python/examples/criteo_ctr_with_cube/README.md index 6dc81f3a6b0f98017e0c5b45234f8f348c5f75ce..493b3d72c1fff9275c2a99cfee45efd4bef1af4c 100755 --- a/python/examples/criteo_ctr_with_cube/README.md +++ b/python/examples/criteo_ctr_with_cube/README.md @@ -45,7 +45,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data CPU :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz -Model :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/ctr_criteo_with_cube/network_conf.py) +Model :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py) server core/thread num : 4/8 diff --git a/python/examples/criteo_ctr_with_cube/README_CN.md b/python/examples/criteo_ctr_with_cube/README_CN.md index 97d5629170f3a65dabb104c3764a55ba08051bc5..7a0eb43c203aafeb38b64d249954cdabf7bf7a38 100644 --- a/python/examples/criteo_ctr_with_cube/README_CN.md +++ b/python/examples/criteo_ctr_with_cube/README_CN.md @@ -43,7 +43,7 @@ python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data 设备 :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz -模型 :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/ctr_criteo_with_cube/network_conf.py) +模型 :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py) server core/thread num : 4/8 diff --git a/python/examples/criteo_ctr_with_cube/test_server_quant.py b/python/examples/criteo_ctr_with_cube/test_server_quant.py index fc278f755126cdeb204644cbc91838b1b038379e..38a3fe67da803d1c84337d64e3421d8295ac5767 100755 --- a/python/examples/criteo_ctr_with_cube/test_server_quant.py +++ b/python/examples/criteo_ctr_with_cube/test_server_quant.py @@ -33,5 +33,9 @@ server = Server() server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_num_threads(4) server.load_model_config(sys.argv[1]) -server.prepare_server(workdir="work_dir1", port=9292, device="cpu") +server.prepare_server( + workdir="work_dir1", + port=9292, + device="cpu", + cube_conf="./cube/conf/cube.conf") server.run_server() diff --git a/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py b/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py index feca75b077d737a614bdfd955b7bf0d82ed08529..2fd9308454b4caa862e7d83ddadb48279bba7167 100755 --- a/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py +++ b/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py @@ -33,5 +33,9 @@ server = Server() server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_num_threads(4) server.load_model_config(sys.argv[1], sys.argv[2]) -server.prepare_server(workdir="work_dir1", port=9292, device="cpu") +server.prepare_server( + workdir="work_dir1", + port=9292, + device="cpu", + cube_conf="./cube/conf/cube.conf") server.run_server() diff --git a/python/examples/grpc_impl_example/imdb/get_data.sh b/python/examples/grpc_impl_example/imdb/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..81d8d5d3b018f133c41e211d1501cf3cd9a3d8a4 --- /dev/null +++ b/python/examples/grpc_impl_example/imdb/get_data.sh @@ -0,0 +1,4 @@ +wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz +tar -zxvf text_classification_data.tar.gz +tar -zxvf imdb_model.tar.gz diff --git a/python/examples/grpc_impl_example/imdb/imdb_reader.py b/python/examples/grpc_impl_example/imdb/imdb_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ef3e163a50b0dc244ac2653df1e38d7f91699b --- /dev/null +++ b/python/examples/grpc_impl_example/imdb/imdb_reader.py @@ -0,0 +1,92 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import sys +import os +import paddle +import re +import paddle.fluid.incubate.data_generator as dg + +py_version = sys.version_info[0] + + +class IMDBDataset(dg.MultiSlotDataGenerator): + def load_resource(self, dictfile): + self._vocab = {} + wid = 0 + if py_version == 2: + with open(dictfile) as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + else: + with open(dictfile, encoding="utf-8") as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + self._unk_id = len(self._vocab) + self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))') + self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0]) + + def get_words_only(self, line): + sent = line.lower().replace("
", " ").strip() + words = [x for x in self._pattern.split(sent) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + return feas + + def get_words_and_label(self, line): + send = '|'.join(line.split('|')[:-1]).lower().replace("
", + " ").strip() + label = [int(line.split('|')[-1])] + + words = [x for x in self._pattern.split(send) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + return feas, label + + def infer_reader(self, infer_filelist, batch, buf_size): + def local_iter(): + for fname in infer_filelist: + with open(fname, "r") as fin: + for line in fin: + feas, label = self.get_words_and_label(line) + yield feas, label + + import paddle + batch_iter = paddle.batch( + paddle.reader.shuffle( + local_iter, buf_size=buf_size), + batch_size=batch) + return batch_iter + + def generate_sample(self, line): + def memory_iter(): + for i in range(1000): + yield self.return_value + + def data_iter(): + feas, label = self.get_words_and_label(line) + yield ("words", feas), ("label", label) + + return data_iter + + +if __name__ == "__main__": + imdb = IMDBDataset() + imdb.load_resource("imdb.vocab") + imdb.run_from_stdin() diff --git a/python/examples/imdb/test_multilang_ensemble_client.py b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py similarity index 95% rename from python/examples/imdb/test_multilang_ensemble_client.py rename to python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py index 6686d4c8c38d6a17cb9c5701abf7d76773031772..43034e49bde4a477c160c5a0d158ea541d633a4d 100644 --- a/python/examples/imdb/test_multilang_ensemble_client.py +++ b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py @@ -34,4 +34,6 @@ for i in range(3): fetch = ["prediction"] fetch_maps = client.predict(feed=feed, fetch=fetch) for model, fetch_map in fetch_maps.items(): + if model == "serving_status_code": + continue print("step: {}, model: {}, res: {}".format(i, model, fetch_map)) diff --git a/python/examples/imdb/test_multilang_ensemble_server.py b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_server.py similarity index 100% rename from python/examples/imdb/test_multilang_ensemble_server.py rename to python/examples/grpc_impl_example/imdb/test_multilang_ensemble_server.py diff --git a/python/examples/grpc_impl_example/yolov4/000000570688.jpg b/python/examples/grpc_impl_example/yolov4/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/python/examples/grpc_impl_example/yolov4/000000570688.jpg differ diff --git a/python/examples/grpc_impl_example/yolov4/README.md b/python/examples/grpc_impl_example/yolov4/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a04215dcf349b0e589819db16d53b3435bd904ff --- /dev/null +++ b/python/examples/grpc_impl_example/yolov4/README.md @@ -0,0 +1,23 @@ +# Yolov4 Detection Service + +([简体中文](README_CN.md)|English) + +## Get Model + +``` +python -m paddle_serving_app.package --get_model yolov4 +tar -xzvf yolov4.tar.gz +``` + +## Start RPC Service + +``` +python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang +``` + +## Prediction + +``` +python test_client.py 000000570688.jpg +``` +After the prediction is completed, a json file to save the prediction result and a picture with the detection result box will be generated in the `./outpu folder. diff --git a/python/examples/grpc_impl_example/yolov4/README_CN.md b/python/examples/grpc_impl_example/yolov4/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..de7a85b59ccdf831337083b8d6047bfe41525220 --- /dev/null +++ b/python/examples/grpc_impl_example/yolov4/README_CN.md @@ -0,0 +1,24 @@ +# Yolov4 检测服务 + +(简体中文|[English](README.md)) + +## 获取模型 + +``` +python -m paddle_serving_app.package --get_model yolov4 +tar -xzvf yolov4.tar.gz +``` + +## 启动RPC服务 + +``` +python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang +``` + +## 预测 + +``` +python test_client.py 000000570688.jpg +``` + +预测完成会在`./output`文件夹下生成保存预测结果的json文件以及标出检测结果框的图片。 diff --git a/python/examples/grpc_impl_example/yolov4/label_list.txt b/python/examples/grpc_impl_example/yolov4/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..941cb4e1392266f6a6c09b1fdc5f79503b2e5df6 --- /dev/null +++ b/python/examples/grpc_impl_example/yolov4/label_list.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/python/examples/grpc_impl_example/yolov4/test_client.py b/python/examples/grpc_impl_example/yolov4/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a55763880f7852f0297d7e6c7f44f8c3a206dc60 --- /dev/null +++ b/python/examples/grpc_impl_example/yolov4/test_client.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from paddle_serving_client import MultiLangClient as Client +from paddle_serving_app.reader import * +import cv2 + +preprocess = Sequential([ + File2Image(), BGR2RGB(), Resize( + (608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( + (2, 0, 1)) +]) + +postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) +client = Client() +client.connect(['127.0.0.1:9393']) +# client.set_rpc_timeout_ms(10000) + +im = preprocess(sys.argv[1]) +fetch_map = client.predict( + feed={ + "image": im, + "im_size": np.array(list(im.shape[1:])), + }, + fetch=["save_infer_model/scale_0.tmp_0"]) +fetch_map.pop("serving_status_code") +fetch_map["image"] = sys.argv[1] +postprocess(fetch_map) diff --git a/python/examples/ocr/README.md b/python/examples/ocr/README.md index 3535ed80eb27291aa4da4bb2683923c9e4082acf..ca9bbabdf57ce95763b25fec3751a85e4c8f9401 100644 --- a/python/examples/ocr/README.md +++ b/python/examples/ocr/README.md @@ -1,5 +1,7 @@ # OCR +(English|[简体中文](./README_CN.md)) + ## Get Model ``` python -m paddle_serving_app.package --get_model ocr_rec @@ -8,38 +10,78 @@ python -m paddle_serving_app.package --get_model ocr_det tar -xzvf ocr_det.tar.gz ``` -## RPC Service +## Get Dataset (Optional) +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar +tar xf test_imgs.tar +``` + +## Web Service ### Start Service -For the following two code block, please check your devices and pick one -for GPU device ``` -python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0 python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +python ocr_web_server.py ``` -for CPU device + +### Client Prediction ``` -python -m paddle_serving_server.serve --model ocr_rec_model --port 9292 -python -m paddle_serving_server.serve --model ocr_det_model --port 9293 +python ocr_web_client.py ``` +If you want a faster web service, please try Web Debugger Service -### Client Prediction +## Web Debugger Service +``` +python ocr_debugger_server.py +``` +## Web Debugger Client Prediction ``` -python ocr_rpc_client.py +python ocr_web_client.py ``` -## Web Service +## Benchmark -### Start Service +CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40 + +GPU: Nvidia Tesla V100 * 1 + +Dataset: RCTW 500 sample images + +| engine | client read image(ms) | client-server tras time(ms) | server read image(ms) | det pre(ms) | det infer(ms) | det post(ms) | rec pre(ms) | rec infer(ms) | rec post(ms) | server-client trans time(ms) | server side time consumption(ms) | server side overhead(ms) | total time(ms) | +|------------------------------|----------------|----------------------------|------------------|--------------------|------------------|--------------------|--------------------|------------------|--------------------|--------------------------|--------------------|--------------|---------------| +| Serving web service | 8.69 | 13.41 | 109.97 | 2.82 | 87.76 | 4.29 | 3.98 | 78.51 | 3.66 | 4.12 | 181.02 | 136.49 | 317.51 | +| Serving Debugger web service | 8.73 | 16.42 | 115.27 | 2.93 | 20.63 | 3.97 | 4.48 | 13.84 | 3.60 | 6.91 | 49.45 | 147.33 | 196.78 | + +## Appendix: Det or Rec only +if you are going to detect images not recognize it or directly recognize the words from images. We also provide Det and Rec server for you. + +### Det Server ``` -python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 -python ocr_web_server.py +python det_web_server.py +#or +python det_debugger_server.py ``` -### Client Prediction +### Det Client + +``` +# also use ocr_web_client.py +python ocr_web_client.py +``` + +### Rec Server + +``` +python rec_web_server.py +#or +python rec_debugger_server.py +``` + +### Rec Client + ``` -sh ocr_web_client.sh +python rec_web_client.py ``` diff --git a/python/examples/ocr/README_CN.md b/python/examples/ocr/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..65bc066a43a34d1a35cb4236473c571106c5f61b --- /dev/null +++ b/python/examples/ocr/README_CN.md @@ -0,0 +1,93 @@ +# OCR 服务 + +([English](./README.md)|简体中文) + +## 获取模型 +``` +python -m paddle_serving_app.package --get_model ocr_rec +tar -xzvf ocr_rec.tar.gz +python -m paddle_serving_app.package --get_model ocr_det +tar -xzvf ocr_det.tar.gz +``` +## 获取数据集(可选) +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar +tar xf test_imgs.tar +``` + +### 客户端预测 + +``` +python ocr_rpc_client.py +``` + +## Web Service服务 + +### 启动服务 + +``` +python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 +python ocr_web_server.py +``` + +### 启动客户端 +``` +python ocr_web_client.py +``` + +如果用户需要更快的执行速度,请尝试Debugger版Web服务 +## 启动Debugger版Web服务 +``` +python ocr_debugger_server.py +``` + +## 启动客户端 +``` +python ocr_web_client.py +``` + +## 性能指标 + +CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40 + +GPU: Nvidia Tesla V100单卡 + +数据集:RCTW 500张测试数据集 + +| engine | 客户端读图(ms) | 客户端发送请求到服务端(ms) | 服务端读图(ms) | 检测预处理耗时(ms) | 检测模型耗时(ms) | 检测后处理耗时(ms) | 识别预处理耗时(ms) | 识别模型耗时(ms) | 识别后处理耗时(ms) | 服务端回传客户端时间(ms) | 服务端整体耗时(ms) | 空跑耗时(ms) | 整体耗时(ms) | +|------------------------------|----------------|----------------------------|------------------|--------------------|------------------|--------------------|--------------------|------------------|--------------------|--------------------------|--------------------|--------------|---------------| +| Serving web service | 8.69 | 13.41 | 109.97 | 2.82 | 87.76 | 4.29 | 3.98 | 78.51 | 3.66 | 4.12 | 181.02 | 136.49 | 317.51 | +| Serving Debugger web service | 8.73 | 16.42 | 115.27 | 2.93 | 20.63 | 3.97 | 4.48 | 13.84 | 3.60 | 6.91 | 49.45 | 147.33 | 196.78 | + + +## 附录: 检测/识别单服务启动 +如果您想单独启动检测或者识别服务,我们也提供了启动单服务的代码 + +### 启动检测服务 + +``` +python det_web_server.py +#or +python det_debugger_server.py +``` + +### 检测服务客户端 + +``` +# also use ocr_web_client.py +python ocr_web_client.py +``` + +### 启动识别服务 + +``` +python rec_web_server.py +#or +python rec_debugger_server.py +``` + +### 识别服务客户端 + +``` +python rec_web_client.py +``` diff --git a/python/examples/ocr/det_debugger_server.py b/python/examples/ocr/det_debugger_server.py new file mode 100644 index 0000000000000000000000000000000000000000..acfccdb6d24a7e1ba490705dd147f21dbf921d31 --- /dev/null +++ b/python/examples/ocr/det_debugger_server.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re +import base64 + + +class OCRService(WebService): + def init_det(self): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.filter_func = FilterBoxes(10, 10) + self.post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + def preprocess(self, feed=[], fetch=[]): + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + self.ori_h, self.ori_w, _ = im.shape + det_img = self.det_preprocess(im) + _, self.new_h, self.new_w = det_img.shape + return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"] + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + det_out = fetch_map["concat_1.tmp_0"] + ratio_list = [ + float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w + ] + dt_boxes_list = self.post_func(det_out, [ratio_list]) + dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) + return {"dt_boxes": dt_boxes.tolist()} + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_det_model") +ocr_service.set_gpus("0") +ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) +ocr_service.init_det() +ocr_service.run_debugger_service() +ocr_service.run_web_service() diff --git a/python/examples/ocr/det_web_server.py b/python/examples/ocr/det_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..dd69be0c70eb0f4dd627aa47ad33045a204f78c0 --- /dev/null +++ b/python/examples/ocr/det_web_server.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re +import base64 + + +class OCRService(WebService): + def init_det(self): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.filter_func = FilterBoxes(10, 10) + self.post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + + def preprocess(self, feed=[], fetch=[]): + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + self.ori_h, self.ori_w, _ = im.shape + det_img = self.det_preprocess(im) + _, self.new_h, self.new_w = det_img.shape + print(det_img) + return {"image": det_img}, ["concat_1.tmp_0"] + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + det_out = fetch_map["concat_1.tmp_0"] + ratio_list = [ + float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w + ] + dt_boxes_list = self.post_func(det_out, [ratio_list]) + dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w]) + return {"dt_boxes": dt_boxes.tolist()} + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_det_model") +ocr_service.set_gpus("0") +ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) +ocr_service.init_det() +ocr_service.run_rpc_service() +ocr_service.run_web_service() diff --git a/python/examples/ocr/imgs/1.jpg b/python/examples/ocr/imgs/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..08010177fed2ee8c3709912c06c0b161ba546313 Binary files /dev/null and b/python/examples/ocr/imgs/1.jpg differ diff --git a/python/examples/ocr/ocr_debugger_server.py b/python/examples/ocr/ocr_debugger_server.py new file mode 100644 index 0000000000000000000000000000000000000000..93e2d7a3d1dc64451774ecf790c2ebd3b39f1d91 --- /dev/null +++ b/python/examples/ocr/ocr_debugger_server.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes +from paddle_serving_server_gpu.web_service import WebService +from paddle_serving_app.local_predict import Debugger +import time +import re +import base64 + + +class OCRService(WebService): + def init_det_debugger(self, det_model_config): + self.det_preprocess = Sequential([ + ResizeByFactor(32, 960), Div(255), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( + (2, 0, 1)) + ]) + self.det_client = Debugger() + self.det_client.load_model_config( + det_model_config, gpu=True, profile=False) + self.ocr_reader = OCRReader() + + def preprocess(self, feed=[], fetch=[]): + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + ori_h, ori_w, _ = im.shape + det_img = self.det_preprocess(im) + _, new_h, new_w = det_img.shape + det_img = det_img[np.newaxis, :] + det_img = det_img.copy() + det_out = self.det_client.predict( + feed={"image": det_img}, fetch=["concat_1.tmp_0"]) + filter_func = FilterBoxes(10, 10) + post_func = DBPostProcess({ + "thresh": 0.3, + "box_thresh": 0.5, + "max_candidates": 1000, + "unclip_ratio": 1.5, + "min_size": 3 + }) + sorted_boxes = SortedBoxes() + ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] + dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) + dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) + dt_boxes = sorted_boxes(dt_boxes) + get_rotate_crop_image = GetRotateCropImage() + img_list = [] + max_wh_ratio = 0 + for i, dtbox in enumerate(dt_boxes): + boximg = get_rotate_crop_image(im, dt_boxes[i]) + img_list.append(boximg) + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + if len(img_list) == 0: + return [], [] + _, w, h = self.ocr_reader.resize_norm_img(img_list[0], + max_wh_ratio).shape + imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') + for id, img in enumerate(img_list): + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) + imgs[id] = norm_img + feed = {"image": imgs.copy()} + fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + res = {"res": res_lst} + return res + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_rec_model") +ocr_service.prepare_server(workdir="workdir", port=9292) +ocr_service.init_det_debugger(det_model_config="ocr_det_model") +ocr_service.run_debugger_service(gpu=True) +ocr_service.run_web_service() diff --git a/python/examples/ocr/ocr_rpc_client.py b/python/examples/ocr/ocr_rpc_client.py deleted file mode 100644 index 212d46c2b226f91bcb0582e76e31ca2acdc8b948..0000000000000000000000000000000000000000 --- a/python/examples/ocr/ocr_rpc_client.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle_serving_client import Client -from paddle_serving_app.reader import OCRReader -import cv2 -import sys -import numpy as np -import os -from paddle_serving_client import Client -from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor -from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes -import time -import re - - -def sorted_boxes(dt_boxes): - """ - Sort text boxes in order from top to bottom, left to right - args: - dt_boxes(array):detected text boxes with shape [4, 2] - return: - sorted boxes(array) with shape [4, 2] - """ - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - - for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp - return _boxes - - -def get_rotate_crop_image(img, points): - #img = cv2.imread(img) - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0], \ - [img_crop_width, img_crop_height], [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img_crop, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - - -def read_det_box_file(filename): - with open(filename, 'r') as f: - line = f.readline() - a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int( - line.split(' ')[2]) - dt_boxes = np.zeros((a, b, c)).astype(np.float32) - line = f.readline() - for i in range(a): - for j in range(b): - line = f.readline() - dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float( - line.split(' ')[0]), float(line.split(' ')[1]) - line = f.readline() - - -def resize_norm_img(img, max_wh_ratio): - import math - imgC, imgH, imgW = 3, 32, 320 - imgW = int(32 * max_wh_ratio) - h = img.shape[0] - w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - return padding_im - - -def main(): - client1 = Client() - client1.load_client_config("ocr_det_client/serving_client_conf.prototxt") - client1.connect(["127.0.0.1:9293"]) - - client2 = Client() - client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt") - client2.connect(["127.0.0.1:9292"]) - - read_image_file = File2Image() - preprocess = Sequential([ - ResizeByFactor(32, 960), Div(255), - Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( - (2, 0, 1)) - ]) - post_func = DBPostProcess({ - "thresh": 0.3, - "box_thresh": 0.5, - "max_candidates": 1000, - "unclip_ratio": 1.5, - "min_size": 3 - }) - - filter_func = FilterBoxes(10, 10) - ocr_reader = OCRReader() - files = [ - "./imgs/{}".format(f) for f in os.listdir('./imgs') - if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f) - ] - #files = ["2.jpg"]*30 - #files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)] - time_all = 0 - time_det_all = 0 - time_rec_all = 0 - for name in files: - #print(name) - im = read_image_file(name) - ori_h, ori_w, _ = im.shape - time1 = time.time() - img = preprocess(im) - _, new_h, new_w = img.shape - ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] - #print(new_h, new_w, ori_h, ori_w) - time_before_det = time.time() - outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"]) - time_after_det = time.time() - time_det_all += (time_after_det - time_before_det) - #print(outputs) - dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list]) - dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) - dt_boxes = sorted_boxes(dt_boxes) - feed_list = [] - img_list = [] - max_wh_ratio = 0 - for i, dtbox in enumerate(dt_boxes): - boximg = get_rotate_crop_image(im, dt_boxes[i]) - img_list.append(boximg) - h, w = boximg.shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - for img in img_list: - norm_img = resize_norm_img(img, max_wh_ratio) - #norm_img = norm_img[np.newaxis, :] - feed = {"image": norm_img} - feed_list.append(feed) - #fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] - fetch = ["ctc_greedy_decoder_0.tmp_0"] - time_before_rec = time.time() - if len(feed_list) == 0: - continue - fetch_map = client2.predict(feed=feed_list, fetch=fetch) - time_after_rec = time.time() - time_rec_all += (time_after_rec - time_before_rec) - rec_res = ocr_reader.postprocess(fetch_map) - #for res in rec_res: - # print(res[0].encode("utf-8")) - time2 = time.time() - time_all += (time2 - time1) - print("rpc+det time: {}".format(time_all / len(files))) - - -if __name__ == '__main__': - main() diff --git a/python/examples/ocr/ocr_web_client.py b/python/examples/ocr/ocr_web_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3d25e288dd93014cf9c3f84edc01d42c013ba2d9 --- /dev/null +++ b/python/examples/ocr/ocr_web_client.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: utf-8 -*- + +import requests +import json +import cv2 +import base64 +import os, sys +import time + + +def cv2_to_base64(image): + #data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(image).decode( + 'utf8') #data.tostring()).decode('utf8') + + +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:9292/ocr/prediction" +test_img_dir = "imgs/" +for img_file in os.listdir(test_img_dir): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data1 = file.read() + image = cv2_to_base64(image_data1) + data = {"feed": [{"image": image}], "fetch": ["res"]} + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(r.json()) diff --git a/python/examples/ocr/ocr_web_client.sh b/python/examples/ocr/ocr_web_client.sh deleted file mode 100644 index 5f4f1d7d1fb00dc63b3235533850f56f998a647f..0000000000000000000000000000000000000000 --- a/python/examples/ocr/ocr_web_client.sh +++ /dev/null @@ -1 +0,0 @@ - curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction diff --git a/python/examples/ocr/ocr_web_server.py b/python/examples/ocr/ocr_web_server.py index b55027d84252f8590f1e62839ad8cbd25e56c8fe..d017f6b9b560dc82158641b9f3a9f80137b40716 100644 --- a/python/examples/ocr/ocr_web_server.py +++ b/python/examples/ocr/ocr_web_server.py @@ -21,10 +21,11 @@ import os from paddle_serving_client import Client from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor from paddle_serving_app.reader import Div, Normalize, Transpose -from paddle_serving_app.reader import DBPostProcess, FilterBoxes +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from paddle_serving_server_gpu.web_service import WebService import time import re +import base64 class OCRService(WebService): @@ -37,74 +38,16 @@ class OCRService(WebService): self.det_client = Client() self.det_client.load_client_config(det_client_config) self.det_client.connect(["127.0.0.1:{}".format(det_port)]) + self.ocr_reader = OCRReader() def preprocess(self, feed=[], fetch=[]): - img_url = feed[0]["image"] - #print(feed, img_url) - read_from_url = URL2Image() - im = read_from_url(img_url) + data = base64.b64decode(feed[0]["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) ori_h, ori_w, _ = im.shape det_img = self.det_preprocess(im) - #print("det_img", det_img, det_img.shape) det_out = self.det_client.predict( feed={"image": det_img}, fetch=["concat_1.tmp_0"]) - - #print("det_out", det_out) - def sorted_boxes(dt_boxes): - num_boxes = dt_boxes.shape[0] - sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) - _boxes = list(sorted_boxes) - for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp - return _boxes - - def get_rotate_crop_image(img, points): - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0], \ - [img_crop_width, img_crop_height], [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img_crop, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - - def resize_norm_img(img, max_wh_ratio): - import math - imgC, imgH, imgW = 3, 32, 320 - imgW = int(32 * max_wh_ratio) - h = img.shape[0] - w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - return padding_im - _, new_h, new_w = det_img.shape filter_func = FilterBoxes(10, 10) post_func = DBPostProcess({ @@ -114,10 +57,12 @@ class OCRService(WebService): "unclip_ratio": 1.5, "min_size": 3 }) + sorted_boxes = SortedBoxes() ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) dt_boxes = sorted_boxes(dt_boxes) + get_rotate_crop_image = GetRotateCropImage() feed_list = [] img_list = [] max_wh_ratio = 0 @@ -128,29 +73,25 @@ class OCRService(WebService): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for img in img_list: - norm_img = resize_norm_img(img, max_wh_ratio) + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) feed = {"image": norm_img} feed_list.append(feed) - fetch = ["ctc_greedy_decoder_0.tmp_0"] - #print("feed_list", feed_list) + fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] return feed_list, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): - #print(fetch_map) - ocr_reader = OCRReader() - rec_res = ocr_reader.postprocess(fetch_map) + rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) res_lst = [] for res in rec_res: res_lst.append(res[0]) - fetch_map["res"] = res_lst - del fetch_map["ctc_greedy_decoder_0.tmp_0"] - del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"] - return fetch_map + res = {"res": res_lst} + return res ocr_service = OCRService(name="ocr") ocr_service.load_model_config("ocr_rec_model") -ocr_service.prepare_server(workdir="workdir", port=9292) +ocr_service.set_gpus("0") +ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) ocr_service.init_det_client( det_port=9293, det_client_config="ocr_det_client/serving_client_conf.prototxt") diff --git a/python/examples/ocr/rec_debugger_server.py b/python/examples/ocr/rec_debugger_server.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe67aafee5c8dcae269cd4ad6f6100ed514f0b7 --- /dev/null +++ b/python/examples/ocr/rec_debugger_server.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re +import base64 + + +class OCRService(WebService): + def init_rec(self): + self.ocr_reader = OCRReader() + + def preprocess(self, feed=[], fetch=[]): + img_list = [] + for feed_data in feed: + data = base64.b64decode(feed_data["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + img_list.append(im) + max_wh_ratio = 0 + for i, boximg in enumerate(img_list): + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + _, w, h = self.ocr_reader.resize_norm_img(img_list[0], + max_wh_ratio).shape + imgs = np.zeros((len(img_list), 3, w, h)).astype('float32') + for i, img in enumerate(img_list): + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) + imgs[i] = norm_img + feed = {"image": imgs.copy()} + fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + return feed, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + res = {"res": res_lst} + return res + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_rec_model") +ocr_service.set_gpus("0") +ocr_service.init_rec() +ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) +ocr_service.run_debugger_service() +ocr_service.run_web_service() diff --git a/python/examples/ocr/rec_img/ch_doc3.jpg b/python/examples/ocr/rec_img/ch_doc3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0c2053643c6211b9c2017e305c5fa05bba0cc66 Binary files /dev/null and b/python/examples/ocr/rec_img/ch_doc3.jpg differ diff --git a/python/examples/ocr/rec_web_client.py b/python/examples/ocr/rec_web_client.py new file mode 100644 index 0000000000000000000000000000000000000000..312a2148886d6f084a1c077d84e907cb28c0652a --- /dev/null +++ b/python/examples/ocr/rec_web_client.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: utf-8 -*- + +import requests +import json +import cv2 +import base64 +import os, sys +import time + + +def cv2_to_base64(image): + #data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(image).decode( + 'utf8') #data.tostring()).decode('utf8') + + +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:9292/ocr/prediction" +test_img_dir = "rec_img/" + +for img_file in os.listdir(test_img_dir): + with open(os.path.join(test_img_dir, img_file), 'rb') as file: + image_data1 = file.read() + image = cv2_to_base64(image_data1) + #data = {"feed": [{"image": image}], "fetch": ["res"]} + data = {"feed": [{"image": image}] * 3, "fetch": ["res"]} + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(r.json()) diff --git a/python/examples/ocr/rec_web_server.py b/python/examples/ocr/rec_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..684c313d4d50cfe00c576c81aad05a810525dcce --- /dev/null +++ b/python/examples/ocr/rec_web_server.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import OCRReader +import cv2 +import sys +import numpy as np +import os +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor +from paddle_serving_app.reader import Div, Normalize, Transpose +from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes +from paddle_serving_server_gpu.web_service import WebService +import time +import re +import base64 + + +class OCRService(WebService): + def init_rec(self): + self.ocr_reader = OCRReader() + + def preprocess(self, feed=[], fetch=[]): + # TODO: to handle batch rec images + img_list = [] + for feed_data in feed: + data = base64.b64decode(feed_data["image"].encode('utf8')) + data = np.fromstring(data, np.uint8) + im = cv2.imdecode(data, cv2.IMREAD_COLOR) + img_list.append(im) + feed_list = [] + max_wh_ratio = 0 + for i, boximg in enumerate(img_list): + h, w = boximg.shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for img in img_list: + norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) + feed = {"image": norm_img} + feed_list.append(feed) + fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] + return feed_list, fetch + + def postprocess(self, feed={}, fetch=[], fetch_map=None): + rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) + res_lst = [] + for res in rec_res: + res_lst.append(res[0]) + res = {"res": res_lst} + return res + + +ocr_service = OCRService(name="ocr") +ocr_service.load_model_config("ocr_rec_model") +ocr_service.set_gpus("0") +ocr_service.init_rec() +ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) +ocr_service.run_rpc_service() +ocr_service.run_web_service() diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index 93039c6fdd467357b589bbb2889f3c2d3208b538..afe6d474b5382a2fe74f95adf2fed34faa28937b 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -70,9 +70,10 @@ class Debugger(object): config.enable_use_gpu(100, 0) if profile: config.enable_profile() + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.set_cpu_math_library_num_threads(cpu_num) config.switch_ir_optim(False) - + config.switch_use_feed_fetch_ops(False) self.predictor = create_paddle_predictor(config) def predict(self, feed=None, fetch=None): @@ -113,20 +114,30 @@ class Debugger(object): "Fetch names should not be empty or out of saved fetch list.") return {} - inputs = [] - for name in self.feed_names_: + input_names = self.predictor.get_input_names() + for name in input_names: if isinstance(feed[name], list): feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ name]) - if self.feed_types_[name] == 0: - feed[name] = feed[name].astype("int64") - else: - feed[name] = feed[name].astype("float32") - inputs.append(PaddleTensor(feed[name][np.newaxis, :])) - - outputs = self.predictor.run(inputs) + if self.feed_types_[name] == 0: + feed[name] = feed[name].astype("int64") + else: + feed[name] = feed[name].astype("float32") + input_tensor = self.predictor.get_input_tensor(name) + input_tensor.copy_from_cpu(feed[name]) + output_tensors = [] + output_names = self.predictor.get_output_names() + for output_name in output_names: + output_tensor = self.predictor.get_output_tensor(output_name) + output_tensors.append(output_tensor) + outputs = [] + self.predictor.zero_copy_run() + for output_tensor in output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) fetch_map = {} - for name in fetch: - fetch_map[name] = outputs[self.fetch_names_to_idx_[ - name]].as_ndarray() + for i, name in enumerate(fetch): + fetch_map[name] = outputs[i] + if len(output_tensors[i].lod()) > 0: + fetch_map[name + ".lod"] = output_tensors[i].lod()[0] return fetch_map diff --git a/python/paddle_serving_app/reader/__init__.py b/python/paddle_serving_app/reader/__init__.py index e15a93084cbd437531129b48b51fe852ce17d19b..93e2cd76102d93f52955060055afda34f9576ed8 100644 --- a/python/paddle_serving_app/reader/__init__.py +++ b/python/paddle_serving_app/reader/__init__.py @@ -15,7 +15,7 @@ from .chinese_bert_reader import ChineseBertReader from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import RCNNPostprocess, SegPostprocess, PadStride -from .image_reader import DBPostProcess, FilterBoxes +from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from .lac_reader import LACReader from .senta_reader import SentaReader from .imdb_reader import IMDBDataset diff --git a/python/paddle_serving_app/reader/image_reader.py b/python/paddle_serving_app/reader/image_reader.py index 4f747df1f74800cf692bb22171466bffb7c598b0..50c0753c27f845e784676b54ae7e029bec2a4ec4 100644 --- a/python/paddle_serving_app/reader/image_reader.py +++ b/python/paddle_serving_app/reader/image_reader.py @@ -797,6 +797,59 @@ class Transpose(object): return format_string +class SortedBoxes(object): + """ + Sorted bounding boxes from Detection + """ + + def __init__(self): + pass + + def __call__(self, dt_boxes): + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + +class GetRotateCropImage(object): + """ + Rotate and Crop image from OCR Det output + """ + + def __init__(self): + pass + + def __call__(self, img, points): + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + img_crop_width = int(np.linalg.norm(points[0] - points[1])) + img_crop_height = int(np.linalg.norm(points[0] - points[3])) + pts_std = np.float32([[0, 0], [img_crop_width, 0], \ + [img_crop_width, img_crop_height], [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img_crop, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + class ImageReader(): def __init__(self, image_shape=[3, 224, 224], diff --git a/python/paddle_serving_app/reader/ocr_reader.py b/python/paddle_serving_app/reader/ocr_reader.py index 72a2918f89a8ccc913894f3f46fab08f51cf9460..68ee72d51a6ed7e36b57186c6ea1b8d9fdb147a9 100644 --- a/python/paddle_serving_app/reader/ocr_reader.py +++ b/python/paddle_serving_app/reader/ocr_reader.py @@ -120,29 +120,21 @@ class CharacterOps(object): class OCRReader(object): - def __init__(self): - args = self.parse_args() - image_shape = [int(v) for v in args.rec_image_shape.split(",")] + def __init__(self, + algorithm="CRNN", + image_shape=[3, 32, 320], + char_type="ch", + batch_num=1, + char_dict_path="./ppocr_keys_v1.txt"): self.rec_image_shape = image_shape - self.character_type = args.rec_char_type - self.rec_batch_num = args.rec_batch_num + self.character_type = char_type + self.rec_batch_num = batch_num char_ops_params = {} - char_ops_params["character_type"] = args.rec_char_type - char_ops_params["character_dict_path"] = args.rec_char_dict_path + char_ops_params["character_type"] = char_type + char_ops_params["character_dict_path"] = char_dict_path char_ops_params['loss_type'] = 'ctc' self.char_ops = CharacterOps(char_ops_params) - def parse_args(self): - parser = argparse.ArgumentParser() - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=1) - parser.add_argument( - "--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt") - return parser.parse_args() - def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape if self.character_type == "ch": @@ -154,15 +146,14 @@ class OCRReader(object): resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) - - seq = Sequential([ - Resize(imgH, resized_w), Transpose((2, 0, 1)), Div(255), - Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], True) - ]) - resized_image = seq(img) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image + padding_im[:, :, 0:resized_w] = resized_image return padding_im def preprocess(self, img_list): @@ -191,11 +182,17 @@ class OCRReader(object): for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] end = rec_idx_lod[rno + 1] - rec_idx_tmp = rec_idx_batch[beg:end, 0] + if isinstance(rec_idx_batch, list): + rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]] + else: #nd array + rec_idx_tmp = rec_idx_batch[beg:end, 0] preds_text = self.char_ops.decode(rec_idx_tmp) if with_score: beg = predict_lod[rno] end = predict_lod[rno + 1] + if isinstance(outputs["softmax_0.tmp_0"], list): + outputs["softmax_0.tmp_0"] = np.array(outputs[ + "softmax_0.tmp_0"]).astype(np.float32) probs = outputs["softmax_0.tmp_0"][beg:end, :] ind = np.argmax(probs, axis=1) blank = probs.shape[1] diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 455bcf62cd039dde69736ec514892856eabd3088..cf669c54f3492fc739bedcfacc49537a5ecc545f 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -399,6 +399,7 @@ class MultiLangClient(object): self.channel_ = None self.stub_ = None self.rpc_timeout_s_ = 2 + self.profile_ = _Profiler() def add_variant(self, tag, cluster, variant_weight): # TODO @@ -520,7 +521,7 @@ class MultiLangClient(object): tensor.float_data.extend( var.reshape(-1).astype('float32').tolist()) elif v_type == 2: - tensor.int32_data.extend( + tensor.int_data.extend( var.reshape(-1).astype('int32').tolist()) else: raise Exception("error tensor value type.") @@ -530,7 +531,7 @@ class MultiLangClient(object): elif v_type == 1: tensor.float_data.extend(self._flatten_list(var)) elif v_type == 2: - tensor.int32_data.extend(self._flatten_list(var)) + tensor.int_data.extend(self._flatten_list(var)) else: raise Exception("error tensor value type.") else: @@ -582,6 +583,7 @@ class MultiLangClient(object): ret = list(multi_result_map.values())[0] else: ret = multi_result_map + ret["serving_status_code"] = 0 return ret if not need_variant_tag else [ret, tag] @@ -601,18 +603,30 @@ class MultiLangClient(object): need_variant_tag=False, asyn=False, is_python=True): - req = self._pack_inference_request(feed, fetch, is_python=is_python) if not asyn: try: + self.profile_.record('py_prepro_0') + req = self._pack_inference_request( + feed, fetch, is_python=is_python) + self.profile_.record('py_prepro_1') + + self.profile_.record('py_client_infer_0') resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_) - return self._unpack_inference_response( + self.profile_.record('py_client_infer_1') + + self.profile_.record('py_postpro_0') + ret = self._unpack_inference_response( resp, fetch, is_python=is_python, need_variant_tag=need_variant_tag) + self.profile_.record('py_postpro_1') + self.profile_.print_profile() + return ret except grpc.RpcError as e: return {"serving_status_code": e.code()} else: + req = self._pack_inference_request(feed, fetch, is_python=is_python) call_future = self.stub_.Inference.future( req, timeout=self.rpc_timeout_s_) return MultiLangPredictFuture( diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py index 1e5fd16ed6c153a28cd72422ca3ef7b9177cb079..678c0583d1e132791a1199e315ea380a4ae3108b 100644 --- a/python/paddle_serving_server/__init__.py +++ b/python/paddle_serving_server/__init__.py @@ -524,7 +524,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. elif v_type == 1: # float32 data = np.array(list(var.float_data), dtype="float32") elif v_type == 2: # int32 - data = np.array(list(var.int32_data), dtype="int32") + data = np.array(list(var.int_data), dtype="int32") else: raise Exception("error type.") data.shape = list(feed_inst.tensor_array[idx].shape) @@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. results, tag = ret resp.tag = tag resp.err_code = 0 + if not self.is_multi_model_: results = {'general_infer_0': results} for model_name, model_result in results.items(): @@ -558,8 +559,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. tensor.float_data.extend(model_result[name].reshape(-1) .tolist()) elif v_type == 2: # int32 - tensor.int32_data.extend(model_result[name].reshape(-1) - .tolist()) + tensor.int_data.extend(model_result[name].reshape(-1) + .tolist()) else: raise Exception("error type.") tensor.shape.extend(list(model_result[name].shape)) diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py index df04cb7840bbacd90ccb7e3c66147a6856b23e02..0261003a7863d11fb342d1572b124d1cbb533a2b 100644 --- a/python/paddle_serving_server_gpu/__init__.py +++ b/python/paddle_serving_server_gpu/__init__.py @@ -571,7 +571,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. elif v_type == 1: # float32 data = np.array(list(var.float_data), dtype="float32") elif v_type == 2: - data = np.array(list(var.int32_data), dtype="int32") + data = np.array(list(var.int_data), dtype="int32") else: raise Exception("error type.") data.shape = list(feed_inst.tensor_array[idx].shape) @@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. results, tag = ret resp.tag = tag resp.err_code = 0 + if not self.is_multi_model_: results = {'general_infer_0': results} for model_name, model_result in results.items(): @@ -605,8 +606,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. tensor.float_data.extend(model_result[name].reshape(-1) .tolist()) elif v_type == 2: # int32 - tensor.int32_data.extend(model_result[name].reshape(-1) - .tolist()) + tensor.int_data.extend(model_result[name].reshape(-1) + .tolist()) else: raise Exception("error type.") tensor.shape.extend(list(model_result[name].shape)) diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index fecc61ffa1f8637fb214cc748fb14c7ce30731ab..6750de86f1750f2ab9dc36eca9d4307f7821e2d8 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -127,9 +127,9 @@ class WebService(object): request.json["fetch"]) if isinstance(feed, dict) and "fetch" in feed: del feed["fetch"] + if len(feed) == 0: + raise ValueError("empty input") fetch_map = self.client.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key].tolist() result = self.postprocess( feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) result = {"result": result} @@ -164,6 +164,33 @@ class WebService(object): self.app_instance = app_instance + # TODO: maybe change another API name: maybe run_local_predictor? + def run_debugger_service(self, gpu=False): + import socket + localIP = socket.gethostbyname(socket.gethostname()) + print("web service address:") + print("http://{}:{}/{}/prediction".format(localIP, self.port, + self.name)) + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + self._launch_local_predictor(gpu) + + service_name = "/" + self.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return self.get_prediction(request) + + self.app_instance = app_instance + + def _launch_local_predictor(self, gpu): + from paddle_serving_app.local_predict import Debugger + self.client = Debugger() + self.client.load_model_config( + "{}".format(self.model_config), gpu=gpu, profile=False) + def run_web_service(self): self.app_instance.run(host="0.0.0.0", port=self.port, diff --git a/tools/Dockerfile.ci b/tools/Dockerfile.ci index 8709075f6cf8f985e346999e76f6b273d7664193..fc52733bb214af918af34069e71b05b2eaa8511e 100644 --- a/tools/Dockerfile.ci +++ b/tools/Dockerfile.ci @@ -1,39 +1,50 @@ FROM centos:7.3.1611 + RUN yum -y install wget >/dev/null \ && yum -y install gcc gcc-c++ make glibc-static which >/dev/null \ && yum -y install git openssl-devel curl-devel bzip2-devel python-devel >/dev/null \ && yum -y install libSM-1.2.2-2.el7.x86_64 --setopt=protected_multilib=false \ && yum -y install libXrender-0.9.10-1.el7.x86_64 --setopt=protected_multilib=false \ - && yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false \ - && wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \ + && yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false + +RUN wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \ && tar xzf cmake-3.2.0-Linux-x86_64.tar.gz \ && mv cmake-3.2.0-Linux-x86_64 /usr/local/cmake3.2.0 \ && echo 'export PATH=/usr/local/cmake3.2.0/bin:$PATH' >> /root/.bashrc \ - && rm cmake-3.2.0-Linux-x86_64.tar.gz \ - && wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ + && rm cmake-3.2.0-Linux-x86_64.tar.gz + +RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ && tar xzf go1.14.linux-amd64.tar.gz \ && mv go /usr/local/go \ && echo 'export GOROOT=/usr/local/go' >> /root/.bashrc \ && echo 'export PATH=/usr/local/go/bin:$PATH' >> /root/.bashrc \ - && rm go1.14.linux-amd64.tar.gz \ - && yum -y install python-devel sqlite-devel >/dev/null \ + && rm go1.14.linux-amd64.tar.gz + +RUN yum -y install python-devel sqlite-devel >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && python get-pip.py >/dev/null \ && pip install google protobuf setuptools wheel flask >/dev/null \ - && rm get-pip.py \ - && wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \ + && rm get-pip.py + +RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \ && yum -y install bzip2 >/dev/null \ && tar -jxf patchelf-0.10.tar.bz2 \ && cd patchelf-0.10 \ && ./configure --prefix=/usr \ && make >/dev/null && make install >/dev/null \ && cd .. \ - && rm -rf patchelf-0.10* \ - && yum install -y python3 python3-devel \ - && pip3 install google protobuf setuptools wheel flask \ - && yum -y update >/dev/null \ + && rm -rf patchelf-0.10* + +RUN yum install -y python3 python3-devel \ + && pip3 install google protobuf setuptools wheel flask + +RUN yum -y update >/dev/null \ && yum -y install dnf >/dev/null \ && yum -y install dnf-plugins-core >/dev/null \ && dnf copr enable alonid/llvm-3.8.0 -y \ && dnf install llvm-3.8.0 clang-3.8.0 compiler-rt-3.8.0 -y \ && echo 'export PATH=/opt/llvm-3.8.0/bin:$PATH' >> /root/.bashrc + +RUN yum install -y java \ + && wget http://repos.fedorapeople.org/repos/dchen/apache-maven/epel-apache-maven.repo -O /etc/yum.repos.d/epel-apache-maven.repo \ + && yum install -y apache-maven diff --git a/tools/serving_build.sh b/tools/serving_build.sh index 175f084d3e29ded9005d7a1e7c39da3e001978c8..d1f11ff78d4d032ef62162b2d2d914d186fda634 100644 --- a/tools/serving_build.sh +++ b/tools/serving_build.sh @@ -182,26 +182,26 @@ function python_test_fit_a_line() { kill_server_process # test web - unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed. - check_cmd "python -m paddle_serving_server_gpu.serve --model uci_housing_model --port 9393 --thread 2 --gpu_ids 0 --name uci > /dev/null &" - sleep 5 # wait for the server to start - check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" + #unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed. + #check_cmd "python -m paddle_serving_server_gpu.serve --model uci_housing_model --port 9393 --thread 2 --gpu_ids 0 --name uci > /dev/null &" + #sleep 5 # wait for the server to start + #check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" # check http code - http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` - if [ ${http_code} -ne 200 ]; then - echo "HTTP status code -ne 200" - exit 1 - fi + #http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` + #if [ ${http_code} -ne 200 ]; then + # echo "HTTP status code -ne 200" + # exit 1 + #fi # test web batch - check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" + #check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {\"x\": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], \"fetch\":[\"price\"]}' http://127.0.0.1:9393/uci/prediction" # check http code - http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` - if [ ${http_code} -ne 200 ]; then - echo "HTTP status code -ne 200" - exit 1 - fi - setproxy # recover proxy state - kill_server_process + #http_code=`curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}, {"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' -s -w "%{http_code}" -o /dev/null http://127.0.0.1:9393/uci/prediction` + #if [ ${http_code} -ne 200 ]; then + # echo "HTTP status code -ne 200" + # exit 1 + #fi + #setproxy # recover proxy state + #kill_server_process ;; *) echo "error type" @@ -499,6 +499,64 @@ function python_test_lac() { cd .. } +function java_run_test() { + # pwd: /Serving + local TYPE=$1 + export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving + unsetproxy + case $TYPE in + CPU) + # compile java sdk + cd java # pwd: /Serving/java + mvn compile > /dev/null + mvn install > /dev/null + # compile java sdk example + cd examples # pwd: /Serving/java/examples + mvn compile > /dev/null + mvn install > /dev/null + + # fit_a_line (general, asyn_predict, batch_predict) + cd ../../python/examples/grpc_impl_example/fit_a_line # pwd: /Serving/python/examples/grpc_impl_example/fit_a_line + sh get_data.sh + check_cmd "python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --thread 4 --use_multilang > /dev/null &" + sleep 5 # wait for the server to start + cd ../../../java/examples # /Serving/java/examples + java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample fit_a_line + java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample asyn_predict + java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample batch_predict + kill_server_process + + # imdb (model_ensemble) + cd ../../python/examples/grpc_impl_example/imdb # pwd: /Serving/python/examples/grpc_impl_example/imdb + sh get_data.sh > /dev/null + check_cmd "python test_multilang_ensemble_server.py > /dev/null &" + sleep 5 # wait for the server to start + cd ../../../java/examples # /Serving/java/examples + java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample model_ensemble + kill_server_process + + # yolov4 (int32) + cd ../../python/examples/grpc_impl_example/yolov4 # pwd: /Serving/python/examples/grpc_impl_example/yolov4 + python -m paddle_serving_app.package --get_model yolov4 > /dev/null + tar -xzf yolov4.tar.gz > /dev/null + check_cmd "python -m paddle_serving_server.serve --model yolov4_model --port 9393 --use_multilang --mem_optim > /dev/null &" + cd ../../../java/examples # /Serving/java/examples + java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample yolov4 src/main/resources/000000570688.jpg + kill_server_process + cd ../../ # pwd: /Serving + ;; + GPU) + ;; + *) + echo "error type" + exit 1 + ;; + esac + echo "java-sdk $TYPE part finished as expected." + setproxy + unset SERVING_BIN +} + function python_test_grpc_impl() { # pwd: /Serving/python/examples cd grpc_impl_example # pwd: /Serving/python/examples/grpc_impl_example @@ -537,7 +595,7 @@ function python_test_grpc_impl() { # test load server config and client config in Server side cd criteo_ctr_with_cube # pwd: /Serving/python/examples/grpc_impl_example/criteo_ctr_with_cube - check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz" + check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz > /dev/null" check_cmd "tar xf ctr_cube_unittest.tar.gz" check_cmd "mv models/ctr_client_conf ./" check_cmd "mv models/ctr_serving_model_kv ./" @@ -978,6 +1036,7 @@ function main() { build_client $TYPE # pwd: /Serving build_server $TYPE # pwd: /Serving build_app $TYPE # pwd: /Serving + java_run_test $TYPE # pwd: /Serving python_run_test $TYPE # pwd: /Serving monitor_test $TYPE # pwd: /Serving echo "serving $TYPE part finished as expected."