diff --git a/core/configure/proto/multi_lang_general_model_service.proto b/core/configure/proto/multi_lang_general_model_service.proto
index 2a8a8bc1532c19aa02a1998aa751aa7ba9d41570..b83450aed666b96de324050d53b10c56e059a8d5 100644
--- a/core/configure/proto/multi_lang_general_model_service.proto
+++ b/core/configure/proto/multi_lang_general_model_service.proto
@@ -14,6 +14,10 @@
syntax = "proto2";
+option java_multiple_files = true;
+option java_package = "io.paddle.serving.grpc";
+option java_outer_classname = "ServingProto";
+
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
diff --git a/doc/JAVA_SDK.md b/doc/JAVA_SDK.md
new file mode 100644
index 0000000000000000000000000000000000000000..4880e74bfee123b432b6b583a239d2d2ccbb45ac
--- /dev/null
+++ b/doc/JAVA_SDK.md
@@ -0,0 +1,109 @@
+# Paddle Serving Client Java SDK
+
+([简体中文](JAVA_SDK_CN.md)|English)
+
+Paddle Serving provides Java SDK,which supports predict on the Client side with Java language. This document shows how to use the Java SDK.
+
+## Getting started
+
+
+### Prerequisites
+
+```
+- Java 8 or higher
+- Apache Maven
+```
+
+The following table shows compatibilities between Paddle Serving Server and Java SDK.
+
+| Paddle Serving Server version | Java SDK version |
+| :---------------------------: | :--------------: |
+| 0.3.2 | 0.0.1 |
+
+### Install Java SDK
+
+You can download jar and install it to the local Maven repository:
+
+```shell
+wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar
+mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar
+```
+
+Or compile from the source code and install it to the local Maven repository:
+
+```shell
+cd Serving/java
+mvn compile
+mvn install
+```
+
+### Maven configure
+
+```text
+
+ io.paddle.serving.client
+ paddle-serving-sdk-java
+ 0.0.1
+
+```
+
+
+
+## Example
+
+Here we will show how to use Java SDK for Boston house price prediction. Please refer to [examples](../java/examples) folder for more examples.
+
+### Get model
+
+```shell
+wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+tar -xzf uci_housing.tar.gz
+```
+
+### Start Python Server
+
+```shell
+python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang
+```
+
+#### Client side code example
+
+```java
+import io.paddle.serving.client.*;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import java.util.*;
+
+public class PaddleServingClientExample {
+ public static void main( String[] args ) {
+ float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
+ 0.0582f, -0.0727f, -0.1583f, -0.0584f,
+ 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("x", npdata);
+ }};
+ List fetch = Arrays.asList("price");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return ;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ System.out.println("predict failed.");
+ return ;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return ;
+ }
+}
+```
diff --git a/doc/JAVA_SDK_CN.md b/doc/JAVA_SDK_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..f624a4403371f5b284f34cbf310fef64d59602d9
--- /dev/null
+++ b/doc/JAVA_SDK_CN.md
@@ -0,0 +1,109 @@
+# Paddle Serving Client Java SDK
+
+(简体中文|[English](JAVA_SDK.md))
+
+Paddle Serving 提供了 Java SDK,支持 Client 端用 Java 语言进行预测,本文档说明了如何使用 Java SDK。
+
+## 快速开始
+
+### 环境要求
+
+```
+- Java 8 or higher
+- Apache Maven
+```
+
+下表显示了 Paddle Serving Server 和 Java SDK 之间的兼容性
+
+| Paddle Serving Server version | Java SDK version |
+| :---------------------------: | :--------------: |
+| 0.3.2 | 0.0.1 |
+
+### 安装
+
+您可以直接下载 jar,安装到本地 Maven 库:
+
+```shell
+wget https://paddle-serving.bj.bcebos.com/jar/paddle-serving-sdk-java-0.0.1.jar
+mvn install:install-file -Dfile=$PWD/paddle-serving-sdk-java-0.0.1.jar -DgroupId=io.paddle.serving.client -DartifactId=paddle-serving-sdk-java -Dversion=0.0.1 -Dpackaging=jar
+```
+
+或者从源码进行编译,安装到本地 Maven 库:
+
+```shell
+cd Serving/java
+mvn compile
+mvn install
+```
+
+### Maven 配置
+
+```text
+
+ io.paddle.serving.client
+ paddle-serving-sdk-java
+ 0.0.1
+
+```
+
+
+
+
+## 使用样例
+
+这里将展示如何使用 Java SDK 进行房价预测,更多例子详见 [examples](../java/examples) 文件夹。
+
+### 获取房价预测模型
+
+```shell
+wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+tar -xzf uci_housing.tar.gz
+```
+
+### 启动 Python 端 Server
+
+```shell
+python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --use_multilang
+```
+
+### Client 端代码示例
+
+```java
+import io.paddle.serving.client.*;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import java.util.*;
+
+public class PaddleServingClientExample {
+ public static void main( String[] args ) {
+ float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
+ 0.0582f, -0.0727f, -0.1583f, -0.0584f,
+ 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("x", npdata);
+ }};
+ List fetch = Arrays.asList("price");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return ;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ System.out.println("predict failed.");
+ return ;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return ;
+ }
+}
+```
diff --git a/java/examples/pom.xml b/java/examples/pom.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b6c8bc424f5d528d74a4a45828fd9b5e7e5d008e
--- /dev/null
+++ b/java/examples/pom.xml
@@ -0,0 +1,88 @@
+
+
+
+ 4.0.0
+
+ io.paddle.serving.client
+ paddle-serving-sdk-java-examples
+ 0.0.1
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ 8
+
+ 3.8.1
+
+
+ maven-assembly-plugin
+
+
+
+ true
+ my.fully.qualified.class.Main
+
+
+
+ jar-with-dependencies
+
+
+
+
+ make-my-jar-with-dependencies
+ package
+
+ single
+
+
+
+
+
+
+
+
+ UTF-8
+ nd4j-native
+ 1.0.0-beta7
+ 1.0.0-beta7
+ 0.0.1
+ 1.7
+ 1.7
+
+
+
+
+ io.paddle.serving.client
+ paddle-serving-sdk-java
+ ${paddle.serving.client.version}
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.30
+
+
+ org.nd4j
+ ${nd4j.backend}
+ ${nd4j.version}
+
+
+ junit
+ junit
+ 4.11
+ test
+
+
+ org.datavec
+ datavec-data-image
+ ${datavec.version}
+
+
+
+
diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java
new file mode 100644
index 0000000000000000000000000000000000000000..cdc11df130095d668734ae0a23adb12ef735b2ea
--- /dev/null
+++ b/java/examples/src/main/java/PaddleServingClientExample.java
@@ -0,0 +1,363 @@
+import io.paddle.serving.client.*;
+import java.io.File;
+import java.io.IOException;
+import java.net.URL;
+import org.nd4j.linalg.api.iter.NdIndexIterator;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.datavec.image.loader.NativeImageLoader;
+import org.nd4j.linalg.api.ops.CustomOp;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.factory.Nd4j;
+import java.util.*;
+
+public class PaddleServingClientExample {
+ boolean fit_a_line() {
+ float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
+ 0.0582f, -0.0727f, -0.1583f, -0.0584f,
+ 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("x", npdata);
+ }};
+ List fetch = Arrays.asList("price");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ boolean yolov4(String filename) {
+ // https://deeplearning4j.konduit.ai/
+ int height = 608;
+ int width = 608;
+ int channels = 3;
+ NativeImageLoader loader = new NativeImageLoader(height, width, channels);
+ INDArray BGRimage = null;
+ try {
+ BGRimage = loader.asMatrix(new File(filename));
+ } catch (java.io.IOException e) {
+ System.out.println("load image fail.");
+ return false;
+ }
+
+ // shape: (channels, height, width)
+ BGRimage = BGRimage.reshape(channels, height, width);
+ INDArray RGBimage = Nd4j.create(BGRimage.shape());
+
+ // BGR2RGB
+ CustomOp op = DynamicCustomOp.builder("reverse")
+ .addInputs(BGRimage)
+ .addOutputs(RGBimage)
+ .addIntegerArguments(0)
+ .build();
+ Nd4j.getExecutioner().exec(op);
+
+ // Div(255.0)
+ INDArray image = RGBimage.divi(255.0);
+
+ INDArray im_size = Nd4j.createFromArray(new int[]{height, width});
+ HashMap feed_data
+ = new HashMap() {{
+ put("image", image);
+ put("im_size", im_size);
+ }};
+ List fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+ succ = client.setRpcTimeoutMs(20000); // cpu
+ if (succ != true) {
+ System.out.println("set timeout failed.");
+ return false;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ boolean batch_predict() {
+ float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
+ 0.0582f, -0.0727f, -0.1583f, -0.0584f,
+ 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("x", npdata);
+ }};
+ List> feed_batch
+ = new ArrayList>() {{
+ add(feed_data);
+ add(feed_data);
+ }};
+ List fetch = Arrays.asList("price");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ Map fetch_map = client.predict(feed_batch, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ boolean asyn_predict() {
+ float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
+ 0.0582f, -0.0727f, -0.1583f, -0.0584f,
+ 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("x", npdata);
+ }};
+ List fetch = Arrays.asList("price");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ PredictFuture future = client.asyn_predict(feed_data, fetch);
+ Map fetch_map = future.get();
+ if (fetch_map == null) {
+ System.out.println("Get future reslut failed");
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ boolean model_ensemble() {
+ long[] data = {8, 233, 52, 601};
+ INDArray npdata = Nd4j.createFromArray(data);
+ HashMap feed_data
+ = new HashMap() {{
+ put("words", npdata);
+ }};
+ List fetch = Arrays.asList("prediction");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ Map> fetch_map
+ = client.ensemble_predict(feed_data, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry> entry : fetch_map.entrySet()) {
+ System.out.println("Model = " + entry.getKey());
+ HashMap tt = entry.getValue();
+ for (Map.Entry e : tt.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ }
+ return true;
+ }
+
+ boolean bert() {
+ float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ long[] segment_ids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ HashMap feed_data
+ = new HashMap() {{
+ put("input_mask", Nd4j.createFromArray(input_mask));
+ put("position_ids", Nd4j.createFromArray(position_ids));
+ put("input_ids", Nd4j.createFromArray(input_ids));
+ put("segment_ids", Nd4j.createFromArray(segment_ids));
+ }};
+ List fetch = Arrays.asList("pooled_output");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ boolean cube_local() {
+ long[] embedding_14 = {250644};
+ long[] embedding_2 = {890346};
+ long[] embedding_10 = {3939};
+ long[] embedding_17 = {421122};
+ long[] embedding_23 = {664215};
+ long[] embedding_6 = {704846};
+ float[] dense_input = {0.0f, 0.006633499170812604f, 0.03f, 0.0f,
+ 0.145078125f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ long[] embedding_24 = {269955};
+ long[] embedding_12 = {295309};
+ long[] embedding_7 = {437731};
+ long[] embedding_3 = {990128};
+ long[] embedding_1 = {7753};
+ long[] embedding_4 = {286835};
+ long[] embedding_8 = {27346};
+ long[] embedding_9 = {636474};
+ long[] embedding_18 = {880474};
+ long[] embedding_16 = {681378};
+ long[] embedding_22 = {410878};
+ long[] embedding_13 = {255651};
+ long[] embedding_5 = {25207};
+ long[] embedding_11 = {10891};
+ long[] embedding_20 = {238459};
+ long[] embedding_21 = {26235};
+ long[] embedding_15 = {691460};
+ long[] embedding_25 = {544187};
+ long[] embedding_19 = {537425};
+ long[] embedding_0 = {737395};
+
+ HashMap feed_data
+ = new HashMap() {{
+ put("embedding_14.tmp_0", Nd4j.createFromArray(embedding_14));
+ put("embedding_2.tmp_0", Nd4j.createFromArray(embedding_2));
+ put("embedding_10.tmp_0", Nd4j.createFromArray(embedding_10));
+ put("embedding_17.tmp_0", Nd4j.createFromArray(embedding_17));
+ put("embedding_23.tmp_0", Nd4j.createFromArray(embedding_23));
+ put("embedding_6.tmp_0", Nd4j.createFromArray(embedding_6));
+ put("dense_input", Nd4j.createFromArray(dense_input));
+ put("embedding_24.tmp_0", Nd4j.createFromArray(embedding_24));
+ put("embedding_12.tmp_0", Nd4j.createFromArray(embedding_12));
+ put("embedding_7.tmp_0", Nd4j.createFromArray(embedding_7));
+ put("embedding_3.tmp_0", Nd4j.createFromArray(embedding_3));
+ put("embedding_1.tmp_0", Nd4j.createFromArray(embedding_1));
+ put("embedding_4.tmp_0", Nd4j.createFromArray(embedding_4));
+ put("embedding_8.tmp_0", Nd4j.createFromArray(embedding_8));
+ put("embedding_9.tmp_0", Nd4j.createFromArray(embedding_9));
+ put("embedding_18.tmp_0", Nd4j.createFromArray(embedding_18));
+ put("embedding_16.tmp_0", Nd4j.createFromArray(embedding_16));
+ put("embedding_22.tmp_0", Nd4j.createFromArray(embedding_22));
+ put("embedding_13.tmp_0", Nd4j.createFromArray(embedding_13));
+ put("embedding_5.tmp_0", Nd4j.createFromArray(embedding_5));
+ put("embedding_11.tmp_0", Nd4j.createFromArray(embedding_11));
+ put("embedding_20.tmp_0", Nd4j.createFromArray(embedding_20));
+ put("embedding_21.tmp_0", Nd4j.createFromArray(embedding_21));
+ put("embedding_15.tmp_0", Nd4j.createFromArray(embedding_15));
+ put("embedding_25.tmp_0", Nd4j.createFromArray(embedding_25));
+ put("embedding_19.tmp_0", Nd4j.createFromArray(embedding_19));
+ put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
+ }};
+ List fetch = Arrays.asList("prob");
+
+ Client client = new Client();
+ String target = "localhost:9393";
+ boolean succ = client.connect(target);
+ if (succ != true) {
+ System.out.println("connect failed.");
+ return false;
+ }
+
+ Map fetch_map = client.predict(feed_data, fetch);
+ if (fetch_map == null) {
+ return false;
+ }
+
+ for (Map.Entry e : fetch_map.entrySet()) {
+ System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
+ }
+ return true;
+ }
+
+ public static void main( String[] args ) {
+ // DL4J(Deep Learning for Java)Document:
+ // https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md
+ PaddleServingClientExample e = new PaddleServingClientExample();
+ boolean succ = false;
+
+ if (args.length < 1) {
+ System.out.println("Usage: java -cp PaddleServingClientExample .");
+ System.out.println(": fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4");
+ return;
+ }
+ String testType = args[0];
+ System.out.format("[Example] %s\n", testType);
+ if ("fit_a_line".equals(testType)) {
+ succ = e.fit_a_line();
+ } else if ("bert".equals(testType)) {
+ succ = e.bert();
+ } else if ("model_ensemble".equals(testType)) {
+ succ = e.model_ensemble();
+ } else if ("asyn_predict".equals(testType)) {
+ succ = e.asyn_predict();
+ } else if ("batch_predict".equals(testType)) {
+ succ = e.batch_predict();
+ } else if ("cube_local".equals(testType)) {
+ succ = e.cube_local();
+ } else if ("cube_quant".equals(testType)) {
+ succ = e.cube_local();
+ } else if ("yolov4".equals(testType)) {
+ if (args.length < 2) {
+ System.out.println("Usage: java -cp PaddleServingClientExample yolov4 .");
+ return;
+ }
+ succ = e.yolov4(args[1]);
+ } else {
+ System.out.format("test-type(%s) not match.\n", testType);
+ return;
+ }
+
+ if (succ == true) {
+ System.out.println("[Example] succ.");
+ } else {
+ System.out.println("[Example] fail.");
+ }
+ }
+}
diff --git a/java/examples/src/main/resources/000000570688.jpg b/java/examples/src/main/resources/000000570688.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54
Binary files /dev/null and b/java/examples/src/main/resources/000000570688.jpg differ
diff --git a/java/pom.xml b/java/pom.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d7e9ea7a097ea1ea2f41f930773d4a5d72a6d515
--- /dev/null
+++ b/java/pom.xml
@@ -0,0 +1,267 @@
+
+
+
+ 4.0.0
+
+ io.paddle.serving.client
+ paddle-serving-sdk-java
+ 0.0.1
+ jar
+
+ paddle-serving-sdk-java
+ Java SDK for Paddle Sering Client.
+ https://github.com/PaddlePaddle/Serving
+
+
+
+ Apache License, Version 2.0
+ http://www.apache.org/licenses/LICENSE-2.0.txt
+ repo
+
+
+
+
+
+ PaddlePaddle Author
+ guru4elephant@gmail.com
+ PaddlePaddle
+ https://github.com/PaddlePaddle/Serving
+
+
+
+
+ scm:git:https://github.com/PaddlePaddle/Serving.git
+ scm:git:https://github.com/PaddlePaddle/Serving.git
+ https://github.com/PaddlePaddle/Serving
+
+
+
+ UTF-8
+ 1.27.2
+ 3.11.0
+ 3.11.0
+ nd4j-native
+ 1.0.0-beta7
+ 1.8
+ 1.8
+
+
+
+
+
+ io.grpc
+ grpc-bom
+ ${grpc.version}
+ pom
+ import
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 1.6
+
+
+ io.grpc
+ grpc-netty-shaded
+ runtime
+
+
+ io.grpc
+ grpc-protobuf
+
+
+ io.grpc
+ grpc-stub
+
+
+ javax.annotation
+ javax.annotation-api
+ 1.2
+ provided
+
+
+ io.grpc
+ grpc-testing
+ test
+
+
+ com.google.protobuf
+ protobuf-java-util
+ ${protobuf.version}
+ runtime
+
+
+ com.google.errorprone
+ error_prone_annotations
+ 2.3.4
+
+
+ org.junit.jupiter
+ junit-jupiter
+ 5.5.2
+ test
+
+
+ org.apache.commons
+ commons-text
+ 1.6
+
+
+ org.apache.commons
+ commons-collections4
+ 4.4
+
+
+ org.json
+ json
+ 20190722
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.30
+
+
+ org.apache.logging.log4j
+ log4j-slf4j-impl
+ 2.12.1
+
+
+ org.nd4j
+ ${nd4j.backend}
+ ${nd4j.version}
+
+
+
+
+
+ release
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ 3.1.0
+
+
+ attach-sources
+
+ jar-no-fork
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 3.1.1
+
+ ${java.home}/bin/javadoc
+
+
+
+ attach-javadocs
+
+ jar
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 1.6
+
+
+ sign-artifacts
+ verify
+
+ sign
+
+
+
+
+
+
+
+
+
+
+
+
+ kr.motd.maven
+ os-maven-plugin
+ 1.6.2
+
+
+
+
+ org.sonatype.plugins
+ nexus-staging-maven-plugin
+ 1.6.8
+ true
+
+ ossrh
+ https://oss.sonatype.org/
+ true
+
+
+
+ org.apache.maven.plugins
+ maven-release-plugin
+ 2.5.3
+
+ true
+ false
+ release
+ deploy
+
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+ 0.6.1
+
+ com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier}
+
+ grpc-java
+ io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}
+
+
+
+
+
+ compile
+ compile-custom
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+ 3.0.0-M2
+
+
+ enforce
+
+
+
+
+
+
+ enforce
+
+
+
+
+
+
+
+
diff --git a/java/src/main/java/io/paddle/serving/client/Client.java b/java/src/main/java/io/paddle/serving/client/Client.java
new file mode 100644
index 0000000000000000000000000000000000000000..1e09e0c23c89dd4f0d70e0c93269b2185a69807f
--- /dev/null
+++ b/java/src/main/java/io/paddle/serving/client/Client.java
@@ -0,0 +1,471 @@
+package io.paddle.serving.client;
+
+import java.util.*;
+import java.util.function.Function;
+import java.lang.management.ManagementFactory;
+import java.lang.management.RuntimeMXBean;
+
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import io.grpc.StatusRuntimeException;
+import com.google.protobuf.ByteString;
+
+import com.google.common.util.concurrent.ListenableFuture;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.iter.NdIndexIterator;
+import org.nd4j.linalg.factory.Nd4j;
+
+import io.paddle.serving.grpc.*;
+import io.paddle.serving.configure.*;
+import io.paddle.serving.client.PredictFuture;
+
+class Profiler {
+ int pid_;
+ String print_head_ = null;
+ List time_record_ = null;
+ boolean enable_ = false;
+
+ Profiler() {
+ RuntimeMXBean runtimeMXBean = ManagementFactory.getRuntimeMXBean();
+ pid_ = Integer.valueOf(runtimeMXBean.getName().split("@")[0]).intValue();
+ print_head_ = "\nPROFILE\tpid:" + pid_ + "\t";
+ time_record_ = new ArrayList();
+ time_record_.add(print_head_);
+ }
+
+ void record(String name) {
+ if (enable_) {
+ long ctime = System.currentTimeMillis() * 1000;
+ time_record_.add(name + ":" + String.valueOf(ctime) + " ");
+ }
+ }
+
+ void printProfile() {
+ if (enable_) {
+ String profile_str = String.join("", time_record_);
+ time_record_ = new ArrayList();
+ time_record_.add(print_head_);
+ }
+ }
+
+ void enable(boolean flag) {
+ enable_ = flag;
+ }
+}
+
+public class Client {
+ private ManagedChannel channel_;
+ private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
+ private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceFutureStub futureStub_;
+ private double rpcTimeoutS_;
+ private List feedNames_;
+ private Map feedTypes_;
+ private Map> feedShapes_;
+ private List fetchNames_;
+ private Map fetchTypes_;
+ private Set lodTensorSet_;
+ private Map feedTensorLen_;
+ private Profiler profiler_;
+
+ public Client() {
+ channel_ = null;
+ blockingStub_ = null;
+ futureStub_ = null;
+ rpcTimeoutS_ = 2;
+
+ feedNames_ = null;
+ feedTypes_ = null;
+ feedShapes_ = null;
+ fetchNames_ = null;
+ fetchTypes_ = null;
+ lodTensorSet_ = null;
+ feedTensorLen_ = null;
+
+ profiler_ = new Profiler();
+ boolean is_profile = false;
+ String FLAGS_profile_client = System.getenv("FLAGS_profile_client");
+ if (FLAGS_profile_client != null && FLAGS_profile_client.equals("1")) {
+ is_profile = true;
+ }
+ profiler_.enable(is_profile);
+ }
+
+ public boolean setRpcTimeoutMs(int rpc_timeout) {
+ if (futureStub_ == null || blockingStub_ == null) {
+ System.out.println("set timeout must be set after connect.");
+ return false;
+ }
+ rpcTimeoutS_ = rpc_timeout / 1000.0;
+ SetTimeoutRequest timeout_req = SetTimeoutRequest.newBuilder()
+ .setTimeoutMs(rpc_timeout)
+ .build();
+ SimpleResponse resp;
+ try {
+ resp = blockingStub_.setTimeout(timeout_req);
+ } catch (StatusRuntimeException e) {
+ System.out.format("Set RPC timeout failed: %s\n", e.toString());
+ return false;
+ }
+ return resp.getErrCode() == 0;
+ }
+
+ public boolean connect(String target) {
+ // TODO: target must be NameResolver-compliant URI
+ // https://grpc.github.io/grpc-java/javadoc/io/grpc/ManagedChannelBuilder.html
+ try {
+ channel_ = ManagedChannelBuilder.forTarget(target)
+ .defaultLoadBalancingPolicy("round_robin")
+ .maxInboundMessageSize(Integer.MAX_VALUE)
+ .usePlaintext()
+ .build();
+ blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
+ futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_);
+ } catch (Exception e) {
+ System.out.format("Connect failed: %s\n", e.toString());
+ return false;
+ }
+ GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
+ GetClientConfigResponse resp;
+ try {
+ resp = blockingStub_.getClientConfig(get_client_config_req);
+ } catch (Exception e) {
+ System.out.format("Get Client config failed: %s\n", e.toString());
+ return false;
+ }
+ String model_config_str = resp.getClientConfigStr();
+ _parseModelConfig(model_config_str);
+ return true;
+ }
+
+ private void _parseModelConfig(String model_config_str) {
+ GeneralModelConfig.Builder model_conf_builder = GeneralModelConfig.newBuilder();
+ try {
+ com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder);
+ } catch (com.google.protobuf.TextFormat.ParseException e) {
+ System.out.format("Parse client config failed: %s\n", e.toString());
+ }
+ GeneralModelConfig model_conf = model_conf_builder.build();
+
+ feedNames_ = new ArrayList();
+ fetchNames_ = new ArrayList();
+ feedTypes_ = new HashMap();
+ feedShapes_ = new HashMap>();
+ fetchTypes_ = new HashMap();
+ lodTensorSet_ = new HashSet();
+ feedTensorLen_ = new HashMap();
+
+ List feed_var_list = model_conf.getFeedVarList();
+ for (FeedVar feed_var : feed_var_list) {
+ feedNames_.add(feed_var.getAliasName());
+ }
+ List fetch_var_list = model_conf.getFetchVarList();
+ for (FetchVar fetch_var : fetch_var_list) {
+ fetchNames_.add(fetch_var.getAliasName());
+ }
+
+ for (int i = 0; i < feed_var_list.size(); ++i) {
+ FeedVar feed_var = feed_var_list.get(i);
+ String var_name = feed_var.getAliasName();
+ feedTypes_.put(var_name, feed_var.getFeedType());
+ feedShapes_.put(var_name, feed_var.getShapeList());
+ if (feed_var.getIsLodTensor()) {
+ lodTensorSet_.add(var_name);
+ } else {
+ int counter = 1;
+ for (int dim : feedShapes_.get(var_name)) {
+ counter *= dim;
+ }
+ feedTensorLen_.put(var_name, counter);
+ }
+ }
+
+ for (int i = 0; i < fetch_var_list.size(); i++) {
+ FetchVar fetch_var = fetch_var_list.get(i);
+ String var_name = fetch_var.getAliasName();
+ fetchTypes_.put(var_name, fetch_var.getFetchType());
+ if (fetch_var.getIsLodTensor()) {
+ lodTensorSet_.add(var_name);
+ }
+ }
+ }
+
+ private InferenceRequest _packInferenceRequest(
+ List> feed_batch,
+ Iterable fetch) throws IllegalArgumentException {
+ List feed_var_names = new ArrayList();
+ feed_var_names.addAll(feed_batch.get(0).keySet());
+
+ InferenceRequest.Builder req_builder = InferenceRequest.newBuilder()
+ .addAllFeedVarNames(feed_var_names)
+ .addAllFetchVarNames(fetch)
+ .setIsPython(false);
+ for (HashMap feed_data: feed_batch) {
+ FeedInst.Builder inst_builder = FeedInst.newBuilder();
+ for (String name: feed_var_names) {
+ Tensor.Builder tensor_builder = Tensor.newBuilder();
+ INDArray variable = feed_data.get(name);
+ long[] flattened_shape = {-1};
+ INDArray flattened_list = variable.reshape(flattened_shape);
+ int v_type = feedTypes_.get(name);
+ NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
+ if (v_type == 0) { // int64
+ while (iter.hasNext()) {
+ long[] next_index = iter.next();
+ long x = flattened_list.getLong(next_index);
+ tensor_builder.addInt64Data(x);
+ }
+ } else if (v_type == 1) { // float32
+ while (iter.hasNext()) {
+ long[] next_index = iter.next();
+ float x = flattened_list.getFloat(next_index);
+ tensor_builder.addFloatData(x);
+ }
+ } else if (v_type == 2) { // int32
+ while (iter.hasNext()) {
+ long[] next_index = iter.next();
+ // the interface of INDArray is strange:
+ // https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html
+ int[] int_next_index = new int[next_index.length];
+ for(int i = 0; i < next_index.length; i++) {
+ int_next_index[i] = (int)next_index[i];
+ }
+ int x = flattened_list.getInt(int_next_index);
+ tensor_builder.addIntData(x);
+ }
+ } else {
+ throw new IllegalArgumentException("error tensor value type.");
+ }
+ tensor_builder.addAllShape(feedShapes_.get(name));
+ inst_builder.addTensorArray(tensor_builder.build());
+ }
+ req_builder.addInsts(inst_builder.build());
+ }
+ return req_builder.build();
+ }
+
+ private Map>
+ _unpackInferenceResponse(
+ InferenceResponse resp,
+ Iterable fetch,
+ Boolean need_variant_tag) throws IllegalArgumentException {
+ return Client._staticUnpackInferenceResponse(
+ resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
+ }
+
+ private static Map>
+ _staticUnpackInferenceResponse(
+ InferenceResponse resp,
+ Iterable fetch,
+ Map fetchTypes,
+ Set lodTensorSet,
+ Boolean need_variant_tag) throws IllegalArgumentException {
+ if (resp.getErrCode() != 0) {
+ return null;
+ }
+ String tag = resp.getTag();
+ HashMap> multi_result_map
+ = new HashMap>();
+ for (ModelOutput model_result: resp.getOutputsList()) {
+ String engine_name = model_result.getEngineName();
+ FetchInst inst = model_result.getInsts(0);
+ HashMap result_map
+ = new HashMap();
+ int index = 0;
+ for (String name: fetch) {
+ Tensor variable = inst.getTensorArray(index);
+ int v_type = fetchTypes.get(name);
+ INDArray data = null;
+ if (v_type == 0) { // int64
+ List list = variable.getInt64DataList();
+ long[] array = new long[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ array[i] = list.get(i);
+ }
+ data = Nd4j.createFromArray(array);
+ } else if (v_type == 1) { // float32
+ List list = variable.getFloatDataList();
+ float[] array = new float[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ array[i] = list.get(i);
+ }
+ data = Nd4j.createFromArray(array);
+ } else if (v_type == 2) { // int32
+ List list = variable.getIntDataList();
+ int[] array = new int[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ array[i] = list.get(i);
+ }
+ data = Nd4j.createFromArray(array);
+ } else {
+ throw new IllegalArgumentException("error tensor value type.");
+ }
+ // shape
+ List shape_lsit = variable.getShapeList();
+ int[] shape_array = new int[shape_lsit.size()];
+ for (int i = 0; i < shape_lsit.size(); ++i) {
+ shape_array[i] = shape_lsit.get(i);
+ }
+ data = data.reshape(shape_array);
+
+ // put data to result_map
+ result_map.put(name, data);
+
+ // lod
+ if (lodTensorSet.contains(name)) {
+ List list = variable.getLodList();
+ int[] array = new int[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ array[i] = list.get(i);
+ }
+ result_map.put(name + ".lod", Nd4j.createFromArray(array));
+ }
+ index += 1;
+ }
+ multi_result_map.put(engine_name, result_map);
+ }
+
+ // TODO: tag(ABtest not support now)
+ return multi_result_map;
+ }
+
+ public Map predict(
+ HashMap feed,
+ Iterable fetch) {
+ return predict(feed, fetch, false);
+ }
+
+ public Map> ensemble_predict(
+ HashMap feed,
+ Iterable fetch) {
+ return ensemble_predict(feed, fetch, false);
+ }
+
+ public PredictFuture asyn_predict(
+ HashMap feed,
+ Iterable fetch) {
+ return asyn_predict(feed, fetch, false);
+ }
+
+ public Map predict(
+ HashMap feed,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ List> feed_batch
+ = new ArrayList>();
+ feed_batch.add(feed);
+ return predict(feed_batch, fetch, need_variant_tag);
+ }
+
+ public Map> ensemble_predict(
+ HashMap feed,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ List> feed_batch
+ = new ArrayList>();
+ feed_batch.add(feed);
+ return ensemble_predict(feed_batch, fetch, need_variant_tag);
+ }
+
+ public PredictFuture asyn_predict(
+ HashMap feed,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ List> feed_batch
+ = new ArrayList>();
+ feed_batch.add(feed);
+ return asyn_predict(feed_batch, fetch, need_variant_tag);
+ }
+
+ public Map predict(
+ List> feed_batch,
+ Iterable fetch) {
+ return predict(feed_batch, fetch, false);
+ }
+
+ public Map> ensemble_predict(
+ List> feed_batch,
+ Iterable fetch) {
+ return ensemble_predict(feed_batch, fetch, false);
+ }
+
+ public PredictFuture asyn_predict(
+ List> feed_batch,
+ Iterable fetch) {
+ return asyn_predict(feed_batch, fetch, false);
+ }
+
+ public Map predict(
+ List> feed_batch,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ try {
+ profiler_.record("java_prepro_0");
+ InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
+ profiler_.record("java_prepro_1");
+
+ profiler_.record("java_client_infer_0");
+ InferenceResponse resp = blockingStub_.inference(req);
+ profiler_.record("java_client_infer_1");
+
+ profiler_.record("java_postpro_0");
+ Map> ensemble_result
+ = _unpackInferenceResponse(resp, fetch, need_variant_tag);
+ List>> list
+ = new ArrayList>>(
+ ensemble_result.entrySet());
+ if (list.size() != 1) {
+ System.out.format("predict failed: please use ensemble_predict impl.\n");
+ return null;
+ }
+ profiler_.record("java_postpro_1");
+ profiler_.printProfile();
+
+ return list.get(0).getValue();
+ } catch (StatusRuntimeException e) {
+ System.out.format("predict failed: %s\n", e.toString());
+ return null;
+ }
+ }
+
+ public Map> ensemble_predict(
+ List> feed_batch,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ try {
+ profiler_.record("java_prepro_0");
+ InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
+ profiler_.record("java_prepro_1");
+
+ profiler_.record("java_client_infer_0");
+ InferenceResponse resp = blockingStub_.inference(req);
+ profiler_.record("java_client_infer_1");
+
+ profiler_.record("java_postpro_0");
+ Map> ensemble_result
+ = _unpackInferenceResponse(resp, fetch, need_variant_tag);
+ profiler_.record("java_postpro_1");
+ profiler_.printProfile();
+
+ return ensemble_result;
+ } catch (StatusRuntimeException e) {
+ System.out.format("predict failed: %s\n", e.toString());
+ return null;
+ }
+ }
+
+ public PredictFuture asyn_predict(
+ List> feed_batch,
+ Iterable fetch,
+ Boolean need_variant_tag) {
+ InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
+ ListenableFuture future = futureStub_.inference(req);
+ PredictFuture predict_future = new PredictFuture(future,
+ (InferenceResponse resp) -> {
+ return Client._staticUnpackInferenceResponse(
+ resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
+ }
+ );
+ return predict_future;
+ }
+}
diff --git a/java/src/main/java/io/paddle/serving/client/PredictFuture.java b/java/src/main/java/io/paddle/serving/client/PredictFuture.java
new file mode 100644
index 0000000000000000000000000000000000000000..28156d965e76db889358be00ab8c05381e0f89d8
--- /dev/null
+++ b/java/src/main/java/io/paddle/serving/client/PredictFuture.java
@@ -0,0 +1,54 @@
+package io.paddle.serving.client;
+
+import java.util.*;
+import java.util.function.Function;
+import io.grpc.StatusRuntimeException;
+import com.google.common.util.concurrent.ListenableFuture;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+import io.paddle.serving.client.Client;
+import io.paddle.serving.grpc.*;
+
+public class PredictFuture {
+ private ListenableFuture callFuture_;
+ private Function>> callBackFunc_;
+
+ PredictFuture(ListenableFuture call_future,
+ Function>> call_back_func) {
+ callFuture_ = call_future;
+ callBackFunc_ = call_back_func;
+ }
+
+ public Map get() {
+ InferenceResponse resp = null;
+ try {
+ resp = callFuture_.get();
+ } catch (Exception e) {
+ System.out.format("predict failed: %s\n", e.toString());
+ return null;
+ }
+ Map> ensemble_result
+ = callBackFunc_.apply(resp);
+ List>> list
+ = new ArrayList>>(
+ ensemble_result.entrySet());
+ if (list.size() != 1) {
+ System.out.format("predict failed: please use get_ensemble impl.\n");
+ return null;
+ }
+ return list.get(0).getValue();
+ }
+
+ public Map> ensemble_get() {
+ InferenceResponse resp = null;
+ try {
+ resp = callFuture_.get();
+ } catch (Exception e) {
+ System.out.format("predict failed: %s\n", e.toString());
+ return null;
+ }
+ return callBackFunc_.apply(resp);
+ }
+}
diff --git a/java/src/main/proto/general_model_config.proto b/java/src/main/proto/general_model_config.proto
new file mode 100644
index 0000000000000000000000000000000000000000..03cff3f8c1ab4a369f132d64d7e4f2c871ebb077
--- /dev/null
+++ b/java/src/main/proto/general_model_config.proto
@@ -0,0 +1,40 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+
+option java_multiple_files = true;
+option java_package = "io.paddle.serving.configure";
+option java_outer_classname = "ConfigureProto";
+
+package paddle.serving.configure;
+
+message FeedVar {
+ optional string name = 1;
+ optional string alias_name = 2;
+ optional bool is_lod_tensor = 3 [ default = false ];
+ optional int32 feed_type = 4 [ default = 0 ];
+ repeated int32 shape = 5;
+}
+message FetchVar {
+ optional string name = 1;
+ optional string alias_name = 2;
+ optional bool is_lod_tensor = 3 [ default = false ];
+ optional int32 fetch_type = 4 [ default = 0 ];
+ repeated int32 shape = 5;
+}
+message GeneralModelConfig {
+ repeated FeedVar feed_var = 1;
+ repeated FetchVar fetch_var = 2;
+};
diff --git a/java/src/main/proto/multi_lang_general_model_service.proto b/java/src/main/proto/multi_lang_general_model_service.proto
new file mode 100644
index 0000000000000000000000000000000000000000..b83450aed666b96de324050d53b10c56e059a8d5
--- /dev/null
+++ b/java/src/main/proto/multi_lang_general_model_service.proto
@@ -0,0 +1,66 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+
+option java_multiple_files = true;
+option java_package = "io.paddle.serving.grpc";
+option java_outer_classname = "ServingProto";
+
+message Tensor {
+ optional bytes data = 1;
+ repeated int32 int_data = 2;
+ repeated int64 int64_data = 3;
+ repeated float float_data = 4;
+ optional int32 elem_type = 5;
+ repeated int32 shape = 6;
+ repeated int32 lod = 7; // only for fetch tensor currently
+};
+
+message FeedInst { repeated Tensor tensor_array = 1; };
+
+message FetchInst { repeated Tensor tensor_array = 1; };
+
+message InferenceRequest {
+ repeated FeedInst insts = 1;
+ repeated string feed_var_names = 2;
+ repeated string fetch_var_names = 3;
+ required bool is_python = 4 [ default = false ];
+};
+
+message InferenceResponse {
+ repeated ModelOutput outputs = 1;
+ optional string tag = 2;
+ required int32 err_code = 3;
+};
+
+message ModelOutput {
+ repeated FetchInst insts = 1;
+ optional string engine_name = 2;
+}
+
+message SetTimeoutRequest { required int32 timeout_ms = 1; }
+
+message SimpleResponse { required int32 err_code = 1; }
+
+message GetClientConfigRequest {}
+
+message GetClientConfigResponse { required string client_config_str = 1; }
+
+service MultiLangGeneralModelService {
+ rpc Inference(InferenceRequest) returns (InferenceResponse) {}
+ rpc SetTimeout(SetTimeoutRequest) returns (SimpleResponse) {}
+ rpc GetClientConfig(GetClientConfigRequest)
+ returns (GetClientConfigResponse) {}
+};
diff --git a/java/src/main/resources/log4j2.xml b/java/src/main/resources/log4j2.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e13b79d3f92acca50cafde874b501513dbdb292f
--- /dev/null
+++ b/java/src/main/resources/log4j2.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/python/examples/criteo_ctr_with_cube/test_server_quant.py b/python/examples/criteo_ctr_with_cube/test_server_quant.py
index fc278f755126cdeb204644cbc91838b1b038379e..38a3fe67da803d1c84337d64e3421d8295ac5767 100755
--- a/python/examples/criteo_ctr_with_cube/test_server_quant.py
+++ b/python/examples/criteo_ctr_with_cube/test_server_quant.py
@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1])
-server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
+server.prepare_server(
+ workdir="work_dir1",
+ port=9292,
+ device="cpu",
+ cube_conf="./cube/conf/cube.conf")
server.run_server()
diff --git a/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py b/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py
index feca75b077d737a614bdfd955b7bf0d82ed08529..2fd9308454b4caa862e7d83ddadb48279bba7167 100755
--- a/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py
+++ b/python/examples/grpc_impl_example/criteo_ctr_with_cube/test_server_quant.py
@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1], sys.argv[2])
-server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
+server.prepare_server(
+ workdir="work_dir1",
+ port=9292,
+ device="cpu",
+ cube_conf="./cube/conf/cube.conf")
server.run_server()
diff --git a/python/examples/grpc_impl_example/imdb/get_data.sh b/python/examples/grpc_impl_example/imdb/get_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..81d8d5d3b018f133c41e211d1501cf3cd9a3d8a4
--- /dev/null
+++ b/python/examples/grpc_impl_example/imdb/get_data.sh
@@ -0,0 +1,4 @@
+wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
+wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz
+tar -zxvf text_classification_data.tar.gz
+tar -zxvf imdb_model.tar.gz
diff --git a/python/examples/grpc_impl_example/imdb/imdb_reader.py b/python/examples/grpc_impl_example/imdb/imdb_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ef3e163a50b0dc244ac2653df1e38d7f91699b
--- /dev/null
+++ b/python/examples/grpc_impl_example/imdb/imdb_reader.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=doc-string-missing
+
+import sys
+import os
+import paddle
+import re
+import paddle.fluid.incubate.data_generator as dg
+
+py_version = sys.version_info[0]
+
+
+class IMDBDataset(dg.MultiSlotDataGenerator):
+ def load_resource(self, dictfile):
+ self._vocab = {}
+ wid = 0
+ if py_version == 2:
+ with open(dictfile) as f:
+ for line in f:
+ self._vocab[line.strip()] = wid
+ wid += 1
+ else:
+ with open(dictfile, encoding="utf-8") as f:
+ for line in f:
+ self._vocab[line.strip()] = wid
+ wid += 1
+ self._unk_id = len(self._vocab)
+ self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
+ self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
+
+ def get_words_only(self, line):
+ sent = line.lower().replace("
", " ").strip()
+ words = [x for x in self._pattern.split(sent) if x and x != " "]
+ feas = [
+ self._vocab[x] if x in self._vocab else self._unk_id for x in words
+ ]
+ return feas
+
+ def get_words_and_label(self, line):
+ send = '|'.join(line.split('|')[:-1]).lower().replace("
",
+ " ").strip()
+ label = [int(line.split('|')[-1])]
+
+ words = [x for x in self._pattern.split(send) if x and x != " "]
+ feas = [
+ self._vocab[x] if x in self._vocab else self._unk_id for x in words
+ ]
+ return feas, label
+
+ def infer_reader(self, infer_filelist, batch, buf_size):
+ def local_iter():
+ for fname in infer_filelist:
+ with open(fname, "r") as fin:
+ for line in fin:
+ feas, label = self.get_words_and_label(line)
+ yield feas, label
+
+ import paddle
+ batch_iter = paddle.batch(
+ paddle.reader.shuffle(
+ local_iter, buf_size=buf_size),
+ batch_size=batch)
+ return batch_iter
+
+ def generate_sample(self, line):
+ def memory_iter():
+ for i in range(1000):
+ yield self.return_value
+
+ def data_iter():
+ feas, label = self.get_words_and_label(line)
+ yield ("words", feas), ("label", label)
+
+ return data_iter
+
+
+if __name__ == "__main__":
+ imdb = IMDBDataset()
+ imdb.load_resource("imdb.vocab")
+ imdb.run_from_stdin()
diff --git a/python/examples/imdb/test_multilang_ensemble_client.py b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py
similarity index 95%
rename from python/examples/imdb/test_multilang_ensemble_client.py
rename to python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py
index 6686d4c8c38d6a17cb9c5701abf7d76773031772..43034e49bde4a477c160c5a0d158ea541d633a4d 100644
--- a/python/examples/imdb/test_multilang_ensemble_client.py
+++ b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_client.py
@@ -34,4 +34,6 @@ for i in range(3):
fetch = ["prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch)
for model, fetch_map in fetch_maps.items():
+ if model == "serving_status_code":
+ continue
print("step: {}, model: {}, res: {}".format(i, model, fetch_map))
diff --git a/python/examples/imdb/test_multilang_ensemble_server.py b/python/examples/grpc_impl_example/imdb/test_multilang_ensemble_server.py
similarity index 100%
rename from python/examples/imdb/test_multilang_ensemble_server.py
rename to python/examples/grpc_impl_example/imdb/test_multilang_ensemble_server.py
diff --git a/python/examples/grpc_impl_example/yolov4/000000570688.jpg b/python/examples/grpc_impl_example/yolov4/000000570688.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54
Binary files /dev/null and b/python/examples/grpc_impl_example/yolov4/000000570688.jpg differ
diff --git a/python/examples/grpc_impl_example/yolov4/README.md b/python/examples/grpc_impl_example/yolov4/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a04215dcf349b0e589819db16d53b3435bd904ff
--- /dev/null
+++ b/python/examples/grpc_impl_example/yolov4/README.md
@@ -0,0 +1,23 @@
+# Yolov4 Detection Service
+
+([简体中文](README_CN.md)|English)
+
+## Get Model
+
+```
+python -m paddle_serving_app.package --get_model yolov4
+tar -xzvf yolov4.tar.gz
+```
+
+## Start RPC Service
+
+```
+python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang
+```
+
+## Prediction
+
+```
+python test_client.py 000000570688.jpg
+```
+After the prediction is completed, a json file to save the prediction result and a picture with the detection result box will be generated in the `./outpu folder.
diff --git a/python/examples/grpc_impl_example/yolov4/README_CN.md b/python/examples/grpc_impl_example/yolov4/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..de7a85b59ccdf831337083b8d6047bfe41525220
--- /dev/null
+++ b/python/examples/grpc_impl_example/yolov4/README_CN.md
@@ -0,0 +1,24 @@
+# Yolov4 检测服务
+
+(简体中文|[English](README.md))
+
+## 获取模型
+
+```
+python -m paddle_serving_app.package --get_model yolov4
+tar -xzvf yolov4.tar.gz
+```
+
+## 启动RPC服务
+
+```
+python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang
+```
+
+## 预测
+
+```
+python test_client.py 000000570688.jpg
+```
+
+预测完成会在`./output`文件夹下生成保存预测结果的json文件以及标出检测结果框的图片。
diff --git a/python/examples/grpc_impl_example/yolov4/label_list.txt b/python/examples/grpc_impl_example/yolov4/label_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..941cb4e1392266f6a6c09b1fdc5f79503b2e5df6
--- /dev/null
+++ b/python/examples/grpc_impl_example/yolov4/label_list.txt
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+dining table
+toilet
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/python/examples/grpc_impl_example/yolov4/test_client.py b/python/examples/grpc_impl_example/yolov4/test_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55763880f7852f0297d7e6c7f44f8c3a206dc60
--- /dev/null
+++ b/python/examples/grpc_impl_example/yolov4/test_client.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import numpy as np
+from paddle_serving_client import MultiLangClient as Client
+from paddle_serving_app.reader import *
+import cv2
+
+preprocess = Sequential([
+ File2Image(), BGR2RGB(), Resize(
+ (608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
+ (2, 0, 1))
+])
+
+postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
+client = Client()
+client.connect(['127.0.0.1:9393'])
+# client.set_rpc_timeout_ms(10000)
+
+im = preprocess(sys.argv[1])
+fetch_map = client.predict(
+ feed={
+ "image": im,
+ "im_size": np.array(list(im.shape[1:])),
+ },
+ fetch=["save_infer_model/scale_0.tmp_0"])
+fetch_map.pop("serving_status_code")
+fetch_map["image"] = sys.argv[1]
+postprocess(fetch_map)
diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py
index 455bcf62cd039dde69736ec514892856eabd3088..cf669c54f3492fc739bedcfacc49537a5ecc545f 100644
--- a/python/paddle_serving_client/__init__.py
+++ b/python/paddle_serving_client/__init__.py
@@ -399,6 +399,7 @@ class MultiLangClient(object):
self.channel_ = None
self.stub_ = None
self.rpc_timeout_s_ = 2
+ self.profile_ = _Profiler()
def add_variant(self, tag, cluster, variant_weight):
# TODO
@@ -520,7 +521,7 @@ class MultiLangClient(object):
tensor.float_data.extend(
var.reshape(-1).astype('float32').tolist())
elif v_type == 2:
- tensor.int32_data.extend(
+ tensor.int_data.extend(
var.reshape(-1).astype('int32').tolist())
else:
raise Exception("error tensor value type.")
@@ -530,7 +531,7 @@ class MultiLangClient(object):
elif v_type == 1:
tensor.float_data.extend(self._flatten_list(var))
elif v_type == 2:
- tensor.int32_data.extend(self._flatten_list(var))
+ tensor.int_data.extend(self._flatten_list(var))
else:
raise Exception("error tensor value type.")
else:
@@ -582,6 +583,7 @@ class MultiLangClient(object):
ret = list(multi_result_map.values())[0]
else:
ret = multi_result_map
+
ret["serving_status_code"] = 0
return ret if not need_variant_tag else [ret, tag]
@@ -601,18 +603,30 @@ class MultiLangClient(object):
need_variant_tag=False,
asyn=False,
is_python=True):
- req = self._pack_inference_request(feed, fetch, is_python=is_python)
if not asyn:
try:
+ self.profile_.record('py_prepro_0')
+ req = self._pack_inference_request(
+ feed, fetch, is_python=is_python)
+ self.profile_.record('py_prepro_1')
+
+ self.profile_.record('py_client_infer_0')
resp = self.stub_.Inference(req, timeout=self.rpc_timeout_s_)
- return self._unpack_inference_response(
+ self.profile_.record('py_client_infer_1')
+
+ self.profile_.record('py_postpro_0')
+ ret = self._unpack_inference_response(
resp,
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag)
+ self.profile_.record('py_postpro_1')
+ self.profile_.print_profile()
+ return ret
except grpc.RpcError as e:
return {"serving_status_code": e.code()}
else:
+ req = self._pack_inference_request(feed, fetch, is_python=is_python)
call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture(
diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py
index 1e5fd16ed6c153a28cd72422ca3ef7b9177cb079..678c0583d1e132791a1199e315ea380a4ae3108b 100644
--- a/python/paddle_serving_server/__init__.py
+++ b/python/paddle_serving_server/__init__.py
@@ -524,7 +524,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32
- data = np.array(list(var.int32_data), dtype="int32")
+ data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
@@ -540,6 +540,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret
resp.tag = tag
resp.err_code = 0
+
if not self.is_multi_model_:
results = {'general_infer_0': results}
for model_name, model_result in results.items():
@@ -558,8 +559,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 2: # int32
- tensor.int32_data.extend(model_result[name].reshape(-1)
- .tolist())
+ tensor.int_data.extend(model_result[name].reshape(-1)
+ .tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py
index df04cb7840bbacd90ccb7e3c66147a6856b23e02..0261003a7863d11fb342d1572b124d1cbb533a2b 100644
--- a/python/paddle_serving_server_gpu/__init__.py
+++ b/python/paddle_serving_server_gpu/__init__.py
@@ -571,7 +571,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:
- data = np.array(list(var.int32_data), dtype="int32")
+ data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
@@ -587,6 +587,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
results, tag = ret
resp.tag = tag
resp.err_code = 0
+
if not self.is_multi_model_:
results = {'general_infer_0': results}
for model_name, model_result in results.items():
@@ -605,8 +606,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 2: # int32
- tensor.int32_data.extend(model_result[name].reshape(-1)
- .tolist())
+ tensor.int_data.extend(model_result[name].reshape(-1)
+ .tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
diff --git a/tools/Dockerfile.ci b/tools/Dockerfile.ci
index 8709075f6cf8f985e346999e76f6b273d7664193..fc52733bb214af918af34069e71b05b2eaa8511e 100644
--- a/tools/Dockerfile.ci
+++ b/tools/Dockerfile.ci
@@ -1,39 +1,50 @@
FROM centos:7.3.1611
+
RUN yum -y install wget >/dev/null \
&& yum -y install gcc gcc-c++ make glibc-static which >/dev/null \
&& yum -y install git openssl-devel curl-devel bzip2-devel python-devel >/dev/null \
&& yum -y install libSM-1.2.2-2.el7.x86_64 --setopt=protected_multilib=false \
&& yum -y install libXrender-0.9.10-1.el7.x86_64 --setopt=protected_multilib=false \
- && yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false \
- && wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \
+ && yum -y install libXext-1.3.3-3.el7.x86_64 --setopt=protected_multilib=false
+
+RUN wget https://cmake.org/files/v3.2/cmake-3.2.0-Linux-x86_64.tar.gz >/dev/null \
&& tar xzf cmake-3.2.0-Linux-x86_64.tar.gz \
&& mv cmake-3.2.0-Linux-x86_64 /usr/local/cmake3.2.0 \
&& echo 'export PATH=/usr/local/cmake3.2.0/bin:$PATH' >> /root/.bashrc \
- && rm cmake-3.2.0-Linux-x86_64.tar.gz \
- && wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
+ && rm cmake-3.2.0-Linux-x86_64.tar.gz
+
+RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
&& tar xzf go1.14.linux-amd64.tar.gz \
&& mv go /usr/local/go \
&& echo 'export GOROOT=/usr/local/go' >> /root/.bashrc \
&& echo 'export PATH=/usr/local/go/bin:$PATH' >> /root/.bashrc \
- && rm go1.14.linux-amd64.tar.gz \
- && yum -y install python-devel sqlite-devel >/dev/null \
+ && rm go1.14.linux-amd64.tar.gz
+
+RUN yum -y install python-devel sqlite-devel >/dev/null \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \
&& pip install google protobuf setuptools wheel flask >/dev/null \
- && rm get-pip.py \
- && wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
+ && rm get-pip.py
+
+RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
&& yum -y install bzip2 >/dev/null \
&& tar -jxf patchelf-0.10.tar.bz2 \
&& cd patchelf-0.10 \
&& ./configure --prefix=/usr \
&& make >/dev/null && make install >/dev/null \
&& cd .. \
- && rm -rf patchelf-0.10* \
- && yum install -y python3 python3-devel \
- && pip3 install google protobuf setuptools wheel flask \
- && yum -y update >/dev/null \
+ && rm -rf patchelf-0.10*
+
+RUN yum install -y python3 python3-devel \
+ && pip3 install google protobuf setuptools wheel flask
+
+RUN yum -y update >/dev/null \
&& yum -y install dnf >/dev/null \
&& yum -y install dnf-plugins-core >/dev/null \
&& dnf copr enable alonid/llvm-3.8.0 -y \
&& dnf install llvm-3.8.0 clang-3.8.0 compiler-rt-3.8.0 -y \
&& echo 'export PATH=/opt/llvm-3.8.0/bin:$PATH' >> /root/.bashrc
+
+RUN yum install -y java \
+ && wget http://repos.fedorapeople.org/repos/dchen/apache-maven/epel-apache-maven.repo -O /etc/yum.repos.d/epel-apache-maven.repo \
+ && yum install -y apache-maven
diff --git a/tools/serving_build.sh b/tools/serving_build.sh
index 175f084d3e29ded9005d7a1e7c39da3e001978c8..1ecae3e7f255438eb3a735c75bc797bf736bacde 100644
--- a/tools/serving_build.sh
+++ b/tools/serving_build.sh
@@ -499,6 +499,64 @@ function python_test_lac() {
cd ..
}
+function java_run_test() {
+ # pwd: /Serving
+ local TYPE=$1
+ export SERVING_BIN=${SERVING_WORKDIR}/build-server-${TYPE}/core/general-server/serving
+ unsetproxy
+ case $TYPE in
+ CPU)
+ # compile java sdk
+ cd java # pwd: /Serving/java
+ mvn compile > /dev/null
+ mvn install > /dev/null
+ # compile java sdk example
+ cd examples # pwd: /Serving/java/examples
+ mvn compile > /dev/null
+ mvn install > /dev/null
+
+ # fit_a_line (general, asyn_predict, batch_predict)
+ cd ../../python/examples/grpc_impl_example/fit_a_line # pwd: /Serving/python/examples/grpc_impl_example/fit_a_line
+ sh get_data.sh
+ check_cmd "python -m paddle_serving_server.serve --model uci_housing_model --port 9393 --thread 4 --use_multilang > /dev/null &"
+ sleep 5 # wait for the server to start
+ cd ../../../java/examples # /Serving/java/examples
+ java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample fit_a_line
+ java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample asyn_predict
+ java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample batch_predict
+ kill_server_process
+
+ # imdb (model_ensemble)
+ cd ../../python/examples/grpc_impl_example/imdb # pwd: /Serving/python/examples/grpc_impl_example/imdb
+ sh get_data.sh > /dev/null
+ check_cmd "python test_multilang_ensemble_server.py > /dev/null &"
+ sleep 5 # wait for the server to start
+ cd ../../../java/examples # /Serving/java/examples
+ java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample model_ensemble
+ kill_server_process
+
+ # yolov4 (int32)
+ cd ../../python/examples/grpc_impl_example/yolov4 # pwd: /Serving/python/examples/grpc_impl_example/yolov4
+ python -m paddle_serving_app.package --get_model yolov4 > /dev/null
+ tar -xzf yolov4.tar.gz > /dev/null
+ check_cmd "python -m paddle_serving_server.serve --model yolov4_model --port 9393 --use_multilang --mem_optim > /dev/null &"
+ cd ../../../java/examples # /Serving/java/examples
+ java -cp target/paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample yolov4 src/main/resources/000000570688.jpg
+ kill_server_process
+ cd ../../ # pwd: /Serving
+ ;;
+ GPU)
+ ;;
+ *)
+ echo "error type"
+ exit 1
+ ;;
+ esac
+ echo "java-sdk $TYPE part finished as expected."
+ setproxy
+ unset SERVING_BIN
+}
+
function python_test_grpc_impl() {
# pwd: /Serving/python/examples
cd grpc_impl_example # pwd: /Serving/python/examples/grpc_impl_example
@@ -537,7 +595,7 @@ function python_test_grpc_impl() {
# test load server config and client config in Server side
cd criteo_ctr_with_cube # pwd: /Serving/python/examples/grpc_impl_example/criteo_ctr_with_cube
- check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz"
+ check_cmd "wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz > /dev/null"
check_cmd "tar xf ctr_cube_unittest.tar.gz"
check_cmd "mv models/ctr_client_conf ./"
check_cmd "mv models/ctr_serving_model_kv ./"
@@ -978,6 +1036,7 @@ function main() {
build_client $TYPE # pwd: /Serving
build_server $TYPE # pwd: /Serving
build_app $TYPE # pwd: /Serving
+ java_run_test $TYPE # pwd: /Serving
python_run_test $TYPE # pwd: /Serving
monitor_test $TYPE # pwd: /Serving
echo "serving $TYPE part finished as expected."