提交 0a198e0a 编写于 作者: H Huihuang Zheng

Add CxxConfig and Related Java API and Improve CMakeLists

上级 c152fe93
......@@ -147,7 +147,7 @@ endif()
# for lite, both server and mobile framework.
option(WITH_LITE "Enable lite framework" OFF)
option(WITH_JAVA "Compile PaddlePaddle with Java JNI lib" OFF)
option(LITE_WITH_JAVA "Enable Java JNI lib in lite mode" OFF)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
......
......@@ -131,8 +131,8 @@ if (WITH_TESTING)
add_dependencies(test_paddle_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif()
if (WITH_JAVA AND LITE_WITH_ARM)
add_subdirectory(android/jni)
if (LITE_WITH_JAVA AND LITE_WITH_ARM)
add_subdirectory(android)
endif()
#lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
......
if ((NOT LITE_WITH_JAVA) OR (NOT LITE_WITH_ARM))
return()
endif()
add_subdirectory(jni)
if ((NOT WITH_LITE) OR (NOT WITH_JAVA))
if ((NOT LITE_WITH_ARM) OR (NOT LITE_WITH_JAVA))
return()
endif()
......@@ -15,7 +15,9 @@ find_package(JNI REQUIRED)
# Generate PaddlePredictor.jar
include_directories(${JNI_INCLUDE_DIRS})
add_jar(PaddlePredictor src/com/baidu/paddle/lite/PaddlePredictor.java)
add_jar(PaddlePredictor
src/com/baidu/paddle/lite/PaddlePredictor.java
src/com/baidu/paddle/lite/Place.java)
get_target_property(_jarFile PaddlePredictor JAR_FILE)
get_target_property(_classDir PaddlePredictor CLASSDIR)
set(_stubDir "${CMAKE_CURRENT_BINARY_DIR}")
......
......@@ -12,13 +12,22 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/lite/api/android/jni/paddle_lite_jni.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/api/light_api.h"
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/api/paddle_lite_factory_helper.h"
#include "paddle/fluid/lite/api/paddle_place.h"
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/kernels/arm/activation_compute.h"
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include "paddle/fluid/lite/kernels/arm/calib_compute.h"
#include "paddle/fluid/lite/kernels/arm/concat_compute.h"
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/kernels/arm/dropout_compute.h"
......@@ -31,13 +40,6 @@ limitations under the License. */
#include "paddle/fluid/lite/kernels/arm/split_compute.h"
#include "paddle/fluid/lite/kernels/arm/transpose_compute.h"
#include "paddle/fluid/lite/api/light_api.h"
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/api/paddle_lite_factory_helper.h"
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#define ARM_KERNEL_POINTER(kernel_class_name__) \
std::unique_ptr<paddle::lite::kernels::arm::kernel_class_name__> \
p##kernel_class_name__( \
......@@ -47,8 +49,10 @@ limitations under the License. */
extern "C" {
#endif
using paddle::lite_api::CxxConfig;
using paddle::lite_api::MobileConfig;
using paddle::lite_api::PaddlePredictor;
using paddle::lite_api::Place;
using paddle::lite_api::Tensor;
static std::shared_ptr<PaddlePredictor> predictor;
......@@ -59,6 +63,8 @@ static std::shared_ptr<PaddlePredictor> predictor;
*/
static void use_arm_kernels() {
ARM_KERNEL_POINTER(BatchNormCompute);
ARM_KERNEL_POINTER(CalibComputeFp32ToInt8);
ARM_KERNEL_POINTER(CalibComputeInt8ToFp32);
ARM_KERNEL_POINTER(ConvCompute);
ARM_KERNEL_POINTER(ConcatCompute);
ARM_KERNEL_POINTER(ElementwiseAddCompute);
......@@ -129,6 +135,31 @@ inline std::vector<int64_t> jintarray_to_int64_vector(JNIEnv *env,
return dim_vec;
}
/**
* Converts Java com.baidu.paddle.lite.Place to c++ paddle::lite_api::Place.
*/
inline static Place jplace_to_cpp_place(JNIEnv *env, jobject java_place) {
jclass place_jclazz = env->GetObjectClass(java_place);
jmethodID target_method =
env->GetMethodID(place_jclazz, "getTargetInt", "()I");
jmethodID precision_method =
env->GetMethodID(place_jclazz, "getPrecisionInt", "()I");
jmethodID data_layout_method =
env->GetMethodID(place_jclazz, "getDataLayoutInt", "()I");
jmethodID device_method = env->GetMethodID(place_jclazz, "getDevice", "()I");
int target = env->CallIntMethod(java_place, target_method);
int precision = env->CallIntMethod(java_place, precision_method);
int data_layout = env->CallIntMethod(java_place, data_layout_method);
int device = env->CallIntMethod(java_place, device_method);
return Place(static_cast<paddle::lite_api::TargetType>(target),
static_cast<paddle::lite_api::PrecisionType>(precision),
static_cast<paddle::lite_api::DataLayoutType>(data_layout),
device);
}
inline static int64_t product(const std::vector<int64_t> &vec) {
if (vec.empty()) {
return 0;
......@@ -140,6 +171,31 @@ inline static int64_t product(const std::vector<int64_t> &vec) {
return result;
}
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_loadCxxModel(
JNIEnv *env, jclass thiz, jstring model_path, jobject preferred_place,
jobjectArray valid_places) {
if (predictor != nullptr) {
return JNI_FALSE;
}
use_arm_kernels();
int valid_place_count = env->GetArrayLength(valid_places);
std::vector<Place> cpp_valid_places;
for (int i = 0; i < valid_place_count; ++i) {
jobject jplace = env->GetObjectArrayElement(valid_places, i);
cpp_valid_places.push_back(jplace_to_cpp_place(env, jplace));
}
CxxConfig config;
config.set_model_dir(jstring_to_cpp_string(env, model_path));
config.set_preferred_place(jplace_to_cpp_place(env, preferred_place));
config.set_valid_places(cpp_valid_places);
predictor = paddle::lite_api::CreatePaddlePredictor(config);
return predictor == nullptr ? JNI_FALSE : JNI_TRUE;
}
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_loadMobileModel(JNIEnv *env,
jclass thiz,
......@@ -149,9 +205,19 @@ Java_com_baidu_paddle_lite_PaddlePredictor_loadMobileModel(JNIEnv *env,
}
use_arm_kernels();
MobileConfig config;
std::string model_dir = jstring_to_cpp_string(env, model_path);
config.set_model_dir(model_dir);
config.set_model_dir(jstring_to_cpp_string(env, model_path));
predictor = paddle::lite_api::CreatePaddlePredictor(config);
return predictor == nullptr ? JNI_FALSE : JNI_TRUE;
}
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_saveOptimizedModel(
JNIEnv *env, jclass thiz, jstring model_path) {
if (predictor == nullptr) {
return JNI_FALSE;
}
predictor->SaveOptimizedModel(jstring_to_cpp_string(env, model_path));
return JNI_TRUE;
}
......@@ -167,6 +233,9 @@ Java_com_baidu_paddle_lite_PaddlePredictor_clear(JNIEnv *env, jclass thiz) {
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3F(
JNIEnv *env, jclass thiz, jint offset, jintArray dims, jfloatArray buf) {
if (predictor == nullptr) {
return JNI_FALSE;
}
std::vector<int64_t> ddim = jintarray_to_int64_vector(env, dims);
int len = env->GetArrayLength(buf);
......@@ -188,6 +257,9 @@ Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3F(
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3B(
JNIEnv *env, jclass thiz, jint offset, jintArray dims, jbyteArray buf) {
if (predictor == nullptr) {
return JNI_FALSE;
}
std::vector<int64_t> ddim = jintarray_to_int64_vector(env, dims);
int len = env->GetArrayLength(buf);
......@@ -209,6 +281,9 @@ Java_com_baidu_paddle_lite_PaddlePredictor_setInput__I_3I_3B(
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_run(JNIEnv *, jclass) {
if (predictor == nullptr) {
return JNI_FALSE;
}
predictor->Run();
return JNI_TRUE;
}
......
......@@ -21,6 +21,17 @@
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_baidu_paddle_lite_PaddlePredictor
* Method: loadCxxModel
* Signature:
* (Ljava/lang/String;Lcom/baidu/paddle/lite/Place;[Lcom/baidu/paddle/lite/Place;)Z
*/
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_loadCxxModel(JNIEnv *, jclass,
jstring, jobject,
jobjectArray);
/*
* Class: com_baidu_paddle_lite_PaddlePredictor
* Method: loadMobileModel
......@@ -30,6 +41,15 @@ JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_loadMobileModel(JNIEnv *, jclass,
jstring);
/*
* Class: com_baidu_paddle_lite_PaddlePredictor
* Method: saveOptimizedModel
* Signature: (Ljava/lang/String;)Z
*/
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddle_lite_PaddlePredictor_saveOptimizedModel(JNIEnv *, jclass,
jstring);
/*
* Class: com_baidu_paddle_lite_PaddlePredictor
* Method: clear
......
......@@ -24,6 +24,20 @@ public class PaddlePredictor {
System.loadLibrary(JNI_LIB_NAME);
}
/**
* Loads mobile cxx model, which is the model before optimizing passes. The cxx
* model allow users to manage hardware place resources. Caller uses a place at
* Java to control Target, DataLayout, Precision, and Device ID. More details
* about the four fields see our Paddle-Mobile document.
*
*
* @param modelPath modelPath model file path
* @param preferredPlace preferred place to run Cxx Model
* @param validPlaces n * 4 int array, valid places to run Cxx Model
* @return true if load successfully
*/
public static native boolean loadCxxModel(String modelPath, Place preferredPlace, Place[] validPlaces);
/**
* Loads mobile lite model, which is the model after optimizing passes.
*
......@@ -32,6 +46,15 @@ public class PaddlePredictor {
*/
public static native boolean loadMobileModel(String modelPath);
/**
* Saves optimized model, which is the model can be used by
* {@link loadMobileModel}
*
* @param modelPath model file path
* @return true if save successfully
*/
public static native boolean saveOptimizedModel(String modelPath);
/**
* Clears the current loaded model.
*
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
package com.baidu.paddle.lite;
/**
* Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate.
*/
public class Place {
public enum TargetType {
UNKNOWN(0), HOST(1), X86(2), CUDA(3), ARM(4), OPEN_CL(5), ANY(6);
public final int value;
private TargetType(int value) {
this.value = value;
}
}
public enum PrecisionType {
UNKNOWN(0), FLOAT(1), INT8(2), INT32(3), ANY(4);
public final int value;
private PrecisionType(int value) {
this.value = value;
}
}
public enum DataLayoutType {
UNKNOWN(0), NCHW(1), ANY(2);
public final int value;
private DataLayoutType(int value) {
this.value = value;
}
}
public TargetType target;
public PrecisionType precision;
public DataLayoutType layout;
public int device;
public Place() {
target = TargetType.UNKNOWN;
precision = PrecisionType.UNKNOWN;
layout = DataLayoutType.UNKNOWN;
device = 0;
}
public Place(TargetType target) {
this(target, PrecisionType.FLOAT);
}
public Place(TargetType target, PrecisionType precision) {
this(target, precision, DataLayoutType.NCHW);
}
public Place(TargetType target, PrecisionType precision, DataLayoutType layout) {
this(target, precision, layout, 0);
}
public Place(TargetType target, PrecisionType precision, DataLayoutType layout, int device) {
this.target = target;
this.precision = precision;
this.layout = layout;
this.device = device;
}
public boolean isValid() {
return target != TargetType.UNKNOWN && precision != PrecisionType.UNKNOWN && layout != DataLayoutType.UNKNOWN;
}
public int getTargetInt() {
return target.value;
}
public int getPrecisionInt() {
return precision.value;
}
public int getDataLayoutInt() {
return layout.value;
}
public int getDevice() {
return device;
}
}
......@@ -33,6 +33,7 @@ class PaddlePredictorTest {
PaddlePredictor.run();
float[] output = PaddlePredictor.getFloatOutput(0);
assertEquals(output.length, 50000);
assertEquals(output[0], 50.2132f, 1e-3f);
assertEquals(output[1], -28.8729f, 1e-3f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册