diff --git a/CMakeLists.txt b/CMakeLists.txt index 59f565014b59f1393243a892f81f2069edd6eb9e..5917129f28f2d533f993a92a06685dda87822d05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,6 +147,7 @@ endif() # for lite, both server and mobile framework. option(WITH_LITE "Enable lite framework" OFF) +option(WITH_JAVA "Compile PaddlePaddle with Java JNI lib" OFF) option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF) option(LITE_WITH_X86 "Enable X86 in lite mode" ON) option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 6c867084212ff4db317a86c2c63211fd6490aec0..b636d5c690184a30d12871b7001dd6d195c06865 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -110,7 +110,7 @@ file(WRITE ${__lite_cc_files} "") # clean # LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK # HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK function(lite_cc_library TARGET) - set(options STATIC static SHARED shared) + set(options SHARED shared STATIC static MODULE module) set(oneValueArgs "") set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS ARGS) @@ -126,8 +126,12 @@ function(lite_cc_library TARGET) LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} ) - if (${args_SHARED} OR ${args_shared}) + + if (args_SHARED OR ARGS_shared) cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS} SHARED) + elseif (args_MODULE OR ARGS_module) + add_library(${TARGET} MODULE ${args_SRCS}) + add_dependencies(${TARGET} ${deps} ${args_DEPS}) else() cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) endif() diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 15285707c994d77574e5db7f81f6fb258987318c..d982f101759ca9a63934413954a8813525b18adf 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -125,6 +125,10 @@ if (WITH_TESTING) add_dependencies(test_paddle_api_lite extern_lite_download_lite_naive_model_tar_gz) endif() +if (WITH_JAVA AND LITE_WITH_ARM) + add_subdirectory(android/jni) +endif() + #lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc #X86_DEPS operator #DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes diff --git a/paddle/fluid/lite/api/android/jni/.gitignore b/paddle/fluid/lite/api/android/jni/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1299d2738c0d3321a46024d31e24049bef9ace9a --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/.gitignore @@ -0,0 +1,3 @@ +/PaddleListTest.class +/PaddleLite.class +/bin/ diff --git a/paddle/fluid/lite/api/android/jni/CMakeLists.txt b/paddle/fluid/lite/api/android/jni/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..12c1acb00d1127fa2ecac0a1130a13363da173d2 --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/CMakeLists.txt @@ -0,0 +1,50 @@ +if ((NOT WITH_LITE) OR (NOT WITH_JAVA)) + return() +endif() + +include(UseJava) +find_package(Java REQUIRED) + +# We are only interested in finding jni.h: we do not care about extended JVM +# functionality or the AWT library. +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_JVM_LIBRARY NotNeeded) +set(JAVA_INCLUDE_PATH2 NotNeeded) +set(JAVA_AWT_INCLUDE_PATH NotNeeded) +find_package(JNI REQUIRED) + +# Generate PaddlePredictor.jar +include_directories(${JNI_INCLUDE_DIRS}) +add_jar(PaddlePredictor src/com/baidu/paddle/lite/PaddlePredictor.java) +get_target_property(_jarFile PaddlePredictor JAR_FILE) +get_target_property(_classDir PaddlePredictor CLASSDIR) +set(_stubDir "${CMAKE_CURRENT_BINARY_DIR}") + +# Generate paddle_lite_jni.h +add_custom_target( + paddle_lite_jni_header ALL + COMMAND ${Java_JAVAH_EXECUTABLE} -verbose + -classpath ${_classDir} + -o paddle_lite_jni.h + -jni + com.baidu.paddle.lite.PaddlePredictor + DEPENDS PaddlePredictor +) + +# Generate paddle_lite_jni.so +include_directories(${JNI_INCLUDE_DIRS} ${_classDir} ${_stubDir}) +lite_cc_library(paddle_lite_jni MODULE SRCS paddle_lite_jni.cc + DEPS light_api_lite cxx_api_lite + paddle_api_full paddle_api_lite paddle_api_light op_registry_lite + ${ops_lite} ${lite_kernel_deps} + ARM_DEPS ${arm_kernels}) +if (APPLE) + # MacOS only accepts JNI lib ends with .jnilib or .dylib + set_target_properties(paddle_lite_jni PROPERTIES SUFFIX ".jnilib") +elseif (WIN32) + # Windows only accepts JNI lib ends with .dll + set_target_properties(paddle_lite_jni PROPERTIES SUFFIX ".dll") +endif (APPLE) +target_link_libraries(paddle_lite_jni light_api_lite cxx_api_lite + paddle_api_full paddle_api_lite paddle_api_light op_registry_lite + ${ops_lite} ${arm_kernels} ${lite_kernel_deps}) diff --git a/paddle/fluid/lite/api/android/jni/paddle_lite_jni.cc b/paddle/fluid/lite/api/android/jni/paddle_lite_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..a959ffc96a46ed1a7bbd06d0bd2c3de8e0d95535 --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/paddle_lite_jni.cc @@ -0,0 +1,256 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/lite/api/android/jni/paddle_lite_jni.h" +#include +#include +#include +#include + +#include "paddle/fluid/lite/kernels/arm/activation_compute.h" +#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h" +#include "paddle/fluid/lite/kernels/arm/concat_compute.h" +#include "paddle/fluid/lite/kernels/arm/conv_compute.h" +#include "paddle/fluid/lite/kernels/arm/dropout_compute.h" +#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h" +#include "paddle/fluid/lite/kernels/arm/fc_compute.h" +#include "paddle/fluid/lite/kernels/arm/mul_compute.h" +#include "paddle/fluid/lite/kernels/arm/pool_compute.h" +#include "paddle/fluid/lite/kernels/arm/scale_compute.h" +#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" +#include "paddle/fluid/lite/kernels/arm/split_compute.h" +#include "paddle/fluid/lite/kernels/arm/transpose_compute.h" + +#include "paddle/fluid/lite/api/light_api.h" +#include "paddle/fluid/lite/api/paddle_api.h" +#include "paddle/fluid/lite/api/paddle_lite_factory_helper.h" +#include "paddle/fluid/lite/api/paddle_use_kernels.h" +#include "paddle/fluid/lite/api/paddle_use_ops.h" +#include "paddle/fluid/lite/api/paddle_use_passes.h" + +#define ARM_KERNEL_POINTER(kernel_class_name__) \ + std::unique_ptr \ + p##kernel_class_name__( \ + new paddle::lite::kernels::arm::kernel_class_name__); + +#ifdef __cplusplus +extern "C" { +#endif + +using paddle::lite_api::MobileConfig; +using paddle::lite_api::PaddlePredictor; +using paddle::lite_api::Tensor; + +static std::shared_ptr predictor; + +/** + * Not sure why, we have to initial a pointer first for kernels. + * Otherwise it throws null pointer error when do KernelRegistor. + */ +static void use_arm_kernels() { + ARM_KERNEL_POINTER(BatchNormCompute); + ARM_KERNEL_POINTER(ConvCompute); + ARM_KERNEL_POINTER(ConcatCompute); + ARM_KERNEL_POINTER(ElementwiseAddCompute); + ARM_KERNEL_POINTER(DropoutCompute); + ARM_KERNEL_POINTER(FcCompute); + ARM_KERNEL_POINTER(MulCompute); + ARM_KERNEL_POINTER(PoolCompute); + ARM_KERNEL_POINTER(ReluCompute); + ARM_KERNEL_POINTER(ScaleCompute); + ARM_KERNEL_POINTER(SoftmaxCompute); + ARM_KERNEL_POINTER(SplitCompute); + ARM_KERNEL_POINTER(TransposeCompute); + ARM_KERNEL_POINTER(Transpose2Compute); +} + +inline std::string jstring_to_cpp_string(JNIEnv *env, jstring jstr) { + // In java, a unicode char will be encoded using 2 bytes (utf16). + // so jstring will contain characters utf16. std::string in c++ is + // essentially a string of bytes, not characters, so if we want to + // pass jstring from JNI to c++, we have convert utf16 to bytes. + if (!jstr) { + return ""; + } + const jclass stringClass = env->GetObjectClass(jstr); + const jmethodID getBytes = + env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); + const jbyteArray stringJbytes = (jbyteArray)env->CallObjectMethod( + jstr, getBytes, env->NewStringUTF("UTF-8")); + + size_t length = (size_t)env->GetArrayLength(stringJbytes); + jbyte *pBytes = env->GetByteArrayElements(stringJbytes, NULL); + + std::string ret = std::string(reinterpret_cast(pBytes), length); + env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); + + env->DeleteLocalRef(stringJbytes); + env->DeleteLocalRef(stringClass); + return ret; +} + +inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf, + int64_t len) { + jfloatArray result = env->NewFloatArray(len); + env->SetFloatArrayRegion(result, 0, len, buf); + return result; +} + +inline jintArray cpp_array_to_jintarray(JNIEnv *env, const int *buf, + int64_t len) { + jintArray result = env->NewIntArray(len); + env->SetIntArrayRegion(result, 0, len, buf); + return result; +} + +inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, const int8_t *buf, + int64_t len) { + jbyteArray result = env->NewByteArray(len); + env->SetByteArrayRegion(result, 0, len, buf); + return result; +} + +inline std::vector jintarray_to_int64_vector(JNIEnv *env, + jintArray dims) { + int dim_size = env->GetArrayLength(dims); + jint *dim_nums = env->GetIntArrayElements(dims, nullptr); + std::vector dim_vec(dim_nums, dim_nums + dim_size); + env->ReleaseIntArrayElements(dims, dim_nums, 0); + return dim_vec; +} + +inline static int64_t product(const std::vector &vec) { + if (vec.empty()) { + return 0; + } + int64_t result = 1; + for (int64_t d : vec) { + result *= d; + } + return result; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_loadMobileModel(JNIEnv *env, + jclass thiz, + jstring model_path) { + if (predictor != nullptr) { + return JNI_FALSE; + } + use_arm_kernels(); + MobileConfig config; + std::string model_dir = jstring_to_cpp_string(env, model_path); + config.set_model_dir(model_dir); + predictor = paddle::lite_api::CreatePaddlePredictor(config); + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_clear(JNIEnv *env, jclass thiz) { + if (predictor == nullptr) { + return JNI_FALSE; + } + predictor.reset(); + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3F( + JNIEnv *env, jclass thiz, jint offset, jintArray dims, jfloatArray buf) { + std::vector ddim = jintarray_to_int64_vector(env, dims); + + int len = env->GetArrayLength(buf); + if ((int64_t)len != product(ddim)) { + return JNI_FALSE; + } + + float *buffer = env->GetFloatArrayElements(buf, nullptr); + std::unique_ptr tensor = + predictor->GetInput(static_cast(offset)); + tensor->Resize(ddim); + float *input = tensor->mutable_data(); + for (int i = 0; i < len; ++i) { + input[i] = buffer[i]; + } + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3B( + JNIEnv *env, jclass thiz, jint offset, jintArray dims, jbyteArray buf) { + std::vector ddim = jintarray_to_int64_vector(env, dims); + + int len = env->GetArrayLength(buf); + if ((int64_t)len != product(ddim)) { + return JNI_FALSE; + } + + jbyte *buffer = env->GetByteArrayElements(buf, nullptr); + std::unique_ptr tensor = + predictor->GetInput(static_cast(offset)); + tensor->Resize(ddim); + int8_t *input = tensor->mutable_data(); + for (int i = 0; i < len; ++i) { + input[i] = (int8_t)buffer[i]; + } + + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_run(JNIEnv *, jclass) { + predictor->Run(); + return JNI_TRUE; +} + +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getFloatOutput(JNIEnv *env, + jclass thiz, + jint offset) { + std::unique_ptr tensor = + predictor->GetOutput(static_cast(offset)); + int64_t len = product(tensor->shape()); + return cpp_array_to_jfloatarray(env, tensor->data(), len); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getByteOutput(JNIEnv *env, + jclass thiz, + jint offset) { + std::unique_ptr tensor = + predictor->GetOutput(static_cast(offset)); + int64_t len = product(tensor->shape()); + return cpp_array_to_jbytearray(env, tensor->data(), len); +} + +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_fetchFloat(JNIEnv *env, jclass thiz, + jstring name) { + std::string cpp_name = jstring_to_cpp_string(env, name); + std::unique_ptr tensor = predictor->GetTensor(cpp_name); + int64_t len = product(tensor->shape()); + return cpp_array_to_jfloatarray(env, tensor->data(), len); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_fetchByte(JNIEnv *env, jclass thiz, + jstring name) { + std::string cpp_name = jstring_to_cpp_string(env, name); + std::unique_ptr tensor = predictor->GetTensor(cpp_name); + int64_t len = product(tensor->shape()); + return cpp_array_to_jbytearray(env, tensor->data(), len); +} + +#ifdef __cplusplus +} +#endif diff --git a/paddle/fluid/lite/api/android/jni/paddle_lite_jni.h b/paddle/fluid/lite/api/android/jni/paddle_lite_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..7748d07a9c727f90424f8feb11402a94861cddea --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/paddle_lite_jni.h @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_baidu_paddle_lite_PaddlePredictor */ + +#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_PADDLE_LITE_JNI_H_ +#define PADDLE_FLUID_LITE_API_ANDROID_JNI_PADDLE_LITE_JNI_H_ +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: loadMobileModel + * Signature: (Ljava/lang/String;)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_loadMobileModel(JNIEnv *, jclass, + jstring); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: clear + * Signature: ()Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_clear(JNIEnv *, jclass); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: setInput + * Signature: (I[I[F)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3F(JNIEnv *, jclass, + jint, jintArray, + jfloatArray); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: setInput + * Signature: (I[I[B)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3B(JNIEnv *, jclass, + jint, jintArray, + jbyteArray); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: run + * Signature: ()Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_run(JNIEnv *, jclass); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getFloatOutput + * Signature: (I)[F + */ +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getFloatOutput(JNIEnv *, jclass, + jint); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getByteOutput + * Signature: (I)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getByteOutput(JNIEnv *, jclass, + jint); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: fetchFloat + * Signature: (Ljava/lang/String;)[F + */ +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_fetchFloat(JNIEnv *, jclass, + jstring); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: fetchByte + * Signature: (Ljava/lang/String;)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_fetchByte(JNIEnv *, jclass, jstring); + +#ifdef __cplusplus +} +#endif +#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_PADDLE_LITE_JNI_H_ diff --git a/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore b/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..870ec275e827c663c24ab374bbec8c37c8f3d8b0 --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore @@ -0,0 +1,2 @@ +/PaddleLite.class +/PaddleLiteTest.class diff --git a/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java b/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java new file mode 100644 index 0000000000000000000000000000000000000000..d0adc6d6ec325c092cb97ea89ba1861d66e207b8 --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java @@ -0,0 +1,107 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** Java Native Interface (JNI) class for Paddle Lite APIs */ +public class PaddlePredictor { + + /** name of C++ JNI lib */ + private final static String JNI_LIB_NAME = "paddle_lite_jni"; + + /* load the C++ JNI lib */ + static { + System.loadLibrary(JNI_LIB_NAME); + } + + /** + * Loads mobile lite model, which is the model after optimizing passes. + * + * @param modelPath model file path + * @return true if load successfully + */ + public static native boolean loadMobileModel(String modelPath); + + /** + * Clears the current loaded model. + * + * @return true if a loaded model has been cleared. + */ + public static native boolean clear(); + + /** + * Set input data on offset-th column of feed data + * + * @param offset the offset-th column of feed data will be set + * @param buf the input data + * @param dims dimension format of the input image + * @return true if set successfully + */ + public static native boolean setInput(int offset, int[] dims, float[] buf); + + /** + * Set input data on offset-th column of feed data + * + * @param offset the offset-th column of feed data will be set + * @param buf the input data + * @param dims dimension format of the input image + * @return true if set successfully + */ + public static native boolean setInput(int offset, int[] dims, byte[] buf); + + /** + * Run the predict model + * + * @return true if run successfully + */ + public static native boolean run(); + + /** + * Get offset-th column of output data as float + * + * @param offset the offset-th column of output data will be returned + * @return model predict output + */ + public static native float[] getFloatOutput(int offset); + + /** + * Get offset-th column of output data as byte (int8 in C++ side) + * + * @param offset the offset-th column of output data will be returned + * @return model predict output + */ + public static native byte[] getByteOutput(int offset); + + /** + * Fetches a Tensor's value as Float data + * + * @param name Tensor's name + * @return values of the Tensor + */ + public static native float[] fetchFloat(String name); + + /** + * Fetches a Tensor's value as byte data (int8 at C++ side) + * + * @param name Tensor's name + * @return values of the Tensor + */ + public static native byte[] fetchByte(String name); + + /** + * Main function for test + */ + public static void main(String[] args) { + System.out.println("Load native library successfully"); + } +} diff --git a/paddle/fluid/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java b/paddle/fluid/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java new file mode 100644 index 0000000000000000000000000000000000000000..f2cf21ef000ca1012be7509c8286634bd1aa5acc --- /dev/null +++ b/paddle/fluid/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java @@ -0,0 +1,42 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +import org.junit.jupiter.api.Test; + +import static org.junit.Assert.assertEquals; + +class PaddlePredictorTest { + + @Test + public void run_defaultModel() { + PaddlePredictor.loadMobileModel(""); + + float[] inputBuffer = new float[10000]; + for (int i = 0; i < 10000; ++i) { + inputBuffer[i] = i; + } + int[] dims = { 100, 100 }; + + PaddlePredictor.setInput(0, dims, inputBuffer); + PaddlePredictor.run(); + float[] output = PaddlePredictor.getFloatOutput(0); + + assertEquals(output[0], 50.2132f, 1e-3f); + assertEquals(output[1], -28.8729f, 1e-3f); + + PaddlePredictor.clear(); + } + +}