diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..298ea9e213e8c4c11f0431077510d4e325733c65 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 264f5633f683141fb4d5b4fae1537cfaf9e94044..c44c4e171d128c1469343cd01a91cc1d12762b8a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,8 @@ Welcome to Paddle-Lite's documentation! :caption: 使用指南 :name: sec-user-guides + user_guides/cuda + .. toctree:: :maxdepth: 1 :caption: 进阶使用指南 diff --git a/docs/user_guides/cuda.md b/docs/user_guides/cuda.md new file mode 100644 index 0000000000000000000000000000000000000000..45597057bb18c44b60234459f9a49a59b54135f6 --- /dev/null +++ b/docs/user_guides/cuda.md @@ -0,0 +1,110 @@ +# Lite基于CUDA的模型预测 + +Lite支持在x86_64,arm64架构上(如:TX2)进行CUDA的编译运行。 + +## 编译 + +**NOTE:** 如果是在TX2等NVIDIA嵌入式硬件上编译,请使用最新的[Jetpack](https://developer.nvidia.com/embedded/jetpack) 安装依赖库。 + + +一: 下载代码 + +``` +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +``` + +二:编译 + +``` +# 进入代码目录 +cd Paddle-Lite + +# 运行编译脚本 +# 编译结束会在本目录下生成 build_cuda 目录 +# 编译过程中如果提示找不到CUDA,CUDNN,请在环境变量设置CUDA_TOOLKIT_ROOT_DIR, CUDNN_ROOT +# CUDA_TOOLKIT_ROOT_DIR,CUDNN_ROOT分别表示CUDA,CUDNN的根目录 +./lite/tools/build.sh cuda +# 如果使用python接口,需要打开build_python选项 +./lite/tools/build.sh --build_python=ON cuda +``` + +编译结束会在 `build_cuda/inference_lite_lib/python/lib/` 目录下生成 `lite_core.so`。 + +## 运行 + +以下以Yolov3模型为例,介绍如何在Nvidia GPU硬件上运行模型。 + +一: 下载darknet_yolov3模型,模型信息请参考[这里](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/yolov3) + + +``` +# 下载模型 +wget https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/yolov3_infer.tar.gz +tar -zxf yolov3_infer.tar.gz +# 下载图片样例 +wget https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/kite.jpg +``` + +二: 运行 + +**NOTE:**此处示例使用的是python接口,后续会开放C++接口以及示例。 + +``` python +#-*- coding: utf-8 -*- +from __future__ import print_function +import sys +import numpy as np +import cv2 +sys.path.append('build_cuda/inference_lite_lib/python/lib') +from lite_core import * + +def read_img(im_path, resize_h, resize_w): + im = cv2.imread(im_path).astype('float32') + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + h, w, _ = im.shape + im_scale_x = resize_h / float(w) + im_scale_y = resize_w / float(h) + out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC) + mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, -1)) + std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, -1)) + out_img = (out_img / 255.0 - mean) / std + out_img = out_img.transpose((2, 0, 1)) + return out_img + +# 配置config +a = CxxConfig() +a.set_model_file('./yolov3_infer/__model__') # 指定模型文件路径 +a.set_param_file('./yolov3_infer/__params__') # 指定参数文件路径 +place_cuda = Place(TargetType.CUDA) +a.set_valid_places([place_cuda]) + +# 创建predictor +predictor = create_paddle_predictor(a) + +# 设置输入 +input_tensor = predictor.get_input(0); +height, width = 608, 608 +input_tensor.resize([1, 3, height, width]) +data = read_img('./kite.jpg', height, width).flatten() +input_tensor.set_float_data(data, TargetType.CUDA) + +in2 = predictor.get_input(1); +in2.resize([1, 2]) +in2.set_int32_data([height, width], TargetType.CUDA) + +# 运行 +predictor.run() + +# 获取输出 +output_tensor = predictor.get_output(0); + +print (output_tensor.shape()) +# [100L, 6L] +print (output_tensor.target()) +# TargetType.Host +print (output_tensor.float_data()[:6]) +# [0.0, 0.9862784743309021, 98.51927185058594, 471.2381286621094, 120.73092651367188, 578.33251953125] + +``` + +**NOTE:** 对CUDA的支持还在持续开发中。 diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index cb6a872e061a51f142bd2301171f0559a1ccb129..bac6f80c4721e0c5de201eebfe7e6a39a0bdc73a 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -232,6 +232,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_classify/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_classify/Makefile" COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_cv" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_cv/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_cv/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mask_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile" ) add_dependencies(publish_inference_android_cxx_demos logging gflags) add_dependencies(publish_inference_cxx_lib publish_inference_android_cxx_demos) @@ -251,6 +253,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_classify/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_classify/Makefile" COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_cv" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_cv/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_cv/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mask_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile" ) add_dependencies(tiny_publish_cxx_lib publish_inference_android_cxx_demos) endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt old mode 100644 new mode 100755 index d203b765180f5ffd8e74f6c7f3bfb330ee23ffa4..70f483822ac484576fe6934c0a30e85593e1e93a --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -1,4 +1,4 @@ -if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) lite_cc_library(place SRCS paddle_place.cc DEPS logging) else() lite_cc_library(place SRCS paddle_place.cc DEPS glog) @@ -218,20 +218,11 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) --model_dir=${LITE_MODEL_DIR}/resnet50 SERIAL) add_dependencies(test_resnet50 extern_lite_download_resnet50_tar_gz) - lite_cc_test(test_resnet50_fpga SRCS resnet50_test_fpga.cc + lite_cc_test(test_ssd_fpga SRCS test_ssd_fpga.cc DEPS ${lite_model_test_DEPS} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels}) - lite_cc_test(test_inceptionv4 SRCS inceptionv4_test.cc - DEPS ${lite_model_test_DEPS} - CL_DEPS ${opencl_kernels} - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl - --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) - add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) - # lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc - # DEPS ${lite_model_test_DEPS}) - # lite_cc_test(model_run_test_image SRCS model_run_test_image.cc # DEPS ${lite_model_test_DEPS} # CL_DEPS ${opencl_kernels} @@ -296,10 +287,10 @@ if (LITE_ON_TINY_PUBLISH) endif() if (LITE_ON_MODEL_OPTIMIZE_TOOL) - message(STATUS "Compiling model_optimize_tool") - lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc + message(STATUS "Compiling opt") + lite_cc_binary(opt SRCS opt.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc DEPS gflags kernel op optimizer mir_passes utils) - add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h) + add_dependencies(opt op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h) endif(LITE_ON_MODEL_OPTIMIZE_TOOL) lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light diff --git a/lite/api/android/jni/native/convert_util_jni.h b/lite/api/android/jni/native/convert_util_jni.h index 5e5d3723e43eb311f64b85f7507a12497d724109..e4adafdc572fdc937f568508aa9d43eb78470d0d 100644 --- a/lite/api/android/jni/native/convert_util_jni.h +++ b/lite/api/android/jni/native/convert_util_jni.h @@ -181,6 +181,7 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env, MobileConfig config; // set model dir + // NOTE: This is a deprecated API and will be removed in latter release. jmethodID model_dir_method = env->GetMethodID( mobileconfig_jclazz, "getModelDir", "()Ljava/lang/String;"); jstring java_model_dir = @@ -190,6 +191,27 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env, config.set_model_dir(cpp_model_dir); } + // set model from file + jmethodID model_file_method = env->GetMethodID( + mobileconfig_jclazz, "getModelFromFile", "()Ljava/lang/String;"); + jstring java_model_file = + (jstring)env->CallObjectMethod(jmobileconfig, model_file_method); + if (java_model_file != nullptr) { + std::string cpp_model_file = jstring_to_cpp_string(env, java_model_file); + config.set_model_from_file(cpp_model_file); + } + + // set model from buffer + jmethodID model_buffer_method = env->GetMethodID( + mobileconfig_jclazz, "getModelFromBuffer", "()Ljava/lang/String;"); + jstring java_model_buffer = + (jstring)env->CallObjectMethod(jmobileconfig, model_buffer_method); + if (java_model_buffer != nullptr) { + std::string cpp_model_buffer = + jstring_to_cpp_string(env, java_model_buffer); + config.set_model_from_buffer(cpp_model_buffer); + } + // set threads jmethodID threads_method = env->GetMethodID(mobileconfig_jclazz, "getThreads", "()I"); diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java index 5c71db0c92b344e44ea2927305580de1be293f75..e150f98f22113ef6bcedd5e9882e0bd2a6378c97 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java @@ -64,6 +64,44 @@ public class MobileConfig extends ConfigBase { return powerMode.value(); } + /** + * Set model from file. + * + * @return + */ + public void setModelFromFile(String modelFile) { + this.liteModelFile = modelFile; + } + + /** + * Returns name of model_file. + * + * @return liteModelFile + */ + public String getModelFile() { + return liteModelFile; + } + + /** + * Set model from buffer. + * + * @return + */ + public void setModelFromBuffer(String modelBuffer) { + this.liteModelBuffer = modelBuffer; + } + + /** + * Returns model buffer + * + * @return liteModelBuffer + */ + public String getModelBuffer() { + return liteModelBuffer; + } + private PowerMode powerMode = PowerMode.LITE_POWER_HIGH; private int threads = 1; + private String liteModelFile; + private String liteModelBuffer; } diff --git a/lite/api/apis_test.cc b/lite/api/apis_test.cc index ac2c385d53ea0a1785393cd488d115d20c4264f1..bb852297d11a8862460ed6f12e007d727aca9428 100644 --- a/lite/api/apis_test.cc +++ b/lite/api/apis_test.cc @@ -62,7 +62,7 @@ TEST(CXXApi_LightApi, optim_model) { TEST(CXXApi_LightApi, save_and_load_model) { lite::Predictor cxx_api; - lite::LightPredictor light_api(FLAGS_optimized_model); + lite::LightPredictor light_api(FLAGS_optimized_model + ".nb", false); // CXXAPi { diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index f0cb6841d5b73ea600b9e2b7e2f055192811b6c3..718dbe44296f2d197efc5b567cf0cc211835d176 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -116,7 +116,7 @@ void Run(const std::vector>& input_shapes, lite_api::MobileConfig config; config.set_threads(FLAGS_threads); config.set_power_mode(static_cast(FLAGS_power_mode)); - config.set_model_dir(model_dir); + config.set_model_from_file(model_dir + ".nb"); auto predictor = lite_api::CreatePaddlePredictor(config); diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index c1e9fc422450adf96d62c68d622907bd7e15b405..9c0e8e1c343b8eb1705e871aa652e3254474391d 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -121,6 +121,7 @@ void Predictor::SaveOpKernelInfo(const std::string &model_dir) { << kpf_path; } +#ifndef LITE_WITH_FPGA lite::Tensor *Predictor::GetInput(size_t offset) { CHECK(input_names_.size() > offset) << "The network has " << input_names_.size() << " inputs" @@ -130,6 +131,17 @@ lite::Tensor *Predictor::GetInput(size_t offset) { << " in exec_scope"; return in_var->GetMutable(); } +#else +lite::Tensor *Predictor::GetInput(size_t offset) { + auto *_feed_list = exec_scope_->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto *feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + return &feed_list->at(offset); +} +#endif // get inputs names std::vector Predictor::GetInputNames() { return input_names_; } @@ -167,6 +179,8 @@ void Predictor::PrepareFeedFetch() { } } +#ifndef LITE_WITH_FPGA + const lite::Tensor *Predictor::GetOutput(size_t offset) const { CHECK(output_names_.size() > offset) << "The network has " << output_names_.size() << " outputs" @@ -186,6 +200,29 @@ std::vector Predictor::GetOutputs() const { } return outputs; } +#else + +const lite::Tensor *Predictor::GetOutput(size_t offset) const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return &fetch_list.at(offset); +} + +std::vector Predictor::GetOutputs() const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + + std::vector outputs; + for (auto out : fetch_list) { + outputs.push_back(&out); + } + return outputs; +} + +#endif const cpp::ProgramDesc &Predictor::program_desc() const { return program_desc_; @@ -239,7 +276,7 @@ void Predictor::Build(const std::string &model_path, case lite_api::LiteModelType::kNaiveBuffer: CHECK(!model_path.empty()) << "NaiveBuffer backend only supported combined param"; - LoadModelNaive(model_path, scope_.get(), &program_desc_); + LoadModelNaiveFromFile(model_path, scope_.get(), &program_desc_); break; default: LOG(FATAL) << "Unknown model type"; diff --git a/lite/api/cxx_api_test.cc b/lite/api/cxx_api_test.cc index 4d711302cb5880247f4a7b7082185c500b9ad6e9..cdf1e838366f4bcafc1c1c991d8805f115de7345 100644 --- a/lite/api/cxx_api_test.cc +++ b/lite/api/cxx_api_test.cc @@ -101,7 +101,7 @@ TEST(CXXApi, save_model) { TEST(CXXApi, load_model_naive) { lite::Predictor predictor; std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_optimized_model + ".naive", + predictor.Build(FLAGS_optimized_model + ".naive.nb", "", "", valid_places, diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 1558e286178b461dc04c4366dc3adca81b2dd9de..29d8f4f29ab822f8c9601bbd63a3626abbbf1818 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -18,6 +18,17 @@ namespace paddle { namespace lite { +void LightPredictor::Build(const std::string& lite_model_file, + bool model_from_memory) { + if (model_from_memory) { + LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_); + } else { + LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_); + } + BuildRuntimeProgram(cpp_program_desc_); + PrepareFeedFetch(); +} + void LightPredictor::Build(const std::string& model_dir, const std::string& model_buffer, const std::string& param_buffer, diff --git a/lite/api/light_api.h b/lite/api/light_api.h index d1789a9c98333f6e927ba470717d9227729f2108..aa25ea81c7b62238211f96265a4edc49f2d065a1 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -18,6 +18,7 @@ */ #pragma once +#include #include #include #include @@ -39,12 +40,22 @@ namespace lite { */ class LITE_API LightPredictor { public: - LightPredictor( - const std::string& model_dir, - const std::string& model_buffer = "", - const std::string& param_buffer = "", - bool model_from_memory = false, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) { + // constructor function of LightPredictor, `lite_model_file` refers to data in + // model file or buffer,`model_from_memory` refers to whther to load model + // from memory. + LightPredictor(const std::string& lite_model_file, + bool model_from_memory = false) { + scope_ = std::make_shared(); + Build(lite_model_file, model_from_memory); + } + + // NOTE: This is a deprecated API and will be removed in latter release. + LightPredictor(const std::string& model_dir, + const std::string& model_buffer = "", + const std::string& param_buffer = "", + bool model_from_memory = false, + lite_api::LiteModelType model_type = + lite_api::LiteModelType::kNaiveBuffer) { scope_ = std::make_shared(); Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); } @@ -69,6 +80,10 @@ class LITE_API LightPredictor { void PrepareFeedFetch(); private: + void Build(const std::string& lite_model_file, + bool model_from_memory = false); + + // NOTE: This is a deprecated API and will be removed in latter release. void Build( const std::string& model_dir, const std::string& model_buffer, diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index a0ae28df0958403237114a3d4b94031829019339..3965843250abe45c43490bdbb4aaed58915e0908 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -23,13 +23,17 @@ namespace lite { void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { // LightPredictor Only support NaiveBuffer backend in publish lib - raw_predictor_.reset( - new LightPredictor(config.model_dir(), - config.model_buffer(), - config.param_buffer(), - config.model_from_memory(), - lite_api::LiteModelType::kNaiveBuffer)); - + if (config.lite_model_file().empty()) { + raw_predictor_.reset( + new LightPredictor(config.model_dir(), + config.model_buffer(), + config.param_buffer(), + config.model_from_memory(), + lite_api::LiteModelType::kNaiveBuffer)); + } else { + raw_predictor_.reset(new LightPredictor(config.lite_model_file(), + config.model_from_memory())); + } mode_ = config.power_mode(); threads_ = config.threads(); } diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index cf646d823d97213a4a14573f72a95d1a55169c12..190890da4c109f39cc52ca5209cd952f8937f780 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -73,7 +73,7 @@ void Run(const std::vector>& input_shapes, const int repeat, const int warmup_times = 0) { lite_api::MobileConfig config; - config.set_model_dir(model_dir); + config.set_model_from_file(model_dir + ".nb"); config.set_power_mode(power_mode); config.set_threads(thread_num); diff --git a/lite/api/model_optimize_tool.cc b/lite/api/opt.cc similarity index 99% rename from lite/api/model_optimize_tool.cc rename to lite/api/opt.cc index fc23e0b54be41bff5b7b65b4e58908546b186bb4..c172169e59ec074b81a07e4fc96cd0363c50a10a 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/opt.cc @@ -17,7 +17,7 @@ #include #endif // "supported_kernel_op_info.h", "all_kernel_faked.cc" and "kernel_src_map.h" -// are created automatically during model_optimize_tool's compiling period +// are created automatically during opt's compiling period #include #include "all_kernel_faked.cc" // NOLINT #include "kernel_src_map.h" // NOLINT diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index aabb53529221bde53b6b2ee27b2efefee2e6054d..9f071cf7780e27defdd1fcd6be02844618165fb6 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -190,5 +190,27 @@ void ConfigBase::set_threads(int threads) { #endif } +// set model data in combined format, `set_model_from_file` refers to loading +// model from file, set_model_from_buffer refers to loading model from memory +// buffer +void MobileConfig::set_model_from_file(const std::string &x) { + lite_model_file_ = x; +} +void MobileConfig::set_model_from_buffer(const std::string &x) { + lite_model_file_ = x; + model_from_memory_ = true; +} +void MobileConfig::set_model_buffer(const char *model_buffer, + size_t model_buffer_size, + const char *param_buffer, + size_t param_buffer_size) { + LOG(WARNING) << "warning: `set_model_buffer` will be abandened in " + "release/v3.0.0, new method `set_model_from_buffer(const " + "std::string &x)` is recommended."; + model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size); + param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size); + model_from_memory_ = true; +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 6308699ac91900d161a55ee121e4d9777947fede..307eeb74e8b4cdc3b2d6188eb18490e4dcf89b8f 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -168,22 +168,40 @@ class LITE_API CxxConfig : public ConfigBase { /// MobileConfig is the config for the light weight predictor, it will skip /// IR optimization or other unnecessary stages. class LITE_API MobileConfig : public ConfigBase { + // whether to load data from memory. Model data will be loaded from memory + // buffer if model_from_memory_ is true. + bool model_from_memory_{false}; + + // model data readed from file or memory buffer in combined format. + std::string lite_model_file_; + + // NOTE: This is a deprecated variable and will be removed in latter release. std::string model_buffer_; std::string param_buffer_; - bool model_from_memory_{false}; public: + // set model data in combined format, `set_model_from_file` refers to loading + // model from file, set_model_from_buffer refers to loading model from memory + // buffer + void set_model_from_file(const std::string& x); + void set_model_from_buffer(const std::string& x); + // return model data in lite_model_file_, which is in combined format. + const std::string& lite_model_file() const { return lite_model_file_; } + + // return model_from_memory_, which indicates whether to load model from + // memory buffer. + bool model_from_memory() const { return model_from_memory_; } + + // NOTE: This is a deprecated API and will be removed in latter release. void set_model_buffer(const char* model_buffer, size_t model_buffer_size, const char* param_buffer, - size_t param_buffer_size) { - model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size); - param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size); - model_from_memory_ = true; - } + size_t param_buffer_size); - bool model_from_memory() const { return model_from_memory_; } + // NOTE: This is a deprecated API and will be removed in latter release. const std::string& model_buffer() const { return model_buffer_; } + + // NOTE: This is a deprecated API and will be removed in latter release. const std::string& param_buffer() const { return param_buffer_; } }; diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index 69d544c3decac9f312bc9eb03cdc6c3702c5032b..9213a24e5c0614550a098c4de8d97b6cf6695177 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -72,7 +72,7 @@ TEST(CxxApi, run) { #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(LightApi, run) { lite_api::MobileConfig config; - config.set_model_dir(FLAGS_model_dir + ".opt2.naive"); + config.set_model_from_file(FLAGS_model_dir + ".opt2.naive.nb"); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -109,16 +109,11 @@ TEST(LightApi, run) { // Demo2 for Loading model from memory TEST(MobileConfig, LoadfromMemory) { // Get naive buffer - auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb"; - auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb"; - std::string model_buffer = lite::ReadFile(model_path); - size_t size_model = model_buffer.length(); - std::string params_buffer = lite::ReadFile(params_path); - size_t size_params = params_buffer.length(); + auto model_file = std::string(FLAGS_model_dir) + ".opt2.naive.nb"; + std::string model_buffer = lite::ReadFile(model_file); // set model buffer and run model lite_api::MobileConfig config; - config.set_model_buffer( - model_buffer.c_str(), size_model, params_buffer.c_str(), size_params); + config.set_model_from_buffer(model_buffer); auto predictor = lite_api::CreatePaddlePredictor(config); auto input_tensor = predictor->GetInput(0); diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 943760d30742b74a0fe9150e4c2d8c8bb5dbc52a..a2e13e156370090bfb9b9390a3389859b88fac3e 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -41,6 +41,7 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(kernel_place_correct_pass) USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); diff --git a/lite/api/python/pybind/pybind.cc b/lite/api/python/pybind/pybind.cc index 7d4ed4e98701a5328b0f05387dc73ad8b93dfe18..2dfe0c49490ecd13e8a3ce480807bdf3875348b7 100644 --- a/lite/api/python/pybind/pybind.cc +++ b/lite/api/python/pybind/pybind.cc @@ -116,6 +116,8 @@ void BindLiteMobileConfig(py::module *m) { py::class_ mobile_config(*m, "MobileConfig"); mobile_config.def(py::init<>()) + .def("set_model_from_file", &MobileConfig::set_model_from_file) + .def("set_model_from_buffer", &MobileConfig::set_model_from_buffer) .def("set_model_dir", &MobileConfig::set_model_dir) .def("model_dir", &MobileConfig::model_dir) .def("set_model_buffer", &MobileConfig::set_model_buffer) diff --git a/lite/api/resnet50_test_fpga.cc b/lite/api/resnet50_test_fpga.cc index ab647f96998f1c0e73476369611218d0a7930c57..75e6f0cbbc43c3cd7eb9bfa89bc004554ea6f85b 100644 --- a/lite/api/resnet50_test_fpga.cc +++ b/lite/api/resnet50_test_fpga.cc @@ -31,11 +31,7 @@ TEST(ResNet50, test) { std::vector valid_places( {Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); diff --git a/lite/api/test_ssd_fpga.cc b/lite/api/test_ssd_fpga.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb2d75671a637c8042b39e2e90d70f1ae9e6f2fd --- /dev/null +++ b/lite/api/test_ssd_fpga.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_string(input_file, "", "input_file"); + +namespace paddle { +namespace lite { + +std::vector GetDirectoryFiles(const std::string& dir) { + std::vector files; + std::shared_ptr directory_ptr(opendir(dir.c_str()), + [](DIR* dir) { dir&& closedir(dir); }); + struct dirent* dirent_ptr; + if (!directory_ptr) { + std::cout << "Error opening : " << std::strerror(errno) << dir << std::endl; + return files; + } + + while ((dirent_ptr = readdir(directory_ptr.get())) != nullptr) { + files.push_back(std::string(dirent_ptr->d_name)); + } + return files; +} + +void readFromFile(int num, std::string path, float* data) { + std::ifstream file_stream(path); + // file_stream.open(path); + if (!file_stream.good()) { + std::cout << "file: " << path << " dones not exist!\n"; + exit(-1); + return; + } + // float* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + data[i] = value; + } + file_stream.close(); +} + +// #ifdef LITE_WITH_FPGA +TEST(ResNet50, test) { + lite::Predictor predictor; + std::vector valid_places({ + Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + // predictor.Build(FLAGS_model_dir, "", "", valid_places); + predictor.Build("", + FLAGS_model_dir + "/model", + FLAGS_model_dir + "/params", + valid_places); + + auto* input_tensor = predictor.GetInput(0); + int width = 300; + int height = 300; + + // std::ifstream file_stream(FLAGS_input_file); + // if (!file_stream.good()) { + // std::cout << "file: " << FLAGS_input_file << " dones not exist!\n"; + // exit(-1); + // return; + // } + + // file_stream >> height; + // file_stream >> width; + + input_tensor->Resize( + DDim(std::vector({1, 3, height, width}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + // readFromFile(item_size, "car.data", data); + + int num = 3 * width * height; + + // for (int i = 0; i < num; ++i) { + // float value = 0; + // file_stream >> value; + // data[i] = value; + // } + // file_stream.close(); + + for (int i = 0; i < 2; ++i) { + predictor.Run(); + } + + auto* out = predictor.GetOutput(0); + for (int i = 0; i < out->dims().production(); i++) { + std::cout << ":" << out->data()[i] << std::endl; + } + + std::string file = "output/" + FLAGS_input_file.substr(6); + std::cout << "file:::" << file << std::endl; + + std::ofstream ofs; + ofs.open(file); + for (int i = 0; i < out->dims().production(); i++) { + float value = out->data()[i]; + ofs << value << std::endl; + } + ofs.close(); + + LOG(INFO) << "================== Speed Report ==================="; +} +// #endif + +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index daf3957bb1fe92cf9d979439407732bba3b0d9a4..6125547b8ba611d016d5d85359a4138b0ede7607 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -109,7 +109,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size); float* pre_out = pre_din + pre_in_size; #else - float pre_din = tmp_din; + float* pre_din = tmp_din; float* pre_out = pre_din + pre_in_size; #endif prepack_input_nxwc4_dw( diff --git a/lite/backends/arm/math/type_trans.cc b/lite/backends/arm/math/type_trans.cc index 6ded50e75294ad5145b3b88c4c341d4cce09c812..c50abb741ded487efa03d7d46baf2c6f13a8791d 100644 --- a/lite/backends/arm/math/type_trans.cc +++ b/lite/backends/arm/math/type_trans.cc @@ -46,6 +46,7 @@ void fp32_to_int8(const float* din, float inv_scale = 1.f / scale[j % axis_size]; float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vmax = vdupq_n_f32(-127.f); float32x4_t vpoff = vdupq_n_f32(0.5f); float32x4_t vnoff = vdupq_n_f32(-0.5f); const float* din_c = din + j * inner_size; @@ -63,6 +64,14 @@ void fp32_to_int8(const float* din, "fmul v5.4s, v1.4s, %[scale].4s \n" "fmul v6.4s, v2.4s, %[scale].4s \n" "fmul v7.4s, v3.4s, %[scale].4s \n" + "fcmge v8.4s, v4.4s, %[vmax].4s \n" + "fcmge v9.4s, v5.4s, %[vmax].4s \n" + "fcmge v10.4s, v6.4s, %[vmax].4s \n" + "fcmge v11.4s, v7.4s, %[vmax].4s \n" + "bif v4.16b, %[vmax].16b, v8.16b \n" + "bif v5.16b, %[vmax].16b, v9.16b \n" + "bif v6.16b, %[vmax].16b, v10.16b \n" + "bif v7.16b, %[vmax].16b, v11.16b \n" "ldp q0, q1, [%[in]], #32 \n" "subs %[cnt], %[cnt], #1 \n" "FCVTAS v8.4s, v4.4s \n" @@ -79,7 +88,7 @@ void fp32_to_int8(const float* din, "str q8, [%[out]], #16 \n" "bne 0b \n" : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop) - : [scale] "w"(vscale) + : [scale] "w"(vscale), [vmax] "w"(vmax) : "v0", "v1", "v2", @@ -104,15 +113,23 @@ void fp32_to_int8(const float* din, "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" "vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n" - "vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n" "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vcgt.f32 q8, q3, %q[vzero] @ get mask > 0, in3\n" "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" "vbif.f32 q6, %q[vnoff], q10 @ get right offset\n" - "vbif.f32 q7, %q[vnoff], q11 @ get right offset\n" + "vbif.f32 q7, %q[vnoff], q8 @ get right offset\n" "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" "vmla.f32 q6, q2, %q[vscale] @ mul scale\n" "vmla.f32 q7, q3, %q[vscale] @ mul scale\n" + "vcge.f32 q8, q4, %q[vmax] @ q4 >= vmax \n" + "vcge.f32 q9, q5, %q[vmax] @ q4 >= vmax \n" + "vcge.f32 q10, q6, %q[vmax] @ q4 >= vmax \n" + "vbif q4, %q[vmax], q8 @ choose \n" + "vcge.f32 q8, q7, %q[vmax] @ q4 >= vmax \n" + "vbif q5, %q[vmax], q9 @ choose \n" + "vbif q6, %q[vmax], q10 @ choose \n" + "vbif q7, %q[vmax], q8 @ choose \n" "vcvt.s32.f32 q0, q4 @ cvt to int32\n" "vcvt.s32.f32 q1, q5 @ cvt to int32\n" "vcvt.s32.f32 q2, q6 @ cvt to int32\n" @@ -133,25 +150,16 @@ void fp32_to_int8(const float* din, : [vscale] "w"(vscale), [vpoff] "w"(vpoff), [vnoff] "w"(vnoff), - [vzero] "w"(vzero) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); + [vzero] "w"(vzero), + [vmax] "w"(vmax) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10"); #endif } const float* din_r = din_c + 16 * cnt; signed char* dout_r = dout_c + 16 * cnt; for (int i = 0; i < remain; ++i) { dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + dout_r[i] = dout_r[i] < -127 ? -127 : dout_r[i]; } } } diff --git a/lite/backends/fpga/KD/fpga_cv.cpp b/lite/backends/fpga/KD/fpga_cv.cpp deleted file mode 100644 index 15a20e368b09f193e3f43b574ff3682ce96782ad..0000000000000000000000000000000000000000 --- a/lite/backends/fpga/KD/fpga_cv.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "lite/backends/fpga/KD/fpga_cv.hpp" - -using paddle::zynqmp::float16; - -void fpga_resize(float* input, - int input_width, - int input_height, - int input_channel, - uint8_t* output, - int output_width, - int output_height) { - paddle::zynqmp::InplaceArgs inplace_args = {0, 0, 0}; - paddle::zynqmp::config_inplace(inplace_args); - - paddle::zynqmp::ImageInputArgs input_args = {nullptr}; - input_args.address = nullptr; - input_args.scale_address = nullptr; - - float16* input_image_address = - reinterpret_cast(paddle::zynqmp::fpga_malloc( - input_width * input_height * input_channel * sizeof(float16))); - int index = 0; - - for (int i = 0; i < input_width * input_height * input_channel; i++) { - input_image_address[i] = float16(1.0 * input[i]); - } - - paddle::zynqmp::ResizeArgs resize_args = {0}; - - resize_args.input_width = input_width; - resize_args.input_height = input_height; - resize_args.image_channel = input_channel; - resize_args.output_width = output_width; - resize_args.output_height = output_height; - float height_ratio = static_cast(input_height) / - static_cast(resize_args.output_height); - float width_ratio = static_cast(input_width) / - static_cast(resize_args.output_width); - resize_args.height_ratio = *reinterpret_cast(&height_ratio); - resize_args.width_ratio = *reinterpret_cast(&width_ratio); - - int output_size = - resize_args.output_width * resize_args.output_height * input_channel; - float16* fpga_output = reinterpret_cast( - paddle::zynqmp::fpga_malloc(output_size * sizeof(float16))); - resize_args.input_image_address = input_image_address; - resize_args.output_image_address = fpga_output; - - memset(fpga_output, 0, output_size * sizeof(float16)); - paddle::zynqmp::fpga_flush( - input_image_address, - input_width * input_height * input_channel * sizeof(float16)); - paddle::zynqmp::fpga_flush(resize_args.output_image_address, - output_size * sizeof(float16)); - int ret = paddle::zynqmp::compute_fpga_resize(resize_args); - if (ret == 0) { - paddle::zynqmp::fpga_invalidate(resize_args.output_image_address, - output_size * sizeof(float16)); - } - - for (int i = 0; i < output_size; i++) { - output[i] = fpga_output[i]; - } -} diff --git a/lite/backends/fpga/KD/fpga_cv.hpp b/lite/backends/fpga/KD/fpga_cv.hpp deleted file mode 100644 index 6aa52edfbb704a0571fb1052aff6ecf022e49596..0000000000000000000000000000000000000000 --- a/lite/backends/fpga/KD/fpga_cv.hpp +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "lite/backends/fpga/KD/float16.hpp" -#include "lite/backends/fpga/KD/llapi/zynqmp_api.h" -#include "lite/backends/fpga/KD/pe.hpp" - -void fpga_resize(float* input, - int input_width, - int input_height, - int input_channel, - uint8_t* output, - int output_width, - int output_height); diff --git a/lite/backends/fpga/KD/llapi/config.h b/lite/backends/fpga/KD/llapi/config.h deleted file mode 100755 index acf8c8adf4fc5593dcc4238ddc762fdb9fea6760..0000000000000000000000000000000000000000 --- a/lite/backends/fpga/KD/llapi/config.h +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#define PADDLE_LITE_ZU5 -#define FPGA_PRINT_MODE -#define PADDLE_LITE_PROFILE diff --git a/lite/backends/fpga/lite_tensor.h b/lite/backends/fpga/lite_tensor.h index 311fc8a98400e5a6916ba1b9c8de1e6e0bcec4c0..266e0b5ce0ea03108978c3b0a32fbf0e3872c83c 100644 --- a/lite/backends/fpga/lite_tensor.h +++ b/lite/backends/fpga/lite_tensor.h @@ -151,6 +151,10 @@ class TensorLite { size_t offset() const { return offset_; } bool IsInitialized() const { return buffer_->data(); } + void clear() { + buffer_->Free(); + offset_ = 0; + } // Other share data to this. void ShareDataWith(const TensorLite &other); diff --git a/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..1574cb4a69cd0388698707d8d91c1d9c18b625a2 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void concat2(__global const CL_DTYPE* x_data0, __global const CL_DTYPE* x_data1, __global CL_DTYPE* out_data, + int size, int axis_size, int pre_size, int post_size, int total, int total0, int total1) { + const int index = get_global_id(0); + if (index < size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = index * post_size + i * total; + int offset_in = index * post_size + i * total0; + // memcpy(out_data + offset_out, x_data0 + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data0 + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + }else if (index < axis_size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = index * post_size + i * total; + int offset_in = index * post_size + i * total1; + // memcpy(out_data + offset_out, x_data1 + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data1 + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + } +} + +__kernel void concat_mul(__global const CL_DTYPE* x_data, __global CL_DTYPE* out_data, + int axis_size, int pre_size, int post_size, int start, int total, int total0) { + const int index = get_global_id(0); + if (index < axis_size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = (start + index) * post_size + i * total; + int offset_in = index * post_size + i * total0; + // memcpy(out_data + offset_out, x_data + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + } +} diff --git a/lite/backends/opencl/cl_kernel/image/concat_kernel.cl b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..f0335116f87aac34740dd22ac68f2b6265e62445 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void concat2(__read_only image2d_t input0, + __read_only image2d_t input1, + __write_only image2d_t output, + int axis_size, int flag, int width) { + const int x = get_global_id(0); // image_width cxw/4 + const int y = get_global_id(1); // image_height nxh + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int xx = x / width; + if (flag == 0){ + xx = y / width; + } + if (xx < axis_size){ + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(x, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + }else{ + int new_val = xx - axis_size; + new_val *= width; + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(new_val, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + } + // WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); +} + +__kernel void concat_mul(__read_only image2d_t input0, + __write_only image2d_t output, + int axis_size, int flag, int width, int start) { + const int x = get_global_id(0); // image_width cxw/4 + const int y = get_global_id(1); // image_height nxh + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int xx = x / width; + if (flag == 0){ + xx = y / width; + } + + if (xx < axis_size && xx >= start){ + xx -= start; + xx *= width; + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(xx, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + } + +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..8d7950d6b897df833ada56e2de5be7c6203de9ea --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl @@ -0,0 +1,428 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void conv2d_3x3(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height, + __private const int output_c, + __private const int filter_channel, + __private const int filter_width, + __private const int filter_height, + __private const int group) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + if (out_c >= global_size_dim0 || + out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE_CH + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + CL_DTYPE4 input[9]; // 3x3 region of input + if (group == 1) { + for (int i = 0; i < input_c; ++i) { // each run for 3x3 + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + + input[0] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[1] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[2] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[3] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[4] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[5] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[6] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + input[7] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + input[8] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + int j = 0; + int2 pos_of_weight; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + CL_DTYPE4 weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 1; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 2; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 3; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 4; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 5; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 6; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 7; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 8; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + } + } else { // group != 1 + for (int i = 0; i < 4; i++) { + int used_input_channel_num = + (out_c * 4 + i) / (output_c / group) * filter_channel; + for (int f_c = 0; f_c < filter_channel; ++f_c) { + int input_c = used_input_channel_num + f_c; + int input_block = input_c / 4; + int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x, + in_pos_in_one_block.y); + input[0] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[1] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[2] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[3] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[4] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(pos_in.x, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[5] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[6] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + input[7] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + input[8] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + + CL_DTYPE tmp_out = 0; + for (int j = 0; j < 9; j++) { + int2 pos_of_weight; + pos_of_weight.x = (f_c / 4) * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3; + CL_DTYPE4 weight = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + + int f_c_offset = f_c % 4; + CL_DTYPE f_value; + if (f_c_offset == 0) { + f_value = weight.x; + } else if (f_c_offset == 1) { + f_value = weight.y; + } else if (f_c_offset == 2) { + f_value = weight.z; + } else if (f_c_offset == 3) { + f_value = weight.w; + } + + int input_c_offset = input_c % 4; + CL_DTYPE input_value; + if (input_c_offset == 0) { + input_value = input[j].x; + } else if (input_c_offset == 1) { + input_value = input[j].y; + } else if (input_c_offset == 2) { + input_value = input[j].z; + } else if (input_c_offset == 3) { + input_value = input[j].w; + } + tmp_out += f_value * input_value; + } + + if (i == 0) { + output.x += tmp_out; + } else if (i == 1) { + output.y += tmp_out; + } else if (i == 2) { + output.z += tmp_out; + } else if (i == 3) { + output.w += tmp_out; + } + } + } + } + + output = activation_type4(output); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/scale_kernel.cl b/lite/backends/opencl/cl_kernel/image/scale_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..739ff1338582b65d87dbd9c92f1ea86e0c49f0ff --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/scale_kernel.cl @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void scale(__read_only image2d_t input, + __write_only image2d_t output, + __private float scale, + __private float bias){ + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + in = convert_float(scale) * in + convert_float(bias); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); +} diff --git a/lite/backends/x86/math/math_function.cc b/lite/backends/x86/math/math_function.cc index 822b7df936d84c21c226a13a48e8c09a2343f86a..f242e14ad1119e9de78df4841d47ea40d8c751af 100644 --- a/lite/backends/x86/math/math_function.cc +++ b/lite/backends/x86/math/math_function.cc @@ -110,11 +110,11 @@ void set_constant(const lite::Context& context, lite::Tensor* tensor, float value) { TensorSetConstantWithTarget func(context, tensor, value); - //#ifdef PADDLE_WITH_CUDA + // #ifdef PADDLE_WITH_CUDA // tensor->target().apply_visitor(func); - //#else + // #else func(); - //#endif + // #endif } template @@ -128,12 +128,14 @@ struct RowwiseAdd { PADDLE_ENFORCE_EQ(vector.numel(), size); PADDLE_ENFORCE_EQ(output->dims(), in_dims); - auto in = lite::fluid::EigenMatrix::From(input); - auto vec = lite::fluid::EigenVector::Flatten(vector); - auto out = lite::fluid::EigenMatrix::From(*output); - + const T* input_data = input.data(); + const T* vector_data = vector.data(); + T* output_data = output->mutable_data(); for (int64_t i = 0; i < in_dims[0]; ++i) { - out.chip(i, 0) = in.chip(i, 0) + vec; + for (int64_t j = 0; j < size; ++j) { + output_data[i * in_dims[0] + j] = + input_data[i * in_dims[0] + j] + vector_data[j]; + } } } }; diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 379ef67f2996519d0c8007d8f191efbd2166a9e3..3f9fb97ee756eeac870fe5090de182d8c03d170b 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -25,6 +25,7 @@ lite_cc_library(mir_passes elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc + kernel_place_correct_pass.cc type_target_cast_pass.cc type_layout_cast_pass.cc type_precision_cast_pass.cc diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index c5ce74e30e34b5878a534010b6cf8b86f91a1118..b688bbc1083a6ab0f521381c4a988a12badc3141 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -29,6 +29,11 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { act_types.push_back("leaky_relu"); break; } + if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { + act_types.push_back("relu6"); + act_types.push_back("leaky_relu"); + break; + } } for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto act_type : act_types) { diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ff5a7a1f25239d9dbfc79491bd137804b16b6cfa..2720404fb03cddaf00c9a25d8287b14d69ca86e8 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -27,10 +27,24 @@ namespace mir { void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { // delete quant node std::vector quant_op_types = { - "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + "fake_quantize_abs_max", + "fake_quantize_range_abs_max", + "fake_quantize_moving_average_abs_max"}; + /* + for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) { + for (int i = 5; i >= 1; --i){ + fusion::DynamicQuantDequantOpFuser fuser("fake_quantize_abs_max", op_type, + i); + fuser(graph.get()); + } + } + */ + for (auto& op_type : quant_op_types) { fusion::DeleteQuantOpFuser fuser(op_type); fuser(graph.get()); + fusion::DeleteDynamicQuantOpFuser dfuser(op_type); + dfuser(graph.get()); } // fuse quantized node and dequant node diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index da611e4490f4ba7268d9011b3dbb391a63a88305..2c761c6c2a08d24a52db41478456f8db332ef2d2 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { return op_desc; } +void DeleteDynamicQuantOpFuser::BuildPattern() { + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X"); + auto* quant_node = + OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_op_type_, "OutScale"); + auto* output_act_node = + VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out"); + + quant_node->LinksFrom({input_act_node}); + output_scale_node->LinksFrom({quant_node}); + output_act_node->LinksFrom({quant_node}); + VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_; +} + +void DeleteDynamicQuantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto* input_act_node = matched.at("input_act_node"); + auto* quant_node = matched.at("quant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + + // obtain values, save values and relink node + int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto outlinks = output_act_node->outlinks; + for (auto* quantized_node : outlinks) { + auto* op_desc = quantized_node->stmt()->mutable_op_info(); + op_desc->SetAttr("bit_length", bit_length); + IR_NODE_LINK_TO(input_act_node, quantized_node) + } + + // delete nodes and edges + std::unordered_set nodes2rm = { + quant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); +} + +cpp::OpDesc DeleteDynamicQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} void DequantOpFuser::BuildPattern() { std::string weight_name = ""; if (quantized_op_type_ == "conv2d" || @@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, auto& valid_places = quantized_op->stmt()->op()->valid_places(); int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); + float input_scale = 0; + if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) { + input_scale = + quantized_op->stmt()->op_info()->GetAttr("input_scale"); + } float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); float whole_weight_scale = static_cast(range * range) / max_range / range; @@ -162,8 +214,12 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, for (int i = 0; i < weight_scale_size; i++) { weight_scale.push_back(whole_weight_scale); } +#ifndef LITE_WITH_FPGA op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); +#endif + if (quantized_op->stmt()->op_info()->HasAttr("input_scale")) { + op_desc.SetAttr("input_scale", input_scale); + } op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -171,12 +227,29 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, temp_tensor.CopyDataFrom(*quantized_weight_t); float* temp_data = temp_tensor.mutable_data(); size_t weight_num = quantized_weight_t->data_size(); + +#ifdef LITE_WITH_FPGA + float* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = temp_data[i] * whole_weight_scale; + } + quantized_weight_t->set_persistable(true); + quantized_weight_t->set_precision(PRECISION(kFloat)); +#else int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); for (size_t i = 0; i < weight_num; i++) { quantized_weight_data[i] = static_cast(temp_data[i]); } quantized_weight_t->set_persistable(true); quantized_weight_t->set_precision(PRECISION(kInt8)); +#endif + + // int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + // for (size_t i = 0; i < weight_num; i++) { + // quantized_weight_data[i] = static_cast(temp_data[i]); + // } + // quantized_weight_t->set_persistable(true); + // quantized_weight_t->set_precision(PRECISION(kInt8)); // new op and relink nodes auto new_quantized_op = LiteOpRegistry::Global().Create(quantized_op_type_); @@ -464,6 +537,197 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc op_desc; return op_desc; } +// ================dynamic quant fuse============== +// #define DYNAMIC_RANGE +void DynamicQuantDequantOpFuser::BuildPattern() { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + std::string weight_name = ""; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + weight_name = "Filter"; + } else { + weight_name = "Y"; + } + auto* quant_op_input = VarNode("quant_op_input") + ->assert_is_op_input(quant_type_, "X") + ->AsInput(); +#ifdef DYNAMIC_RANGE + auto* quant_op_in_scale = VarNode("quant_op_in_scale") + ->assert_is_op_input(quant_type_, "InScale") + ->AsIntermediate(); +#endif + auto* quant_op = OpNode("quant_op", quant_type_) + ->assert_is_op(quant_type_) + ->AsIntermediate(); + + auto* quant_op_out_scale = + VarNode("quant_op_out_scale") + ->assert_is_op_output(quant_type_, "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto* quant_op_out = VarNode("quant_op_out") + ->assert_is_op_output(quant_type_, "Out") + ->assert_is_op_input(op_type_) + ->AsIntermediate(); + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) + ->assert_is_op_input(op_type_, weight_name) + ->AsInput()); + + nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) + ->assert_is_op(op_type_) + ->AsIntermediate()); + + nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate()); + nodes.push_back(VarNode(string_format("dequant_op_out%d", i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + +#ifdef DYNAMIC_RANGE + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); +#endif + quant_op->LinksFrom({quant_op_input}); + quant_op_out->LinksFrom({quant_op}); + quant_op_out_scale->LinksFrom({quant_op}); + for (int i = 0; i < times_; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + +void DynamicQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + auto* quant_op_input = matched.at("quant_op_input"); +#ifdef DYNAMIC_RANGE + auto* quant_op_in_scale = matched.at("quant_op_in_scale"); +#endif + auto* quant_op = matched.at("quant_op"); + + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(matched.at(string_format("quantized_op_weight%d", i))); + nodes.push_back(matched.at(string_format("quantized_op%d", i))); + nodes.push_back(matched.at(string_format("quantized_op_out%d", i))); + nodes.push_back(matched.at(string_format("dequant_op%d", i))); + nodes.push_back(matched.at(string_format("dequant_op_out%d", i))); + } + int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); + auto* scope = quant_op->stmt()->op()->scope(); + auto& valid_places = quant_op->stmt()->op()->valid_places(); + int range = ((1 << (bit_length - 1)) - 1); + +#ifdef DYNAMIC_RANGE + auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) + ->GetMutable(); + float input_scale = input_scale_t->data()[0] / range; + VLOG(4) << "range: " << range << " input_scale: " << input_scale; +#endif + for (int i = 0; i < times_; i++) { + float max_range = nodes[i * kNumFields + kDequantOpOffset] + ->stmt() + ->op_info() + ->GetAttr("max_range"); + // weight_scale = max(abs(weight)) + float whole_weight_scale = + static_cast(range * range) / max_range / range; + + cpp::OpDesc op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); + + auto quantized_weight_var_name = + nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + std::vector weight_scale; + int weight_scale_size; + + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should + // be Cout. + weight_scale_size = quantized_weight_t->dims()[0]; + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Fc weight: Cin * Cout, the weight_scale_size should be Cout. + weight_scale_size = quantized_weight_t->dims()[1]; + } + for (int i = 0; i < weight_scale_size; i++) { + weight_scale.push_back(whole_weight_scale); + } + // op_desc.SetAttr("enable_int8", true); + // op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); + + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + size_t weight_num = quantized_weight_t->data_size(); + quantized_weight_t->set_persistable(true); + std::cout << "DynamicQuantDequantOpFuser::InsertNewNode====================" + "========================================" + << std::endl; +#ifdef LITE_WITH_FPGA + float* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = temp_data[i] * whole_weight_scale; + std::cout << whole_weight_scale << "," << temp_data[i] << "," + << quantized_weight_data[i] << std::endl; + } + quantized_weight_t->set_precision(PRECISION(kFloat)); +#else + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); + } + quantized_weight_t->set_precision(PRECISION(kInt8)); +#endif + auto quantized_op = LiteOpRegistry::Global().Create(op_type_); + quantized_op->Attach(op_desc, scope); + auto* new_op_node = + graph->GraphCreateInstructNode(quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_op_node); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], + new_op_node); + IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + } +} + +cpp::OpDesc DynamicQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} } // namespace fusion } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index bef9f4d9573d049700736c166cd0d31b668f7eff..c21df350f96143a09b3229776bf5c013b1988559 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase { private: std::string quant_op_type_{}; }; +class DeleteDynamicQuantOpFuser : public FuseBase { + public: + explicit DeleteDynamicQuantOpFuser(const std::string& quant_op_type) + : quant_op_type_(quant_op_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string quant_op_type_{}; +}; /* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs. */ @@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase { private: std::string quantized_op_type_{}; }; +// dynamic quantdequant op fuser +class DynamicQuantDequantOpFuser : public FuseBase { + public: + explicit DynamicQuantDequantOpFuser(const std::string& quantized_op_type, + const std::string& op_type, + int i) + : op_type_(op_type), quant_type_(quantized_op_type), times_(i) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{}; + std::string quant_type_{}; + int times_{1}; +}; } // namespace fusion } // namespace mir diff --git a/lite/core/mir/kernel_place_correct_pass.cc b/lite/core/mir/kernel_place_correct_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..dad7687bbec1ddbd7c8c787338005955de964f17 --- /dev/null +++ b/lite/core/mir/kernel_place_correct_pass.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/kernel_place_correct_pass.h" +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void KernelPlaceCorrectPass::Apply(const std::unique_ptr &graph) { + CorrectArgumentPlace(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(kernel_place_correct_pass, + paddle::lite::mir::KernelPlaceCorrectPass) + .BindTargets({TARGET(kFPGA)}); diff --git a/lite/core/mir/kernel_place_correct_pass.h b/lite/core/mir/kernel_place_correct_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..5fab5000862378976c16448f5a82f052ffbc20a5 --- /dev/null +++ b/lite/core/mir/kernel_place_correct_pass.h @@ -0,0 +1,147 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * Correct the place of the variables in the SSAGrpah, it will inference the + * variables' place by the kernels outputs them. + */ +class KernelPlaceCorrectPass : public DebugPass { + public: + void Apply(const std::unique_ptr& graph) override; + + private: + void CorrectArgumentPlace(SSAGraph* graph) { + auto& valid_places = graph->valid_places(); + auto valid_places_has_target = [&](TargetType t) -> bool { + for (auto& p : valid_places) { + if (p.target == t) { + return true; + } + } + return false; + }; + std::map lite_with_targets{ + {"kOpenCL", valid_places_has_target(TARGET(kOpenCL))}, + {"kFPGA", valid_places_has_target(TARGET(kFPGA))}}; + VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"]; + VLOG(4) << "lite_with_targets['kFPGA']:" << lite_with_targets["kFPGA"]; + + VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); + for (auto& x : graph->StmtTopologicalOrder()) { + auto& inst = x->AsStmt(); + // The IoCopyOp is a tool operator, it won't support the type inference. + // in fpga, we has io_copy+cali+layout tool ops, so we need type inference + // for + // tool operator + if ((!lite_with_targets["kFPGA"]) && (!lite_with_targets["kOpenCL"])) { + VLOG(3) << "inst.op_type() == 'io_copy', continue"; + if (inst.op_type() == "io_copy") continue; + } + // deal with inputs + VLOG(4) << "checking op " << inst.op_info()->Repr(); + + auto get_argname = [&]( + const std::string& node_name, + const std::map>& argname_map) + -> std::string { + for (auto& ele : argname_map) { + auto it = + std::find(ele.second.begin(), ele.second.end(), node_name); + if (it != ele.second.end()) return ele.first; + } + return ""; + }; + + bool need_correct_place = true; + + std::vector in_types; + std::vector out_types; + for (auto* x_in : x->inlinks) { + std::string node_name = x_in->AsArg().name; + std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; + VLOG(4) << "-- input arg_name:" << arg_name << " " + << "-- node name:" << node_name; + auto type = inst.picked_kernel().GetInputDeclType(arg_name); + if (!x_in->AsArg().type) { + need_correct_place &= false; + } else { + if (in_types.empty()) { + in_types.push_back(x_in->AsArg().type->target()); + } else { + if (in_types[0] != x_in->AsArg().type->target()) { + need_correct_place &= false; + } + } + } + } + + for (auto* x_out : x->outlinks) { + std::string node_name = x_out->AsArg().name; + std::string arg_name = + get_argname(node_name, inst.op_info()->outputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name << " in Inst " + << inst.op_type(); + VLOG(4) << "-- output arg_name " << arg_name; + auto type = inst.picked_kernel().GetOutputDeclType(arg_name); + if (!x_out->AsArg().type) { + need_correct_place &= false; + } else { + if (out_types.empty()) { + out_types.push_back(x_out->AsArg().type->target()); + } else { + if (out_types[0] != x_out->AsArg().type->target()) { + need_correct_place &= false; + } + } + } + } + + auto this_type = inst.picked_kernel().target(); + bool io_target_same = (in_types[0] == out_types[0]); + need_correct_place &= (io_target_same && (in_types[0] != this_type)); + if (need_correct_place) { + // update this kernel's valid place; + UpdateTarget(inst, in_types[0]); + } + } + } + + // Update me's kUnk fields by other's fields. + void UpdateTarget(mir::Node::Stmt& inst, TargetType new_target) { // NOLINT + auto new_place = inst.place(); + new_place.target = new_target; + std::vector places; + places.push_back(new_place); + inst.ResetKernels(places); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc index 4a90e530a46c4d42d2ba032da1828973dfc1bcef..52fd39182a7132777231929d49c319bb961cf7f9 100644 --- a/lite/core/mir/node.cc +++ b/lite/core/mir/node.cc @@ -53,6 +53,11 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, } valid_kernels_ = op_->CreateKernels(valid_places); } +void mir::Node::Stmt::ResetKernels(const std::vector &valid_places) { + CHECK(op_) << "change valid place failed, not created op"; + valid_kernels_.clear(); + valid_kernels_ = op_->CreateKernels(valid_places); +} mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { auto &x = AsArg(); diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index e2c8a68bde6ee18506de73a7531716695b3d54f1..e7c44d2be689a9d890158c097e198314413d1ba3 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -53,6 +53,7 @@ class Node { const std::vector& valid_places, lite::Scope* scope = nullptr); + void ResetKernels(const std::vector& valid_places); std::string op_type() const { return op_info()->Type(); } const OpInfo* op_info() const; OpInfo* mutable_op_info(); diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc old mode 100644 new mode 100755 index 2b5b65ce5903ede41137311c585c0e87eaaa0e9d..0d4c642877f7beccfe37ebb92a5f6e7e508d37b0 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -140,10 +140,12 @@ void SSAGraph::Build(const Program &program, arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; } + /* if (var_types.count(name) && !arg_node->arg()->type) { arg_node->arg()->type = LiteType::GetTensorTy( TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); } + */ if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(arg_node, op_node); @@ -153,10 +155,12 @@ void SSAGraph::Build(const Program &program, auto *arg_node = &node_storage_.back(); arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; + /* if (var_types.count(name) && !arg_node->arg()->type) { arg_node->arg()->type = LiteType::GetTensorTy( TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); } + */ if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); diff --git a/lite/core/mir/subgraph/subgraph_pass_test.cc b/lite/core/mir/subgraph/subgraph_pass_test.cc index a56c364f975fa6c3f82e1bbbb4489c93eb6ab724..252517939990d8ce48083badb342c22fae1459c6 100644 --- a/lite/core/mir/subgraph/subgraph_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_pass_test.cc @@ -157,7 +157,7 @@ std::shared_ptr TestModel( lite_api::LiteModelType::kNaiveBuffer); // Load optimized model lite_api::MobileConfig mobile_config; - mobile_config.set_model_dir(optimized_model_dir); + mobile_config.set_model_from_file(optimized_model_dir + ".nb"); mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); mobile_config.set_threads(1); predictor = lite_api::CreatePaddlePredictor(mobile_config); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index ae74bd8d4d5647139a13509dfda0bb2b41ecc5c7..85c22db45c6d3f8d6e00daf9cc74643ad308ba73 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -101,7 +101,6 @@ void TypeTargetTransformPass::AddIoCopyInst( auto io_copy_output_name = string_format("%s/target_trans", in->AsArg().name.c_str()); // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); - if (copied_nodes->count(in->AsArg().name)) { // Remove the old link RemoveDirectedLink(in, inst_node); @@ -116,12 +115,14 @@ void TypeTargetTransformPass::AddIoCopyInst( } else { // TODO(MyPandaShaoxiang) should set same place with input? auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); - // Set the place for io_copy_output_arg node, the target should be equal to - // to.target() - // The precision and layout should be equal to from.precision(), - // from.layout() +// Set the place for io_copy_output_arg node, the target should be equal to +// to.target() +// The precision and layout should be equal to from.precision(), +// from.layout() +#ifndef LITE_WITH_FPGA io_copy_output_arg->AsArg().type = LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); +#endif auto* io_copy_inst = graph->NewInstructNode(); bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h old mode 100644 new mode 100755 index ddd94484ac4bb8d96d5c55300c985d21b44f1843..bebafb88a8bcacbdd639d523831c0a61031191e3 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -77,6 +77,7 @@ class Optimizer { #endif "static_kernel_pick_pass", // pick original kernel from graph "variable_place_inference_pass", // inference arg/var's + "kernel_place_correct_pass", // info(target/precision/layout/device) // using kernel info "argument_type_display_pass", // debug pass: show arg-type-node's @@ -108,7 +109,9 @@ class Optimizer { "runtime_context_assign_pass", "argument_type_display_pass", +#ifndef LITE_WITH_FPGA "memory_optimize_pass", +#endif "npu_subgraph_pass", "xpu_subgraph_pass"}}; RunPasses(passes_local); diff --git a/lite/core/program.cc b/lite/core/program.cc old mode 100644 new mode 100755 index 41d178f015d723aff739e608501e4619f8b10f5d..4f6ea2ce470724c0b00993478c47eb0315b5a1e5 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -137,11 +137,16 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::Run() { for (auto& inst : instructions_) { +#ifndef LITE_WITH_FPGA if (inst.is_feed_fetch_op()) continue; + std::string op_type = inst.op()->op_info()->Type(); +#endif inst.Run(); #ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE +#ifndef LITE_WITH_FPGA LITE_PRECISION_PROFILE(inst) +#endif #endif // LITE_WITH_PRECISION_PROFILE #endif // LITE_WITH_PROFILE } diff --git a/lite/core/tensor.h b/lite/core/tensor.h old mode 100644 new mode 100755 diff --git a/lite/core/version.h.in b/lite/core/version.h.in index 3082adc5abecb20f5ce19032177fc7cdb75299ff..d34c32073b852a50b5d26984ed4812ac4f38a870 100644 --- a/lite/core/version.h.in +++ b/lite/core/version.h.in @@ -42,7 +42,7 @@ static std::string version() { std::string tag = paddlelite_tag(); if (tag.empty()) { - ss << paddlelite_branch() << "(" << paddlelite_commit() << ")"; + ss << paddlelite_commit(); } else { ss << tag; } diff --git a/lite/demo/cxx/README.md b/lite/demo/cxx/README.md index 3217a7ed49006325715e22f8aa82d155bc8bf927..a0162ddfdc83a8245e8d0d8d8862f0413cac5d8e 100644 --- a/lite/demo/cxx/README.md +++ b/lite/demo/cxx/README.md @@ -1,8 +1,77 @@ # C++ Demo + +> 欢迎加入PaddleLite百度官方QQ群(696965088),会有专业同学解答您的疑问与困惑。 + 1. 环境准备 - - 保证Android NDK在/opt目录下 + - 一台可以编译PaddleLite的电脑 - 一台armv7或armv8架构的安卓手机 -2. 编译并运行全量api的demo(注:当编译模式为tiny_pubish时将不存在该demo) + +2. 人脸识别和佩戴口罩判断的Demo + +参考[源码编译](https://paddlepaddle.github.io/Paddle-Lite/v2.2.0/source_compile/)准备编译环境。 + +执行下面命令,下载PaddleLite代码。 +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +``` + +进入PaddleLite根目录,编译预测库。 +```shell +./lite/tools/build.sh \ + --arm_os=android \ + --arm_abi=armv8 \ + --arm_lang=gcc \ + --android_stl=c++_static \ + --build_extra=ON \ + --shutdown_log=OFF \ + tiny_publish +``` + +进入编译目录,下载模型和图片的压缩包,编译可执行文件。 +```shell +cd build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mask_detection +wget https://paddle-inference-dist.bj.bcebos.com/mask_detection.tar.gz +tar zxvf mask_detection.tar.gz +make +``` + +当然,大家也可以通过PaddleHub下载人脸检测模型和口罩佩戴判断模型。 +``` +# 下载paddlehub以后,通过python执行以下代码 +import paddlehub as hub +pyramidbox_lite_mask = hub.Module(name="pyramidbox_lite_mask") +# 将模型保存在test_program文件夹之中 +pyramidbox_lite_mask.processor.save_inference_model(dirname="test_program") +通过以上命令,可以获得人脸检测和口罩佩戴判断模型,分别存储在pyramidbox_lite和mask_detector之中。文件夹中的__model__是模型结构文件,__param__文件是权重文件。 +``` + +电脑连接安卓手机,将可执行文件、测试图片、模型文件、预测库push到安卓手机上。 +``` +adb push mask_detection /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push face_detection /data/local/tmp +adb push mask_classification /data/local/tmp +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mask_detection +``` + +进入安卓手机,执行demo。 +``` +adb shell +cd /data/local/tmp +export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH +./mask_detection face_detection mask_classification test.jpg +``` + +回到电脑端,将结果取出,查看如下效果图。 +``` +adb pull /data/local/tmp/test_mask_detection_result.jpg ./ +``` + +![test_mask_detection_result](https://user-images.githubusercontent.com/7383104/74279176-6200cd00-4d55-11ea-9fc0-83cfc2b3b37d.jpg) + +3. 编译并运行全量api的demo(注:当编译模式为tiny_pubish时将不存在该demo) ```shell cd inference_lite_lib.android.armv8/demo/cxx/mobile_full wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz @@ -17,7 +86,7 @@ adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && ``` 运行成功将在控制台输出预测结果的前10个类别的预测概率 -3. 编译并运行轻量级api的demo +4. 编译并运行轻量级api的demo ```shell cd ../mobile_light make @@ -29,7 +98,7 @@ adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && ``` 运行成功将在控制台输出预测结果的前10个类别的预测概率 -4. 编译并运行ssd目标检测的demo +5. 编译并运行ssd目标检测的demo ```shell cd ../ssd_detection wget https://paddle-inference-dist.bj.bcebos.com/mobilenetv1-ssd.tar.gz @@ -46,7 +115,7 @@ adb pull /data/local/tmp/test_ssd_detection_result.jpg ./ ``` 运行成功将在ssd_detection目录下看到生成的目标检测结果图像: test_ssd_detection_result.jpg -5. 编译并运行yolov3目标检测的demo +6. 编译并运行yolov3目标检测的demo ```shell cd ../yolov3_detection wget https://paddle-inference-dist.bj.bcebos.com/mobilenetv1-yolov3.tar.gz @@ -63,7 +132,7 @@ adb pull /data/local/tmp/test_yolov3_detection_result.jpg ./ ``` 运行成功将在yolov3_detection目录下看到生成的目标检测结果图像: test_yolov3_detection_result.jpg -6. 编译并运行物体分类的demo +7. 编译并运行物体分类的demo ```shell cd ../mobile_classify wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz @@ -71,41 +140,41 @@ tar zxvf mobilenet_v1.tar.gz ./model_optimize_tool optimize model make -adb -s emulator-5554 push mobile_classify /data/local/tmp/ -adb -s emulator-5554 push test.jpg /data/local/tmp/ -adb -s emulator-5554 push labels.txt /data/local/tmp/ -adb -s emulator-5554 push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ -adb -s emulator-5554 shell chmod +x /data/local/tmp/mobile_classify -adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +adb push mobile_classify /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push labels.txt /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mobile_classify +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && /data/local/tmp/mobile_classify /data/local/tmp/mobilenetv1opt2 /data/local/tmp/test.jpg /data/local/tmp/labels.txt" ``` 运行成功将在控制台输出预测结果的前5个类别的预测概率 - 如若想看前10个类别的预测概率,在运行命令输入topk的值即可 eg: ```shell - adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && + adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && /data/local/tmp/mobile_classify /data/local/tmp/mobilenetv1opt2/ /data/local/tmp/test.jpg /data/local/tmp/labels.txt 10" ``` - 如若想看其他模型的分类结果, 在运行命令输入model_dir 及其model的输入大小即可 eg: ```shell - adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && + adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && /data/local/tmp/mobile_classify /data/local/tmp/mobilenetv2opt2/ /data/local/tmp/test.jpg /data/local/tmp/labels.txt 10 224 224" ``` -9. 编译含CV预处理库模型单测demo +8. 编译含CV预处理库模型单测demo ```shell cd ../test_cv wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz tar zxvf mobilenet_v1.tar.gz ./model_optimize_tool optimize model make -adb -s emulator-5554 push test_model_cv /data/local/tmp/ -adb -s emulator-5554 push test.jpg /data/local/tmp/ -adb -s emulator-5554 push labels.txt /data/local/tmp/ -adb -s emulator-5554 push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ -adb -s emulator-5554 shell chmod +x /data/local/tmp/test_model_cv -adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +adb push test_model_cv /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push labels.txt /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/test_model_cv +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && /data/local/tmp/test_model_cv /data/local/tmp/mobilenetv1opt2 /data/local/tmp/test.jpg /data/local/tmp/labels.txt" ``` 运行成功将在控制台输出预测结果的前10个类别的预测概率 diff --git a/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..dd6d4b0960160e140e2f051b78814d2fee08d5e0 --- /dev/null +++ b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 @@ -0,0 +1,61 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mask_detection: fetch_opencv mask_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mask_detection.o -o mask_detection $(CXX_LIBS) $(LDFLAGS) + +mask_detection.o: mask_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mask_detection.o -c mask_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mask_detection.o + rm -f mask_detection diff --git a/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..c2f601ed2f68c342b47c5add451f84c537f978de --- /dev/null +++ b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 @@ -0,0 +1,61 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mask_detection: fetch_opencv mask_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mask_detection.o -o mask_detection $(CXX_LIBS) $(LDFLAGS) + +mask_detection.o: mask_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mask_detection.o -c mask_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mask_detection.o + rm -f mask_detection diff --git a/lite/demo/cxx/mask_detection/mask_detection.cc b/lite/demo/cxx/mask_detection/mask_detection.cc new file mode 100644 index 0000000000000000000000000000000000000000..748b84365fc70aa59171a6bf8847f554308fdc8c --- /dev/null +++ b/lite/demo/cxx/mask_detection/mask_detection.cc @@ -0,0 +1,246 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +struct Object { + int batch_id; + cv::Rect rec; + int class_id; + float prob; +}; + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale(const float* din, + float* dout, + int size, + const std::vector mean, + const std::vector scale) { + if (mean.size() != 3 || scale.size() != 3) { + std::cerr << "[ERROR] mean or scale size must equal to 3\n"; + exit(1); + } + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(scale[0]); + float32x4_t vscale1 = vdupq_n_f32(scale[1]); + float32x4_t vscale2 = vdupq_n_f32(scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c1++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c2++) = (*(din++) - mean[2]) * scale[2]; + } +} + +void pre_process(const cv::Mat& img, + int width, + int height, + const std::vector& mean, + const std::vector& scale, + float* data, + bool is_scale = false) { + cv::Mat resized_img; + cv::resize( + img, resized_img, cv::Size(width, height), 0.f, 0.f, cv::INTER_CUBIC); + cv::Mat imgf; + float scale_factor = is_scale ? 1.f / 256 : 1.f; + resized_img.convertTo(imgf, CV_32FC3, scale_factor); + const float* dimg = reinterpret_cast(imgf.data); + neon_mean_scale(dimg, data, width * height, mean, scale); +} + +void RunModel(std::string det_model_dir, + std::string class_model_dir, + std::string img_path) { + // Prepare + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + float shrink = 0.2; + int width = img.cols; + int height = img.rows; + int s_width = static_cast(width * shrink); + int s_height = static_cast(height * shrink); + + // Detection + MobileConfig config; + config.set_model_dir(det_model_dir); + + // Create Predictor For Detction Model + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // Get Input Tensor + std::unique_ptr input_tensor0(std::move(predictor->GetInput(0))); + input_tensor0->Resize({1, 3, s_height, s_width}); + auto* data = input_tensor0->mutable_data(); + + // Do PreProcess + std::vector detect_mean = {104.f, 117.f, 123.f}; + std::vector detect_scale = {0.007843, 0.007843, 0.007843}; + pre_process(img, s_width, s_height, detect_mean, detect_scale, data, false); + + // Detection Model Run + predictor->Run(); + + // Get Output Tensor + std::unique_ptr output_tensor0( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor0->data(); + auto shape_out = output_tensor0->shape(); + int64_t out_len = ShapeProduction(shape_out); + + // Filter Out Detection Box + float detect_threshold = 0.3; + std::vector detect_result; + for (int i = 0; i < out_len / 6; ++i) { + if (outptr[1] >= detect_threshold) { + Object obj; + int xmin = static_cast(width * outptr[2]); + int ymin = static_cast(height * outptr[3]); + int xmax = static_cast(width * outptr[4]); + int ymax = static_cast(height * outptr[5]); + int w = xmax - xmin; + int h = ymax - ymin; + cv::Rect rec_clip = + cv::Rect(xmin, ymin, w, h) & cv::Rect(0, 0, width, height); + obj.rec = rec_clip; + detect_result.push_back(obj); + } + outptr += 6; + } + + // Classification + config.set_model_dir(class_model_dir); + + // Create Predictor For Classification Model + predictor = CreatePaddlePredictor(config); + + // Get Input Tensor + std::unique_ptr input_tensor1(std::move(predictor->GetInput(0))); + int classify_w = 128; + int classify_h = 128; + input_tensor1->Resize({1, 3, classify_h, classify_w}); + auto* input_data = input_tensor1->mutable_data(); + int detect_num = detect_result.size(); + std::vector classify_mean = {0.5f, 0.5f, 0.5f}; + std::vector classify_scale = {1.f, 1.f, 1.f}; + float classify_threshold = 0.5; + for (int i = 0; i < detect_num; ++i) { + cv::Rect rec_clip = detect_result[i].rec; + cv::Mat roi = img(rec_clip); + + // Do PreProcess + pre_process(roi, + classify_w, + classify_h, + classify_mean, + classify_scale, + input_data, + true); + + // Classification Model Run + predictor->Run(); + + // Get Output Tensor + std::unique_ptr output_tensor1( + std::move(predictor->GetOutput(1))); + auto* outptr = output_tensor1->data(); + + // Draw Detection and Classification Results + cv::rectangle(img, rec_clip, cv::Scalar(0, 0, 255), 2, cv::LINE_AA); + std::string text = outptr[1] > classify_threshold ? "wear mask" : "no mask"; + int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; + double font_scale = 1.f; + int thickness = 1; + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); + float new_font_scale = rec_clip.width * 0.7 * font_scale / text_size.width; + text_size = + cv::getTextSize(text, font_face, new_font_scale, thickness, nullptr); + cv::Point origin; + origin.x = rec_clip.x + 5; + origin.y = rec_clip.y + text_size.height + 5; + cv::putText(img, + text, + origin, + font_face, + new_font_scale, + cv::Scalar(0, 255, 255), + thickness, + cv::LINE_AA); + + std::cout << "detect face, location: x=" << rec_clip.x + << ", y=" << rec_clip.y << ", width=" << rec_clip.width + << ", height=" << rec_clip.height + << ", wear mask: " << (outptr[1] > classify_threshold) + << std::endl; + } + + // Write Result to Image File + int start = img_path.find_last_of("/"); + int end = img_path.find_last_of("."); + std::string img_name = img_path.substr(start + 1, end - start - 1); + std::string result_name = img_name + "_mask_detection_result.jpg"; + cv::imwrite(result_name, img); +} + +int main(int argc, char** argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] + << " detction_model_dir classification_model_dir image_path\n"; + exit(1); + } + std::string detect_model_dir = argv[1]; + std::string classify_model_dir = argv[2]; + std::string img_path = argv[3]; + RunModel(detect_model_dir, classify_model_dir, img_path); + return 0; +} diff --git a/lite/demo/cxx/ssd_detection/ssd_detection.cc b/lite/demo/cxx/ssd_detection/ssd_detection.cc index 011733eb87f551141c52ab8e23d9625c93c742fc..2408afcbf64a24924eca119a9d9481dc030250c9 100644 --- a/lite/demo/cxx/ssd_detection/ssd_detection.cc +++ b/lite/demo/cxx/ssd_detection/ssd_detection.cc @@ -82,8 +82,8 @@ void neon_mean_scale(const float* din, } for (; i < size; i++) { *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; - *(dout_c0++) = (*(din++) - mean[1]) * scale[1]; - *(dout_c0++) = (*(din++) - mean[2]) * scale[2]; + *(dout_c1++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c2++) = (*(din++) - mean[2]) * scale[2]; } } @@ -188,13 +188,12 @@ void RunModel(std::string model_dir, std::string img_path) { std::move(predictor->GetOutput(0))); auto* outptr = output_tensor->data(); auto shape_out = output_tensor->shape(); - int64_t cnt = 1; - for (auto& i : shape_out) { - cnt *= i; - } + int64_t cnt = ShapeProduction(shape_out); auto rec_out = detect_object(outptr, static_cast(cnt / 6), 0.6f, img); - std::string result_name = - img_path.substr(0, img_path.find(".")) + "_ssd_detection_result.jpg"; + int start = img_path.find_last_of("/"); + int end = img_path.find_last_of("."); + std::string img_name = img_path.substr(start + 1, end - start - 1); + std::string result_name = img_name + "_ssd_detection_result.jpg"; cv::imwrite(result_name, img); } diff --git a/lite/gen_code/paddle_infer.h b/lite/gen_code/paddle_infer.h old mode 100644 new mode 100755 index e01ffc25e29ca94166e8fe12b0643ae9e914001d..dc2d56422cd710778a36c5e85f42e701fbfcbf0f --- a/lite/gen_code/paddle_infer.h +++ b/lite/gen_code/paddle_infer.h @@ -46,7 +46,6 @@ class Tensor { */ class PaddlePredictor { public: - void Init(); std::unique_ptr GetTensor(const std::string &id) const; std::unique_ptr GetMutableTensor(const std::string &id); diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc old mode 100644 new mode 100755 index 266ae1fc916af4303aca274c39b9b4923fdbb154..0b92317ac51b0af24443ec24436f6a483198dbbc --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -62,6 +62,10 @@ void CastCompute::Run() { int32_t* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 5) { + const auto* x_data = param.X->data(); + auto* o_data = param.Out->mutable_data(); + memcpy(o_data, x_data, sizeof(float) * param.X->numel()); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/kernels/arm/fc_compute.h b/lite/kernels/arm/fc_compute.h index 2e5f2345e824b13d78a1575d3374652b8474c7fd..4f8a82a8689c1f221ee146176ff7074602cad1c9 100644 --- a/lite/kernels/arm/fc_compute.h +++ b/lite/kernels/arm/fc_compute.h @@ -95,7 +95,7 @@ class FcCompute : public KernelLite { CHECK_GE(x_dims.size(), 2UL); CHECK_EQ(w_dims.size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); + CHECK_GE(param.output->dims().size(), 2UL); m_ = x_dims.Slice(0, param.in_num_col_dims).production(); k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); diff --git a/lite/kernels/arm/fill_constant_compute.cc b/lite/kernels/arm/fill_constant_compute.cc index ad475538576b9cc73a43bac49cba1a6cf1c73edb..f265a3284bbff6b69f2861ef0cb00ac6a6d9012e 100644 --- a/lite/kernels/arm/fill_constant_compute.cc +++ b/lite/kernels/arm/fill_constant_compute.cc @@ -60,25 +60,10 @@ class FillConstantCompute : public KernelLite { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - if (param.dtype == static_cast(lite::core::FluidType::FP32)) { - auto data = param.Out->template mutable_data(); - for (int i = 0; i < param.Out->numel(); i++) { - data[i] = param.value; - } - } else if (param.dtype == - static_cast(lite::core::FluidType::INT32)) { - auto data = param.Out->template mutable_data(); - for (int i = 0; i < param.Out->numel(); i++) { - data[i] = param.value; - } - } else if (param.dtype == - static_cast(lite::core::FluidType::INT8)) { - auto data = param.Out->template mutable_data(); - for (int i = 0; i < param.Out->numel(); i++) { - data[i] = param.value; - } - } else { - LOG(FATAL) << "not supported dtype " << param.dtype; + // auto data = param.Out->template mutable_data(); + auto data = param.Out->template mutable_data(); + for (int i = 0; i < param.Out->numel(); i++) { + data[i] = param.value; } } @@ -94,32 +79,38 @@ class FillConstantBatchLikeCompute auto& param = *param_.get_mutable(); auto& context = ctx_->As(); - if (param.input->lod().size() && param.input_dim_idx == 0) { - auto odims = param.out->dims(); - odims[param.output_dim_idx] = param.input->lod().back().size() - 1; - param.out->Resize(odims); + // auto data = param.out->template mutable_data(); + auto data = param.out->template mutable_data(); + for (int i = 0; i < param.out->numel(); i++) { + data[i] = param.value; } - if (param.dtype == static_cast(lite::core::FluidType::FP32)) { - auto data = param.out->template mutable_data(); - for (int i = 0; i < param.out->numel(); i++) { - data[i] = param.value; - } - } else if (param.dtype == - static_cast(lite::core::FluidType::INT32)) { - auto data = param.out->template mutable_data(); - for (int i = 0; i < param.out->numel(); i++) { - data[i] = param.value; - } - } else if (param.dtype == - static_cast(lite::core::FluidType::INT8)) { - auto data = param.out->template mutable_data(); - for (int i = 0; i < param.out->numel(); i++) { - data[i] = param.value; - } - } else { - LOG(FATAL) << "not supported dtype " << param.dtype; - } + // if (param.input->lod().size() && param.input_dim_idx == 0) { + // auto odims = param.out->dims(); + // odims[param.output_dim_idx] = param.input->lod().back().size() - 1; + // param.out->Resize(odims); + // } + + // if (param.dtype == static_cast(lite::core::FluidType::FP32)) { + // auto data = param.out->template mutable_data(); + // for (int i = 0; i < param.out->numel(); i++) { + // data[i] = param.value; + // } + // } else if (param.dtype == + // static_cast(lite::core::FluidType::INT32)) { + // auto data = param.out->template mutable_data(); + // for (int i = 0; i < param.out->numel(); i++) { + // data[i] = param.value; + // } + // } else if (param.dtype == + // static_cast(lite::core::FluidType::INT8)) { + // auto data = param.out->template mutable_data(); + // for (int i = 0; i < param.out->numel(); i++) { + // data[i] = param.value; + // } + // } else { + // LOG(FATAL) << "not supported dtype " << param.dtype; + // } } virtual ~FillConstantBatchLikeCompute() = default; @@ -142,8 +133,9 @@ REGISTER_LITE_KERNEL(fill_constant, {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("ShapeTensorList", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); + REGISTER_LITE_KERNEL(fill_constant_batch_size_like, kARM, kAny, diff --git a/lite/kernels/arm/lookup_table_compute.cc b/lite/kernels/arm/lookup_table_compute.cc index af9426f3f4a7d9dd0d1260143b7b3e8aea15a034..5af21af78fbbbe0425cad63e3047c330b79129b5 100644 --- a/lite/kernels/arm/lookup_table_compute.cc +++ b/lite/kernels/arm/lookup_table_compute.cc @@ -36,7 +36,7 @@ void LookupTableCompute::Run() { auto table_dim = w->dims(); int64_t ids_numel = ids->numel(); - auto ids_data = ids->data(); + auto ids_data = ids->data(); int64_t row_number = table_dim[0]; int64_t row_width = table_dim[1]; @@ -75,7 +75,6 @@ REGISTER_LITE_KERNEL(lookup_table, .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); - REGISTER_LITE_KERNEL(lookup_table_v2, kARM, kFloat, diff --git a/lite/kernels/bm/bridges/CMakeLists.txt b/lite/kernels/bm/bridges/CMakeLists.txt index f9d1f8feea580a9eb21f7d7c11f604f76de98400..688e307a6475073461415da2ca1c8f2cc6c88aac 100644 --- a/lite/kernels/bm/bridges/CMakeLists.txt +++ b/lite/kernels/bm/bridges/CMakeLists.txt @@ -15,7 +15,12 @@ lite_cc_library(subgraph_bridge_softmax_op_bm SRCS softmax_op.cc DEPS ${subgraph lite_cc_library(subgraph_bridge_mul_op_bm SRCS mul_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_batch_norm_op_bm SRCS batch_norm_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_scale_op_bm SRCS scale_op.cc DEPS ${bm_subgraph_bridge_deps}) - +lite_cc_library(subgraph_bridge_concat_op_bm SRCS concat_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_dropout_op_bm SRCS dropout_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_transpose_op_bm SRCS transpose_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_reshape_op_bm SRCS reshape_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_norm_op_bm SRCS norm_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_prior_box_op_bm SRCS prior_box_op.cc DEPS ${bm_subgraph_bridge_deps}) set(bm_subgraph_bridges subgraph_bridge_registry subgraph_bridge_engine @@ -28,4 +33,10 @@ set(bm_subgraph_bridges subgraph_bridge_mul_op_bm subgraph_bridge_batch_norm_op_bm subgraph_bridge_scale_op_bm + subgraph_bridge_concat_op_bm + subgraph_bridge_dropout_op_bm + subgraph_bridge_transpose_op_bm + subgraph_bridge_reshape_op_bm + subgraph_bridge_norm_op_bm + subgraph_bridge_prior_box_op_bm CACHE INTERNAL "bm_subgraph_bridges") diff --git a/lite/kernels/bm/bridges/act_op.cc b/lite/kernels/bm/bridges/act_op.cc index 905bce25066634600099ab86c3eee4254d551dc2..0d3c4e0b83598358958ae670e554949deb7d1926 100644 --- a/lite/kernels/bm/bridges/act_op.cc +++ b/lite/kernels/bm/bridges/act_op.cc @@ -45,7 +45,14 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { for (size_t i = 0; i < output_dims.size(); i++) { i_output_shape_data[i] = static_cast(output_shape_data[i]); } - CHECK_EQ(op_type, "relu"); + float alpha = 0.f; + if (op_type == "relu") { + } else if (op_type == "leaky_relu") { + alpha = op_info->GetAttr("alpha"); + } else { + LOG(FATAL) << "[BM] unsupport act type"; + return FAILED; + } add_relu_layer(graph->GetCompilerHandle(), const_cast(&i_x_shape_data[0]), x_dims.size(), @@ -53,7 +60,7 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { const_cast(&i_output_shape_data[0]), output_dims.size(), static_cast(output_var_name.c_str()), - 0.f, + alpha, -1.f); graph->AddNode(output_var_name); return SUCCESS; @@ -65,3 +72,6 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // namespace paddle REGISTER_SUBGRAPH_BRIDGE(relu, kBM, paddle::lite::subgraph::bm::ActConverter); +REGISTER_SUBGRAPH_BRIDGE(leaky_relu, + kBM, + paddle::lite::subgraph::bm::ActConverter); diff --git a/lite/kernels/bm/bridges/concat_op.cc b/lite/kernels/bm/bridges/concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b568aa4d161b5af8d17a83cdedddc446fcd8237 --- /dev/null +++ b/lite/kernels/bm/bridges/concat_op.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // input + auto x_names = op_info->Input("X"); + auto x_type = kernel->GetInputDeclType("X"); + CHECK(x_type->layout() == DATALAYOUT(kNCHW)); + // output + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + const int64_t* output_shape_data = + const_cast(&output_dims.data()[0]); + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_shape_data[i]); + } + const int32_t input_num = x_names.size(); + int32_t** shape = new int32_t*[input_num]; + int32_t* dim = new int32_t[input_num]; + const char** name = new const char*[input_num]; + for (size_t i = 0; i < x_names.size(); i++) { + auto x = scope->FindMutableTensor(x_names[i]); + name[i] = x_names[i].c_str(); + auto x_dims = x->dims(); + dim[i] = x_dims.size(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + shape[i] = new int32_t[x_dims.size()]; + for (size_t j = 0; j < x_dims.size(); j++) { + shape[i][j] = static_cast(x_shape_data[j]); + } + } + + auto axis = op_info->GetAttr("axis"); + add_concat_layer(graph->GetCompilerHandle(), + input_num, + shape, + dim, + name, + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + axis); + for (size_t i = 0; i < x_names.size(); i++) { + delete[] shape[i]; + } + delete[] shape; + delete[] name; + delete[] dim; + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle +REGISTER_SUBGRAPH_BRIDGE(concat, + kBM, + paddle::lite::subgraph::bm::ConcatConverter); diff --git a/lite/kernels/bm/bridges/conv_op.cc b/lite/kernels/bm/bridges/conv_op.cc index ab48ade68f42dd0b75f58dbee1c369f4868b69d4..ffe5a59aca8124a0f7999a71b35947d11e37b4fe 100644 --- a/lite/kernels/bm/bridges/conv_op.cc +++ b/lite/kernels/bm/bridges/conv_op.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/operators/conv_op.h" #include #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/utility.h" @@ -58,10 +57,10 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { std::vector i_output_shape_data(output_dims.size()); for (size_t i = 0; i < input_dims.size(); i++) { - i_input_shape_data[i] = static_cast(input_shape_data[i]); + i_input_shape_data[i] = static_cast(input_shape_data[i]); } for (size_t i = 0; i < output_dims.size(); i++) { - i_output_shape_data[i] = static_cast(output_shape_data[i]); + i_output_shape_data[i] = static_cast(output_shape_data[i]); } const float* filter_data = const_cast(filter->mutable_data()); @@ -69,7 +68,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto paddings = op_info->GetAttr>("paddings"); auto strides = op_info->GetAttr>("strides"); auto dilations = op_info->GetAttr>("dilations"); - add_conv_layer(graph->GetCompilerHandle(), const_cast(&i_input_shape_data[0]), input_dims.size(), @@ -104,3 +102,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { REGISTER_SUBGRAPH_BRIDGE(conv2d, kBM, paddle::lite::subgraph::bm::ConvConverter); +REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d, + kBM, + paddle::lite::subgraph::bm::ConvConverter); diff --git a/lite/kernels/bm/bridges/dropout_op.cc b/lite/kernels/bm/bridges/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3364e866a3525c225916179152669d6456a42efc --- /dev/null +++ b/lite/kernels/bm/bridges/dropout_op.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // input + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + std::vector i_x_shape_data(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_shape_data[i]); + } + // output + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + const int64_t* output_shape_data = + const_cast(&output_dims.data()[0]); + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_shape_data[i]); + } + auto dropout_prob = op_info->GetAttr("dropout_prob"); + auto dropout_implementation = + op_info->GetAttr("dropout_implementation"); + CHECK_EQ(dropout_implementation, "downgrade_in_infer"); + add_const_binary_layer(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + 1.f - dropout_prob, + static_cast(output_var_name.c_str()), + BINARY_MUL, + 0); + + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle +REGISTER_SUBGRAPH_BRIDGE(dropout, + kBM, + paddle::lite::subgraph::bm::DropoutConverter); diff --git a/lite/kernels/bm/bridges/elementwise_ops.cc b/lite/kernels/bm/bridges/elementwise_ops.cc index 7d158110ee519febf3761ee5662ef7c27de2ca4d..2fdbfd8c3f74879a52f5d3a8057953ab800887ef 100644 --- a/lite/kernels/bm/bridges/elementwise_ops.cc +++ b/lite/kernels/bm/bridges/elementwise_ops.cc @@ -14,6 +14,7 @@ #include #include #include +#include #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/utility.h" #include "lite/kernels/npu/bridges/registry.h" @@ -68,42 +69,52 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { for (size_t i = 0; i < output_dims.size(); i++) { i_output_shape_data[i] = static_cast(output_shape_data[i]); } - if (y_is_const) { - CHECK_EQ(op_type, "elementwise_add"); - } + auto axis = op_info->GetAttr("axis"); int op_code{-1}; + int eltwise_if_code{-1}; float coeff[2] = {1.f, 1.f}; if (op_type == "elementwise_mul") { - op_code = 0; + op_code = BINARY_MUL; + eltwise_if_code = 0; } else if (op_type == "elementwise_add") { - op_code = 1; + op_code = BINARY_ADD; + eltwise_if_code = 1; } else if (op_type == "elementwise_sub") { - op_code = 1; + op_code = BINARY_SUB; + eltwise_if_code = 1; coeff[1] = -1.f; } else { LOG(FATAL) << "UNSUPPORTED ELTWISE OPERATION: " << op_type; } - if (!y_is_const) { - add_eltwise_layer(graph->GetCompilerHandle(), - input_num, - shape, - dim, - name, - const_cast(&i_output_shape_data[0]), - output_dims.size(), - static_cast(output_var_name.c_str()), - op_code, - coeff); - } else { - const float* y_data = const_cast(y->mutable_data()); - const float* x_data = const_cast(x->mutable_data()); - bm_add_const_tensor(graph->GetCompilerHandle(), - name[1], - shape[0], - dim[0], - static_cast(DTYPE_FP32), - static_cast(y_data)); - + const float* y_data = const_cast(y->mutable_data()); + const float* x_data = const_cast(x->mutable_data()); + auto unique_op_name = lite::subgraph::bm::UniqueName("expand_ndims"); + std::vector i_expand_shape_data(3); + if (y_is_const) { + if (dim[0] == dim[1] || 2 == dim[0]) { + bm_add_const_tensor(graph->GetCompilerHandle(), + name[1], + shape[1], + dim[1], + static_cast(DTYPE_FP32), + static_cast(y_data)); + } else if (1 == dim[1] && 1 == axis) { + add_expand_ndims_layer(graph->GetCompilerHandle(), + name[1], + shape[1], + dim[1], + static_cast(y_data), + -1, + 2, + static_cast(unique_op_name.c_str())); + name[1] = static_cast(unique_op_name.c_str()); + dim[1] = 3; + i_expand_shape_data[0] = i_y_shape_data[0]; + i_expand_shape_data[1] = 1; + i_expand_shape_data[2] = 1; + shape[1] = &i_expand_shape_data[0]; + y_data = nullptr; + } add_binary_layer_v2(graph->GetCompilerHandle(), name[0], shape[0], @@ -111,12 +122,23 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { 0, static_cast(x_data), name[1], - shape[0], - dim[0], + shape[1], + dim[1], 0, static_cast(y_data), static_cast(output_var_name.c_str()), - 0); + op_code); + } else { + add_eltwise_layer(graph->GetCompilerHandle(), + input_num, + shape, + dim, + name, + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + eltwise_if_code, + coeff); } delete[] shape; delete[] name; @@ -133,3 +155,9 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { REGISTER_SUBGRAPH_BRIDGE(elementwise_add, kBM, paddle::lite::subgraph::bm::ElementwiseConverter); +REGISTER_SUBGRAPH_BRIDGE(elementwise_mul, + kBM, + paddle::lite::subgraph::bm::ElementwiseConverter); +REGISTER_SUBGRAPH_BRIDGE(elementwise_sub, + kBM, + paddle::lite::subgraph::bm::ElementwiseConverter); diff --git a/lite/kernels/bm/bridges/mul_op.cc b/lite/kernels/bm/bridges/mul_op.cc index add4c89d2b967e8a817ddca8b86bc95ed039f7ab..06ec177bceb883758c42d45c9b07006a83b3c9f6 100644 --- a/lite/kernels/bm/bridges/mul_op.cc +++ b/lite/kernels/bm/bridges/mul_op.cc @@ -41,8 +41,10 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // add reshape layer int i_x_reshape_shape_data[2]; - for (size_t i = 0; i < 2; i++) { - i_x_reshape_shape_data[i] = static_cast(x_shape_data[i]); + i_x_reshape_shape_data[0] = static_cast(x_shape_data[0]); + i_x_reshape_shape_data[1] = 1; + for (size_t i = 1; i < x_dims.size(); i++) { + i_x_reshape_shape_data[1] *= static_cast(x_shape_data[i]); } int reshape_param[] = {0, -1}; auto unique_op_reshape_name = diff --git a/lite/kernels/bm/bridges/norm_op.cc b/lite/kernels/bm/bridges/norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..69b2ac130dcd3252ffcc97ac24a3b5d65afff6f0 --- /dev/null +++ b/lite/kernels/bm/bridges/norm_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int NormConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + const int64_t* output_shape_data = + const_cast(&output_dims.data()[0]); + std::vector i_x_shape_data(x_dims.size()); + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_shape_data[i]); + } + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_shape_data[i]); + } + + float one = 1.f; + auto epsilon = op_info->GetAttr("epsilon"); + add_normalize_layer(graph->GetCompilerHandle(), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + static_cast(x_var_name.c_str()), + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + static_cast(unique_op_name.c_str()), + 0, + 1, + &one, + epsilon); + + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(norm, kBM, paddle::lite::subgraph::bm::NormConverter); diff --git a/lite/kernels/bm/bridges/paddle_use_bridges.h b/lite/kernels/bm/bridges/paddle_use_bridges.h index 417d016c78d4d9d0464e52827c01bdc90afe484f..fdaf70de6a4777ae016326a22721c845a79b7d93 100644 --- a/lite/kernels/bm/bridges/paddle_use_bridges.h +++ b/lite/kernels/bm/bridges/paddle_use_bridges.h @@ -15,10 +15,24 @@ #pragma once USE_SUBGRAPH_BRIDGE(relu, kBM); +USE_SUBGRAPH_BRIDGE(leaky_relu, kBM); USE_SUBGRAPH_BRIDGE(conv2d, kBM); +USE_SUBGRAPH_BRIDGE(depthwise_conv2d, kBM); USE_SUBGRAPH_BRIDGE(elementwise_add, kBM); +USE_SUBGRAPH_BRIDGE(elementwise_mul, kBM); +USE_SUBGRAPH_BRIDGE(elementwise_sub, kBM); USE_SUBGRAPH_BRIDGE(pool2d, kBM); USE_SUBGRAPH_BRIDGE(softmax, kBM); USE_SUBGRAPH_BRIDGE(mul, kBM); USE_SUBGRAPH_BRIDGE(batch_norm, kBM); USE_SUBGRAPH_BRIDGE(scale, kBM); +USE_SUBGRAPH_BRIDGE(concat, kBM); +USE_SUBGRAPH_BRIDGE(dropout, kBM); +USE_SUBGRAPH_BRIDGE(transpose, kBM); +USE_SUBGRAPH_BRIDGE(transpose2, kBM); +USE_SUBGRAPH_BRIDGE(reshape, kBM); +USE_SUBGRAPH_BRIDGE(reshape2, kBM); +USE_SUBGRAPH_BRIDGE(flatten, kBM); +USE_SUBGRAPH_BRIDGE(flatten2, kBM); +USE_SUBGRAPH_BRIDGE(norm, kBM); +USE_SUBGRAPH_BRIDGE(prior_box, kBM); diff --git a/lite/kernels/bm/bridges/pool_op.cc b/lite/kernels/bm/bridges/pool_op.cc index 8b0c0cfffbde847571805a89ef94a0e80e31fa9b..cd48db5b726d1dcb3b65e4c3a70141a09d452bdc 100644 --- a/lite/kernels/bm/bridges/pool_op.cc +++ b/lite/kernels/bm/bridges/pool_op.cc @@ -65,6 +65,12 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { if (pooling_type == "avg") { average_exclusive = op_info->GetAttr("exclusive"); } + if (global_pooling) { + paddings[0] = 0; + paddings[1] = 0; + ksize[0] = i_x_shape_data[2]; + ksize[1] = i_x_shape_data[3]; + } add_pooling_layer( graph->GetCompilerHandle(), const_cast(&i_x_shape_data[0]), diff --git a/lite/kernels/bm/bridges/prior_box_op.cc b/lite/kernels/bm/bridges/prior_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..17c3fbf03473a480e0bb736241e6095055999098 --- /dev/null +++ b/lite/kernels/bm/bridges/prior_box_op.cc @@ -0,0 +1,340 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +typedef struct __tag_st_priorbox_param { + std::vector min_sizes; + std::vector max_sizes; + std::vector aspect_ratios; + std::vector variances; + float step_w; + float step_h; + float offset; + int32_t img_w; + int32_t img_h; + int32_t prior_num; + bool min_max_aspect_ratios_order; + bool clip; + bool flip; +} st_priorbox_param; + +inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, + bool flip, + std::vector* output_aspect_ratior) { + constexpr float epsilon = 1e-6; + output_aspect_ratior->clear(); + output_aspect_ratior->push_back(1.0f); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { + if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior->push_back(ar); + if (flip) { + output_aspect_ratior->push_back(1.0f / ar); + } + } + } +} + +float* compute_priorbox_kernel(OpLite* op, st_priorbox_param* param) { + auto op_info = op->op_info(); + auto scope = op->scope(); + // inputs + auto in_var_name = op_info->Input("Input").front(); + auto in = scope->FindVar(in_var_name)->GetMutable(); + auto in_dims = in->dims(); + auto img_var_name = op_info->Input("Image").front(); + auto img = scope->FindVar(img_var_name)->GetMutable(); + auto img_dims = img->dims(); + // outputs + auto boxes_var_name = op_info->Output("Boxes").front(); + auto boxes = scope->FindVar(boxes_var_name)->GetMutable(); + auto var_var_name = op_info->Output("Variances").front(); + auto var = scope->FindVar(var_var_name)->GetMutable(); + std::vector expand_aspect_ratios; + ExpandAspectRatios(param->aspect_ratios, param->flip, &expand_aspect_ratios); + param->aspect_ratios.clear(); + for (size_t i = 0; i < expand_aspect_ratios.size(); i++) { + param->aspect_ratios.push_back(expand_aspect_ratios[i]); + } + param->prior_num = param->aspect_ratios.size() * param->min_sizes.size(); + if (param->max_sizes.size() > 0) { + param->prior_num += param->max_sizes.size(); + } + int32_t win1 = in_dims[3]; + int32_t hin1 = in_dims[2]; + DDim shape_out({hin1, win1, param->prior_num, 4}); + boxes->Resize(shape_out); + var->Resize(shape_out); + // boxes->mutable_data(); + // var->mutable_data(); + float* cpu_data = + static_cast(malloc(sizeof(float) * boxes->data_size() * 2)); + CHECK(cpu_data != nullptr); + const int32_t width = in_dims[3]; + const int32_t height = in_dims[2]; + int32_t img_width = param->img_w; + int32_t img_height = param->img_h; + if (img_width == 0 || img_height == 0) { + img_width = img_dims[3]; + img_height = img_dims[2]; + } + float step_w = param->step_w; + float step_h = param->step_h; + if (step_w == 0.f || step_h == 0.f) { + step_w = static_cast(img_width) / width; + step_h = static_cast(img_height) / height; + } + float offset = param->offset; + int32_t channel_size = height * width * param->prior_num * 4; + int32_t idx = 0; + /////////////////////////////////////////////////////////////////////// + for (int32_t h = 0; h < height; ++h) { + for (int32_t w = 0; w < width; ++w) { + float center_x = (w + offset) * step_w; + float center_y = (h + offset) * step_h; + float box_width = 0.f; + float box_height = 0.f; + float* min_buf = reinterpret_cast(malloc(sizeof(float) * 4)); + float* max_buf = reinterpret_cast(malloc(sizeof(float) * 4)); + float* com_buf = reinterpret_cast( + malloc(sizeof(float) * expand_aspect_ratios.size() * 4)); + CHECK(min_buf != nullptr); + CHECK(max_buf != nullptr); + CHECK(com_buf != nullptr); + // LOG(INFO) << "the number of min_size is " << min_sizes_.size(); + for (size_t s = 0; s < param->min_sizes.size(); ++s) { + int32_t min_idx = 0; + int32_t max_idx = 0; + int32_t com_idx = 0; + int32_t min_size = param->min_sizes[s]; + //! first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + //! xmin + min_buf[min_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + min_buf[min_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + min_buf[min_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + min_buf[min_idx++] = (center_y + box_height / 2.f) / img_height; + if (param->max_sizes.size() > 0) { + int max_size = param->max_sizes[s]; + //! second prior: aspect_ratio = 1, size = sqrt(min_size * max_size) + box_width = box_height = sqrtf(min_size * max_size); + //! xmin + max_buf[max_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + max_buf[max_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + max_buf[max_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + max_buf[max_idx++] = (center_y + box_height / 2.f) / img_height; + } + //! rest of priors + for (size_t r = 0; r < expand_aspect_ratios.size(); ++r) { + float ar = expand_aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + //! xmin + com_buf[com_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + com_buf[com_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + com_buf[com_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + com_buf[com_idx++] = (center_y + box_height / 2.f) / img_height; + } + if (param->min_max_aspect_ratios_order) { + memcpy(cpu_data + idx, min_buf, sizeof(float) * min_idx); + idx += min_idx; + memcpy(cpu_data + idx, max_buf, sizeof(float) * max_idx); + idx += max_idx; + memcpy(cpu_data + idx, com_buf, sizeof(float) * com_idx); + idx += com_idx; + } else { + memcpy(cpu_data + idx, com_buf, sizeof(float) * com_idx); + idx += com_idx; + memcpy(cpu_data + idx, max_buf, sizeof(float) * max_idx); + idx += max_idx; + } + } + free(min_buf); + free(max_buf); + free(com_buf); + } + } + //! clip the prior's coordidate such that it is within [0, 1] + if (param->clip) { + for (int32_t d = 0; d < channel_size; ++d) { + cpu_data[d] = std::min(std::max(cpu_data[d], 0.f), 1.f); + } + } + //! set the variance. + float* ptr = cpu_data + channel_size; + int count = 0; + for (int32_t h = 0; h < height; ++h) { + for (int32_t w = 0; w < width; ++w) { + for (int32_t i = 0; i < param->prior_num; ++i) { + for (int j = 0; j < 4; ++j) { + ptr[count] = param->variances[j]; + ++count; + } + } + } + } + return cpu_data; +} + +int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + // inputs + auto in_var_name = op_info->Input("Input").front(); + auto in = scope->FindVar(in_var_name)->GetMutable(); + auto in_dims = in->dims(); + auto img_var_name = op_info->Input("Image").front(); + auto img = scope->FindVar(img_var_name)->GetMutable(); + auto img_dims = img->dims(); + std::vector i_input_shape_data(in_dims.size()); + for (size_t i = 0; i < in_dims.size(); i++) { + i_input_shape_data[i] = static_cast(in_dims[i]); + } + // outputs + auto boxes_var_name = op_info->Output("Boxes").front(); + auto boxes = scope->FindVar(boxes_var_name)->GetMutable(); + auto var_var_name = op_info->Output("Variances").front(); + auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); + // param + st_priorbox_param param; + param.clip = op_info->GetAttr("clip"); + param.min_sizes = op_info->GetAttr>("min_sizes"); + param.max_sizes = op_info->GetAttr>("max_sizes"); + param.aspect_ratios = op_info->GetAttr>("aspect_ratios"); + param.variances = op_info->GetAttr>("variances"); + param.offset = op_info->GetAttr("offset"); + if (op_info->HasAttr("flip")) { + param.flip = op_info->GetAttr("flip"); + } + if (op_info->HasAttr("img_w")) { + param.img_w = op_info->GetAttr("img_w"); + } + if (op_info->HasAttr("img_h")) { + param.img_h = op_info->GetAttr("img_h"); + } + if (op_info->HasAttr("step_w")) { + param.step_w = op_info->GetAttr("step_w"); + } + if (op_info->HasAttr("step_h")) { + param.step_h = op_info->GetAttr("step_h"); + } + if (op_info->HasAttr("prior_num")) { + param.prior_num = op_info->GetAttr("prior_num"); + } + if (op_info->HasAttr("min_max_aspect_ratios_order")) { + param.min_max_aspect_ratios_order = + op_info->GetAttr("min_max_aspect_ratios_order"); + } + float* cpu_data = compute_priorbox_kernel(op, ¶m); + compute_priorbox_kernel(op, param); + auto boxes_dims = boxes->dims(); + std::vector i_pri_out_shape_data(boxes_dims.size()); + for (size_t i = 0; i < boxes_dims.size(); i++) { + i_pri_out_shape_data[i] = static_cast(boxes_dims[i]); + } + i_pri_out_shape_data[0] *= 2; + add_priorbox_layer(graph->GetCompilerHandle(), + const_cast(&i_input_shape_data[0]), + in_dims.size(), + static_cast(in_var_name.c_str()), + const_cast(&i_pri_out_shape_data[0]), + boxes_dims.size(), + static_cast(unique_op_name.c_str()), + static_cast(cpu_data), + param.min_sizes.size(), + const_cast(¶m.min_sizes[0]), + param.max_sizes.size(), + const_cast(¶m.max_sizes[0]), + param.aspect_ratios.size(), + const_cast(¶m.aspect_ratios[0]), + static_cast(param.flip), + static_cast(param.clip), + param.variances.size(), + const_cast(¶m.variances[0]), + param.img_h, + param.img_w, + param.step_h, + param.step_w, + param.offset); + std::vector i_output_shape_data(boxes_dims.size()); + for (size_t i = 0; i < boxes_dims.size(); i++) { + i_output_shape_data[i] = static_cast(boxes_dims[i]); + } + int32_t* shape[2]; + int dim[2]; + const char* name[2]; + dim[0] = boxes_dims.size(); + dim[1] = boxes_dims.size(); + name[0] = static_cast(boxes_var_name.c_str()); + name[1] = static_cast(var_var_name.c_str()); + shape[0] = &i_output_shape_data[0]; + shape[1] = &i_output_shape_data[0]; + int split_size = 2; + add_tf_split_layer(graph->GetCompilerHandle(), + const_cast(&i_pri_out_shape_data[0]), + boxes_dims.size(), + static_cast(unique_op_name.c_str()), + 2, + shape, + dim, + name, + boxes_dims.size(), + 0, + &split_size, + 0); + graph->AddNode(boxes_var_name); + graph->AddNode(var_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(prior_box, + kBM, + paddle::lite::subgraph::bm::PriorBoxConverter); diff --git a/lite/kernels/bm/bridges/reshape_op.cc b/lite/kernels/bm/bridges/reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..64f4ffe9f3909b28c5237a8be88ba54fec4b1b83 --- /dev/null +++ b/lite/kernels/bm/bridges/reshape_op.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + std::vector i_x_shape_data(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_dims[i]); + } + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_dims[i]); + } + // auto axis = op_info->GetAttr("axis"); + add_reshape_layer_v2(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + static_cast(output_var_name.c_str()), + const_cast(&i_output_shape_data[0]), + output_dims.size()); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(reshape, + kBM, + paddle::lite::subgraph::bm::ReshapeConverter); +REGISTER_SUBGRAPH_BRIDGE(reshape2, + kBM, + paddle::lite::subgraph::bm::ReshapeConverter); +REGISTER_SUBGRAPH_BRIDGE(flatten, + kBM, + paddle::lite::subgraph::bm::ReshapeConverter); +REGISTER_SUBGRAPH_BRIDGE(flatten2, + kBM, + paddle::lite::subgraph::bm::ReshapeConverter); diff --git a/lite/kernels/bm/bridges/softmax_op.cc b/lite/kernels/bm/bridges/softmax_op.cc index fc08d9db4f78520750d11339b15d5e39bfdc675f..5de58872ac0e5a0536f2c746357f38f4ff664688 100644 --- a/lite/kernels/bm/bridges/softmax_op.cc +++ b/lite/kernels/bm/bridges/softmax_op.cc @@ -48,7 +48,10 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { for (size_t i = 0; i < length; i++) { i_output_shape_data[i] = static_cast(output_shape_data[i]); } - auto axis = op_info->GetAttr("axis"); + int32_t axis = -1; + if (op_info->HasAttr("axis")) { + axis = op_info->GetAttr("axis"); + } if (axis < 0) { axis += x_dims.size(); } diff --git a/lite/kernels/bm/bridges/transpose_op.cc b/lite/kernels/bm/bridges/transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bab24a96b9920212337f6afd3c1c73f582a48975 --- /dev/null +++ b/lite/kernels/bm/bridges/transpose_op.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); + const int64_t* output_shape_data = + const_cast(&output_dims.data()[0]); + std::vector i_x_shape_data(x_dims.size()); + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + i_x_shape_data[i] = static_cast(x_shape_data[i]); + } + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_shape_data[i]); + } + auto axis = op_info->GetAttr>("axis"); + CHECK_EQ(axis.size(), x_dims.size()); + add_transpose_layer_v2(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + DTYPE_FP32, + static_cast(output_var_name.c_str()), + NULL, + const_cast(&axis[0])); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(transpose, + kBM, + paddle::lite::subgraph::bm::TransposeConverter); +REGISTER_SUBGRAPH_BRIDGE(transpose2, + kBM, + paddle::lite::subgraph::bm::TransposeConverter); diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index 83f9fe3bed2ce751ad57962b9b5e35ed23d40ab5..2e47102d767becdea0f0d3d50aa30d6933d6ef8d 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -54,7 +54,7 @@ int SubgraphEngine::BuildDeviceProgram() { } std::string net_name = "paddle_bitmain"; __bmcompile_opt( - graph.GetCompilerHandle(), const_cast(net_name.c_str()), 2); + graph.GetCompilerHandle(), const_cast(net_name.c_str()), 1); void* bmodel_data = nullptr; unsigned int data_size = 0; bm_hd_ = static_cast(ctx.GetHandle()); @@ -109,7 +109,6 @@ int SubgraphEngine::BuildDeviceProgram() { net_info_->output_dtypes[i], stage.output_shapes[i]); } - return status; } diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt old mode 100644 new mode 100755 index 428cc213ce63b8d24193a44f23d61fea78f63d6a..c212fb9b0465824b7a87eef2e87033bf967736e5 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -4,6 +4,7 @@ add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) #lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) #lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/kernels/host/multiclass_nms_compute.cc b/lite/kernels/host/multiclass_nms_compute.cc index 9cbc798d46ecb3cf98159e9b4762c8692ec8c1eb..a4af3548e89c637bffae32944f239997e7d0e41b 100644 --- a/lite/kernels/host/multiclass_nms_compute.cc +++ b/lite/kernels/host/multiclass_nms_compute.cc @@ -426,8 +426,14 @@ REGISTER_LITE_KERNEL(multiclass_nms, kNCHW, paddle::lite::kernels::host::MulticlassNmsCompute, def) - .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("BBoxes", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindInput("Scores", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Index", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) diff --git a/lite/kernels/host/one_hot_compute.cc b/lite/kernels/host/one_hot_compute.cc new file mode 100755 index 0000000000000000000000000000000000000000..e0af6f5173f367bb9b2e06de10499ee36806379c --- /dev/null +++ b/lite/kernels/host/one_hot_compute.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "lite/backends/fpga/KD/debugger.hpp" +#include "lite/kernels/host/one_hot_compute.h" +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void OneHotCompute::Run() { + auto& param = Param(); + param.Out->mutable_data(); + int depth = param.depth; + if (param.depth_tensor) { + auto* depth_tensor = param.depth_tensor; + auto* depth_data = depth_tensor->data(); + depth = depth_data[0]; + auto in_dims = param.X->dims(); + DDim out_dims(in_dims); + out_dims[out_dims.size() - 1] = depth; + param.Out->Resize(out_dims); + } + + auto* p_in_data = param.X->data(); + auto numel = param.X->numel(); + auto* p_out_data = param.Out->mutable_data(); + + for (int i = 0; i < param.Out->numel(); ++i) { + p_out_data[i] = 0; + } + + if (param.allow_out_of_range) { + for (int i = 0; i < numel; ++i) { + if (p_in_data[i] >= 0 && p_in_data[i] < param.depth) { + *(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT + } + } + } else { + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE( + p_in_data[i], 0, "Illegal index value, should be at least 0."); + PADDLE_ENFORCE_LT(p_in_data[i], + param.depth, + "Illegal index value, should be less than depth (%d).", + param.depth); + *(p_out_data + i * param.depth + (int)(p_in_data[i])) = 1.0; // NOLINT + } + } +} +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(one_hot, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::OneHotCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/host/one_hot_compute.h b/lite/kernels/host/one_hot_compute.h new file mode 100755 index 0000000000000000000000000000000000000000..3a6c47fee31bc28f130c3de782c0c912c9f4b769 --- /dev/null +++ b/lite/kernels/host/one_hot_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class OneHotCompute + : public KernelLite { + public: + void Run() override; + + virtual ~OneHotCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index 02f99787e60e73d91ca8f65cb42dcd4c56e7212b..10c50d20b9c52f72d09c4519716e2defb047a23f 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -46,17 +46,21 @@ REGISTER_LITE_KERNEL(reshape, paddle::lite::kernels::host::ReshapeCompute, def) .BindInput("X", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .BindInput("ShapeTensor", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .BindInput("Shape", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .BindOutput("Out", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(reshape2, diff --git a/lite/kernels/npu/bridges/fc_op.cc b/lite/kernels/npu/bridges/fc_op.cc index 3d028172154e58c1ed191b4d4eb780e9937458a5..d9d42cd8c73a321449649bca658333fdd5f57325 100644 --- a/lite/kernels/npu/bridges/fc_op.cc +++ b/lite/kernels/npu/bridges/fc_op.cc @@ -34,27 +34,29 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto input_type = kernel->GetInputDeclType("Input"); CHECK(input_type->precision() == PRECISION(kFloat)); CHECK(input_type->layout() == DATALAYOUT(kNCHW)); - auto input = scope->FindMutableTensor(input_name); + auto input = scope->FindTensor(input_name); auto input_dims = input->dims(); - CHECK_GE(input_dims.size(), 2UL); + auto w_name = op_info->Input("W").front(); auto w_type = kernel->GetInputDeclType("W"); CHECK(w_type->precision() == PRECISION(kFloat)); CHECK(w_type->layout() == DATALAYOUT(kNCHW)); - auto w = scope->FindMutableTensor(w_name); + auto w = scope->FindTensor(w_name); auto w_dims = w->dims(); CHECK_EQ(w_dims.size(), 2UL); + auto out_name = op_info->Output("Out").front(); auto out_type = kernel->GetOutputDeclType("Out"); CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->layout() == DATALAYOUT(kNCHW)); + auto out = scope->FindTensor(out_name); + auto out_dims = out->dims(); + int in_num_col_dims = op_info->GetAttr("in_num_col_dims"); int m = input_dims.Slice(0, in_num_col_dims).production(); int k = input_dims.Slice(in_num_col_dims, input_dims.size()).production(); int n = w_dims[1]; CHECK_EQ(k * n, w_dims.production()); - VLOG(3) << "[NPU] input dims: " << input_dims << " w dims: " << w_dims - << " m: " << m << " k: " << k << " n: " << n; // Create input node and reshape it to (m, k, 1, 1) std::shared_ptr input_node = nullptr; @@ -76,7 +78,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { transpose_w.Resize({n, k, 1, 1}); transpose_w.set_persistable(true); auto transpose_w_data = transpose_w.mutable_data(); - auto w_data = w->mutable_data(); + auto w_data = w->data(); for (int i = 0; i < k; i++) { for (int j = 0; j < n; j++) { transpose_w_data[j * k + i] = w_data[i * n + j]; @@ -85,10 +87,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto trans_w_node = graph->Add(w_name, transpose_w); // FC node - auto fc_node = graph->Add(out_name + "/fc"); + auto fc_node = graph->Add(out_name); auto fc_op = fc_node->data(); fc_op->set_input_x(*reshaped_input_node->data()); fc_op->set_input_w(*trans_w_node->data()); + // Add bias node if bias tensor exists if (HasInputArg(op_info, scope, "Bias")) { std::shared_ptr bias_node = nullptr; @@ -99,19 +102,23 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto bias_type = kernel->GetInputDeclType("Bias"); CHECK(bias_type->precision() == PRECISION(kFloat)); CHECK(bias_type->layout() == DATALAYOUT(kNCHW)); - auto bias = scope->FindMutableTensor(bias_name); + auto bias = scope->FindTensor(bias_name); auto bias_dims = bias->dims(); CHECK_EQ(bias_dims.production(), n); bias_node = graph->Add(bias_name, *bias, {1, n, 1, 1}); } fc_op->set_input_b(*bias_node->data()); } - // Reshape output of FC node from (m, n, 1, 1) to (m, n) + + // Reshape output of FC node from (m, n, 1, 1) to out_shape auto reshaped_fc_node = graph->Add(out_name); auto reshaped_fc_op = reshaped_fc_node->data(); reshaped_fc_op->set_input_tensor(*fc_node->data()); - reshaped_fc_op->set_attr_shape({m, n}); + auto out_shape = out_dims.Vectorize(); + reshaped_fc_op->set_attr_shape( + ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); reshaped_fc_op->set_attr_axis(0); + return REBUILD_WHEN_SHAPE_CHANGED; } diff --git a/lite/kernels/npu/bridges/softmax_op.cc b/lite/kernels/npu/bridges/softmax_op.cc index 24bbb790e08b4b0ff675173af8faad3b07f8f2e0..0ca3bc131d1f0910b9282ec53656bee53bbc5444 100644 --- a/lite/kernels/npu/bridges/softmax_op.cc +++ b/lite/kernels/npu/bridges/softmax_op.cc @@ -42,7 +42,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto out_type = kernel->GetOutputDeclType("Out"); CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->layout() == DATALAYOUT(kNCHW)); - auto axis = op_info->GetAttr("axis"); + int axis = op_info->HasAttr("axis") ? op_info->GetAttr("axis") : -1; if (axis < 0) { axis += x_rank; } diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index a802d1a9ebafc81f1a2b2c04fc190d6a8a7818c3..e81fdf307e94fbb6593962052b911c34a944777a 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -20,7 +20,9 @@ add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps} cl_image_converter) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(concat_opencl OPENCL basic SRCS concat_compute.cc DEPS ${cl_kernel_deps}) add_kernel(nearest_interp_opencl OPENCL basic SRCS nearest_interp_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(scale_opencl OPENCL basic SRCS scale_compute.cc DEPS ${cl_kernel_deps}) lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context @@ -83,6 +85,15 @@ lite_cc_test(test_conv_image2d_opencl SRCS conv_image2d_compute_test.cc lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc DEPS layout_opencl op_registry program context cl_image_converter ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + +lite_cc_test(test_concat_opencl SRCS concat_compute_test.cc + DEPS concat_opencl layout_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + lite_cc_test(test_nearest_interp_opencl SRCS nearest_interp_compute_test.cc DEPS nearest_interp_opencl layout_opencl op_registry program context cl_image_converter ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + +lite_cc_test(test_scale_opencl SRCS scale_compute_test.cc + DEPS scale_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/concat_compute.cc b/lite/kernels/opencl/concat_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f25439ed00a9ff579bbd59a543dba3c8c3b090b --- /dev/null +++ b/lite/kernels/opencl/concat_compute.cc @@ -0,0 +1,372 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/opencl/concat_compute.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +template <> +void ConcatCompute::PrepareForRun() { + auto& context = ctx_->As(); + concat_param_ = param_.get_mutable(); + if (concat_param_->x.size() == 2) { + kernel_func_name_ = "concat2"; + } else { + kernel_func_name_ = "concat_mul"; + } + context.cl_context()->AddKernel( + kernel_func_name_, "image/concat_kernel.cl", build_options_); + // UpdateParams(); + auto axis = concat_param_->axis; + auto inputs = concat_param_->x; + auto out_dims = concat_param_->output->dims(); + auto* axis_tensor = concat_param_->axis_tensor; + if (axis_tensor != nullptr) { + // auto* axis_tensor_data = axis_tensor->data(TARGET(kARM)); + // axis = axis_tensor_data[0]; + } + auto in_dims = inputs[0]->dims(); + axis_size_ = out_dims[axis]; + axis_ = axis; + for (int i = 0; i < axis; i++) { + pre_size_ *= in_dims[i]; + } + for (int i = axis + 1; i < in_dims.size(); i++) { + post_size_ *= in_dims[i]; + } + for (int i = 1; i < inputs.size(); i++) { + auto dims = inputs[i]->dims(); + // auto flag = CHECK_EQ_OR_FALSE(in_dims.size(), dims.size()); + if (in_dims.size() != dims.size()) { + printf("input shape must be same \n"); + return; + } + for (int i = 0; i < dims.size(); i++) { + if (i != axis) { + if (in_dims[i] != dims[i]) { + printf("input shape must be same \n"); + return; + } + } + } + } +} + +template <> +void ConcatCompute::Run() { + auto& param = *param_.get_mutable(); + const auto& x_dims = param.output->dims(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* out_buf = param.output->mutable_data( + image_shape["width"], image_shape["height"]); + const auto& y_dims = param.output->dims(); // useless: check dim only + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + + auto inputs = param.x; + int arg_idx = 0; + int width = inputs[0]->dims()[-1]; + auto global_work_size = + cl::NDRange{static_cast(image_shape["width"]), + static_cast(image_shape["height"])}; + VLOG(4) << TargetToStr(param.output->target()); + VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " + << image_shape["height"]; + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + int flag = 1; // cxw + switch (axis_) { + case 0: + width = x_dims[2]; // n + flag = 0; + break; + case 1: + width = x_dims[3]; // c + break; + case 2: + width = x_dims[0]; // h + flag = 0; + break; + case 3: + case -1: + width = x_dims[1]; // w + break; + default: + printf("this axis: %d does not support \n", axis_); + } + if (inputs.size() == 2) { + auto* x_buf0 = inputs[0]->data(); + auto* x_buf1 = inputs[1]->data(); + cl_int status = kernel.setArg(arg_idx, *x_buf0); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *x_buf1); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + status = + kernel.setArg(++arg_idx, static_cast(inputs[0]->dims()[axis_])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, flag); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, width); + CL_CHECK_FATAL(status); + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_context()->GetCommandQueue().finish(); + } else { + auto start = 0; + for (int i = 0; i < inputs.size(); i++) { + arg_idx = 0; + auto* x_buf = inputs[i]->data(); + cl_int status = kernel.setArg(arg_idx, *x_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, axis_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, start); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, flag); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, width); + CL_CHECK_FATAL(status); + CL_CHECK_FATAL(status); + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_context()->GetCommandQueue().finish(); + start += inputs[i]->dims()[axis_]; + } + } +} + +template <> +std::string ConcatCompute::doc() { + return "Concat using cl::Image, kFloat"; +} + +template <> +void ConcatCompute::PrepareForRun() { + auto& context = ctx_->As(); + concat_param_ = param_.get_mutable(); + if (concat_param_->x.size() == 2) { + kernel_func_name_ = "concat2"; + } else { + kernel_func_name_ = "concat_mul"; + } + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/concat_kernel.cl", build_options_); + + // UpdateParams(); + auto axis = concat_param_->axis; + auto inputs = concat_param_->x; + auto out_dims = concat_param_->output->dims(); + auto* axis_tensor = concat_param_->axis_tensor; + if (axis_tensor != nullptr) { + // auto* axis_tensor_data = axis_tensor->data(TARGET(kARM)); + // axis = axis_tensor_data[0]; + } + auto in_dims = inputs[0]->dims(); + axis_size_ = out_dims[axis]; + axis_ = axis; + for (int i = 0; i < axis; i++) { + pre_size_ *= in_dims[i]; + } + for (int i = axis + 1; i < in_dims.size(); i++) { + post_size_ *= in_dims[i]; + } + for (int i = 1; i < inputs.size(); i++) { + auto dims = inputs[i]->dims(); + if (in_dims.size() != dims.size()) { + printf("input shape must be same \n"); + return; + } + for (int i = 0; i < dims.size(); i++) { + if (i != axis) { + if (in_dims[i] != dims[i]) { + printf("input shape must be same \n"); + return; + } + } + } + } +} + +template <> +void ConcatCompute::Run() { + auto& param = *param_.get_mutable(); + const auto& x_dims = param.output->dims(); + auto image_shape = InitImageDimInfoWith(x_dims); + auto* out_buf = + param.output->mutable_data(TARGET(kOpenCL)); + const auto& y_dims = param.output->dims(); // useless: check dim only + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + + auto inputs = param.x; + int arg_idx = 0; + auto global_work_size = cl::NDRange{axis_size_}; + int total = axis_size_ * post_size_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + if (inputs.size() == 2) { + auto* x_buf0 = inputs[0]->data(); + auto* x_buf1 = inputs[1]->data(); + auto axis0 = inputs[0]->dims()[axis_]; + int total0 = axis0 * post_size_; + int total1 = (axis_size_ - axis0) * post_size_; + cl_int status = kernel.setArg(arg_idx, *x_buf0); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *x_buf1); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(axis0)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, axis_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, pre_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, post_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, total); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, total0); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, total1); + CL_CHECK_FATAL(status); + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_buf, event_); + } else { + auto start = 0; + for (int i = 0; i < inputs.size(); i++) { + arg_idx = 0; + int size = inputs[i]->dims()[axis_]; + auto* x_buf = inputs[i]->data(); + global_work_size = cl::NDRange{static_cast(size)}; + int total0 = size * post_size_; + cl_int status = kernel.setArg(arg_idx, *x_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(size)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, pre_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, post_size_); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, start); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, total); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, total0); + CL_CHECK_FATAL(status); + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_buf, event_); + start += size; + } + } +} + +template <> +std::string ConcatCompute::doc() { + return "Concat using cl::Buffer, kFloat"; +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +typedef paddle::lite::kernels::opencl::ConcatCompute + Concat_buffer; + +typedef paddle::lite::kernels::opencl::ConcatCompute + Concat_image; + +REGISTER_LITE_KERNEL( + concat, kOpenCL, kFloat, kImageDefault, Concat_image, ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kInt32), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .Finalize(); + +REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kInt32), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/opencl/concat_compute.h b/lite/kernels/opencl/concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..7bed6a18146d76043fbfcd72236ba39c5607328b --- /dev/null +++ b/lite/kernels/opencl/concat_compute.h @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +template +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void PrepareForRun() override; + + void Run() override; + + std::string doc(); // override; + + // protected: + // void UpdateParams(); + + int axis_size_ = 1; + int post_size_ = 1; + int pre_size_ = 1; + int axis_ = 1; + param_t* concat_param_{nullptr}; + std::string kernel_func_name_{}; + std::string build_options_{"-DCL_DTYPE_float"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/opencl/concat_compute_test.cc b/lite/kernels/opencl/concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..37e7b6658be2eaa60285474b3766ce462ea3779b --- /dev/null +++ b/lite/kernels/opencl/concat_compute_test.cc @@ -0,0 +1,390 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// // +// // Licensed under the Apache License, Version 2.0 (the "License"); +// // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, software +// // distributed under the License is distributed on an "AS IS" BASIS, +// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// // See the License for the specific language governing permissions and +// // limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" + +namespace paddle { +namespace lite { + +template +void concat2_compute_ref(const dtype *in0, + const dtype *in1, + const int axis, + const DDim in0_dim, + const DDim in1_dim, + const DDim out_dim, + dtype *out_data) { + int pre_size = 1; + int post_size = 1; + for (int i = 0; i < axis; i++) { + pre_size *= in0_dim[i]; + } + for (int i = axis + 1; i < in0_dim.size(); i++) { + post_size *= in0_dim[i]; + } + int axis_size = out_dim[axis]; + for (int i = 0; i < pre_size; i++) { + for (int j = 0; j < axis_size; j++) { + if (j < in0_dim[axis]) { + memcpy(out_data, in0, sizeof(dtype) * post_size); + in0 += post_size; + out_data += post_size; + } + } + } +} + +template +void concat_mul_compute_ref(std::vector ins_data, + std::vector ins_dim, + int axis, + const DDim out_dim, + dtype *out_data) { + int pre_size = 1; + int post_size = 1; + for (int i = 0; i < axis; i++) { + pre_size *= ins_dim[0][i]; + } + for (int i = axis + 1; i < ins_dim[0].size(); i++) { + post_size *= ins_dim[0][i]; + } + int axis_size = out_dim[axis]; + for (int i = 0; i < pre_size; i++) { + for (int j = 0; j < ins_data.size(); j++) { + int size = post_size * ins_dim[j][axis]; + memcpy(out_data, ins_data[j], sizeof(dtype) * size); + out_data += size; + } + } +} +#if 1 // concat_buffer +TEST(opencl_concat_buffer, compute) { + // prepare data + const DDim x0_dim = DDim(std::vector{1, 2, 3, 4}); + const DDim x1_dim = DDim(std::vector{1, 2, 3, 4}); + const DDim x2_dim = DDim(std::vector{1, 2, 3, 4}); + const DDim out_dim = DDim(std::vector{1, 6, 3, 4}); + lite::Tensor x0, x1, x2, out, out_ref; + x0.Resize(x0_dim); + x1.Resize(x1_dim); + x2.Resize(x2_dim); + out.Resize(out_dim); + out_ref.Resize(out_dim); + + auto *x0_data = x0.mutable_data(TARGET(kOpenCL)); + auto *x1_data = x1.mutable_data(TARGET(kOpenCL)); + auto *x2_data = x2.mutable_data(TARGET(kOpenCL)); + std::default_random_engine engine; + std::uniform_real_distribution dist(-10, 10); + auto *mapped_x0 = static_cast( + TargetWrapperCL::Map(x0_data, 0, sizeof(float) * x0_dim.production())); + auto *mapped_x1 = static_cast( + TargetWrapperCL::Map(x1_data, 0, sizeof(float) * x1_dim.production())); + auto *mapped_x2 = static_cast( + TargetWrapperCL::Map(x2_data, 0, sizeof(float) * x2_dim.production())); + for (int i = 0; i < x0_dim.production(); i++) { + mapped_x0[i] = dist(engine); + } + for (int i = 0; i < x1_dim.production(); i++) { + mapped_x1[i] = dist(engine); + } + for (int i = 0; i < x2_dim.production(); i++) { + mapped_x2[i] = dist(engine); + } + + // set param and kernel, then run + operators::ConcatParam param; + std::vector ins; + ins.push_back(&x0); + ins.push_back(&x1); + ins.push_back(&x2); + auto axis = 1; + param.x = ins; + param.output = &out; + param.axis = axis; + + std::vector ins_data; + std::vector ins_dim; + + ins_data.push_back(mapped_x0); + ins_data.push_back(mapped_x1); + ins_data.push_back(mapped_x2); + ins_dim.push_back(x0_dim); + ins_dim.push_back(x1_dim); + ins_dim.push_back(x2_dim); + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + auto kernels = KernelRegistry::Global().Create( + "concat", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + kernel->SetParam(param); + std::unique_ptr concat_context(new KernelContext); + context->As().CopySharedTo( + &(concat_context->As())); + kernel->SetContext(std::move(concat_context)); + kernel->Launch(); + + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target cl tensor."; + } + + // run compute ref and check + auto *out_ref_data = out_ref.mutable_data(TARGET(kARM)); + concat_mul_compute_ref(ins_data, ins_dim, axis, out_dim, out_ref_data); + + auto *out_data = out.mutable_data(); + auto *mapped_out = static_cast( + TargetWrapperCL::Map(out_data, 0, sizeof(float) * out_dim.production())); + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(mapped_out[i], out_ref_data[i], 1e-6); + } + TargetWrapperCL::Unmap(out_data, mapped_out); + TargetWrapperCL::Unmap(x0_data, mapped_x0); + TargetWrapperCL::Unmap(x1_data, mapped_x1); + TargetWrapperCL::Unmap(x2_data, mapped_x2); +} +#endif // concat_buffer + +// #define LOOP_TEST +// #define PRINT_RESULT +TEST(concat_image2d_fp32, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img) -> concat(img) -> " + "layout(img2buf) " + "-> host"; + +#ifdef LOOP_TEST + for (int n = 1; n <= 100; n += 33) { + for (auto c : {1, 3}) { + for (int h = 12; h <= 100; h += 13) { + for (int w = 12; w <= 100; w += 25) { + for (atuo &axis : {0, 1, 2, 3}) { +#else + const int n = 1; + const int c = 2; + const int h = 3; + const int w = 4; + const int axis = 1; +#endif // LOOP_TEST + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c + << " " << h << " " << w << " ========"; + LOG(INFO) << "======== axis: " << axis; + // set layout kernels + auto buf_to_img_kernels = + KernelRegistry::Global().Create("layout", + TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kImageDefault)); + auto buf_to_img_kernels1 = + KernelRegistry::Global().Create("layout", + TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kImageDefault)); + auto img_to_buf_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)); + auto concat_img_kernels = + KernelRegistry::Global().Create("concat", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(buf_to_img_kernels1.empty()); + ASSERT_FALSE(img_to_buf_kernels.empty()); + ASSERT_FALSE(concat_img_kernels.empty()); + + auto buf_to_img_kernel = std::move(buf_to_img_kernels.front()); + auto buf_to_img_kernel1 = std::move(buf_to_img_kernels1.front()); + auto img_to_buf_kernel = std::move(img_to_buf_kernels.front()); + auto concat_img_kernel = std::move(concat_img_kernels.front()); + LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc(); + LOG(INFO) << "get 1st-1 kernel: " << buf_to_img_kernel1->doc(); + LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc(); + LOG(INFO) << "get 3rd kernel: " << concat_img_kernel->doc(); + + // set tensors about op param + LOG(INFO) << "set tensors about op param"; + lite::Tensor x0, x1, y, concat_in0, concat_in1, concat_out, y_ref; + operators::LayoutParam BufferToImageParam0, BufferToImageParam1; + operators::LayoutParam ImageToBufferParam; + BufferToImageParam0.x = &x0; + BufferToImageParam0.y = &concat_in0; + BufferToImageParam1.x = &x1; + BufferToImageParam1.y = &concat_in1; + ImageToBufferParam.x = &concat_out; + ImageToBufferParam.y = &y; + std::vector ins; + operators::ConcatParam concatParam; + ins.push_back(&concat_in0); + ins.push_back(&concat_in1); + concatParam.x = ins; + concatParam.axis = axis; + concatParam.output = &concat_out; + + const DDim x0_dim = DDim(std::vector{n, c, h, w}); + DDim x1_dim = DDim(std::vector{n, c, h, w}); + DDim out_dim = DDim(std::vector{n, c, h, w}); + x1_dim[axis] += 2; + out_dim[axis] = x0_dim[axis] + x1_dim[axis]; + x0.Resize(x0_dim); + x1.Resize(x1_dim); + y.Resize(out_dim); + concat_in0.Resize(x0_dim); + concat_in1.Resize(x1_dim); + concat_out.Resize(out_dim); + y_ref.Resize(out_dim); + auto concat_image2d_shape = + paddle::lite::kernels::opencl::InitImageDimInfoWith(out_dim); + auto concat_image2d_shape_in0 = + paddle::lite::kernels::opencl::InitImageDimInfoWith(x0_dim); + auto concat_image2d_shape_in1 = + paddle::lite::kernels::opencl::InitImageDimInfoWith(x1_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + auto *x_data0 = x0.mutable_data(TARGET(kOpenCL)); + auto *x_data1 = x1.mutable_data(TARGET(kOpenCL)); + auto *y_data = y.mutable_data(TARGET(kOpenCL)); + auto *y_data_ref = y_ref.mutable_data(TARGET(kARM)); + auto *mapped_x0 = static_cast(TargetWrapperCL::Map( + x_data0, 0, sizeof(float) * x0_dim.production())); + auto *mapped_x1 = static_cast(TargetWrapperCL::Map( + x_data1, 0, sizeof(float) * x1_dim.production())); + auto *mapped_y = static_cast(TargetWrapperCL::Map( + y_data, 0, sizeof(float) * out_dim.production())); + for (int i = 0; i < x0_dim.production(); ++i) { + mapped_x0[i] = static_cast(i) - x0_dim.production() / 2; + } + for (int i = 0; i < x1_dim.production(); ++i) { + mapped_x1[i] = static_cast(i) - x1_dim.production() / 2; + } + for (int i = 0; i < out_dim.production(); ++i) { + mapped_y[i] = static_cast(0); + } + auto *concat_in_data0 = concat_in0.mutable_data( + concat_image2d_shape_in0["width"], + concat_image2d_shape_in0["height"]); + auto *concat_in_data1 = concat_in1.mutable_data( + concat_image2d_shape_in1["width"], + concat_image2d_shape_in1["height"]); + auto *concat_out_data = concat_out.mutable_data( + concat_image2d_shape["width"], concat_image2d_shape["height"]); + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + buf_to_img_kernel->SetParam(BufferToImageParam0); + std::unique_ptr buf_to_img_context( + new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + buf_to_img_kernel->SetContext(std::move(buf_to_img_context)); + buf_to_img_kernel1->SetParam(BufferToImageParam1); + std::unique_ptr buf_to_img_context1( + new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context1->As())); + buf_to_img_kernel1->SetContext(std::move(buf_to_img_context1)); + + img_to_buf_kernel->SetParam(ImageToBufferParam); + std::unique_ptr img_to_buf_context( + new KernelContext); + context->As().CopySharedTo( + &(img_to_buf_context->As())); + img_to_buf_kernel->SetContext(std::move(img_to_buf_context)); + + concat_img_kernel->SetParam(concatParam); + std::unique_ptr concat_img_context( + new KernelContext); + context->As().CopySharedTo( + &(concat_img_context->As())); + concat_img_kernel->SetContext(std::move(concat_img_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_kernel->Launch(); + buf_to_img_kernel1->Launch(); + LOG(INFO) << "run kernel: concat_img_kernel"; + concat_img_kernel->Launch(); + LOG(INFO) << "run kernel: img_to_buf_kernel"; + img_to_buf_kernel->Launch(); + + // compute ref cp_u + std::vector ins_ptr; + std::vector in_dim; + ins_ptr.push_back(mapped_x0); + ins_ptr.push_back(mapped_x1); + in_dim.push_back(x0_dim); + in_dim.push_back(x1_dim); + concat_mul_compute_ref( + ins_ptr, in_dim, axis, out_dim, y_data_ref); +// result +#ifdef PRINT_RESULT + LOG(INFO) << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < out_dim.production(); ++eidx) { + std::cout << mapped_x0[eidx] << ", " << mapped_x1[eidx] << " -> " + << mapped_y[eidx] << std::endl; + } +#endif // PRINT_RESULT + + // check result: compare kernel output and cpu output(y_data_ref) + for (int eidx = 0; eidx < out_dim.production(); eidx++) { + EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6); + if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx + << " / " << x0_dim.production() << ", y_data_ref[" + << eidx << "]:" << y_data_ref[eidx] << ", mapped_y[" + << eidx << "]:" << mapped_y[eidx]; + break; + } + } + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data0, mapped_x0); + TargetWrapperCL::Unmap(x_data1, mapped_x1); + TargetWrapperCL::Unmap(y_data, mapped_y); +#ifdef LOOP_TEST + } // axis + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} +} // namespace lite +} // namespace paddle + +// concat buffer +USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def); + +// concat image2d fp32 +USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); +USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW); +USE_LITE_KERNEL(concat, kOpenCL, kFloat, kImageDefault, ImageDefault); diff --git a/lite/kernels/opencl/conv_compute.cc b/lite/kernels/opencl/conv_compute.cc index 0cc256478a80f17ce2efe15b8e43adc38a789921..c3d3e2a6c27f794268ef42ac97ab492ddd4e9de1 100644 --- a/lite/kernels/opencl/conv_compute.cc +++ b/lite/kernels/opencl/conv_compute.cc @@ -362,6 +362,20 @@ void ConvImageCompute::PrepareForRun() { filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); impl_ = &ConvImageCompute::Conv2d1x1; + } else if (kernel_h == 3 && kernel_h == 3) { + // conv2d_3x3 + kernel_func_names_.push_back("conv2d_3x3"); + kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl"); + + CLImageConverterFolder converter; + const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); + std::vector filter_image_v(filter_image_dims[0] * + filter_image_dims[1] * 4); // 4 : RGBA + converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); + filter_gpu_image_.mutable_data( + filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); + + impl_ = &ConvImageCompute::Conv2d3x3; } else if (kernel_h == 5 && kernel_w == 5) { // conv2d_5x5 kernel_func_names_.push_back("conv2d_5x5"); @@ -582,6 +596,184 @@ void ConvImageCompute::Conv2d1x1() { CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); } + +void ConvImageCompute::Conv2d3x3() { + const auto& param = *param_.get_mutable(); + auto input_dims = param.x->dims(); + auto paddings = *param.paddings; + auto strides = param.strides; + + auto* input_image = param.x->data(); + auto* filter_image = filter_gpu_image_.data(); + auto filter_dims = param.filter->dims(); + auto output_dims = param.output->dims(); + + int input_width = input_dims[3]; + int input_height = input_dims[2]; + int input_channel = input_dims[1]; + int output_width = output_dims[3]; + int output_height = output_dims[2]; + int output_channel = output_dims[1]; + int filter_width = filter_dims[3]; + int filter_height = filter_dims[2]; + int filter_channel = filter_dims[1]; + auto out_image_shape = InitImageDimInfoWith(output_dims); + auto* out_image = param.output->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + int offset = static_cast(param.filter->dims()[2]) / 2 - + static_cast(paddings[0]); + + // calc input_c_block + auto input_image_shape = InitImageDimInfoWith(input_dims); + int input_c_block = input_image_shape["width"] / input_dims[3]; + int input_c = input_dims[1]; + auto dilations = *param.dilations; + + // re-calc group + int new_groups{param.groups}; + if (filter_dims[0] == output_dims[1] && filter_dims[1] == input_dims[1]) { + new_groups = 1; + } else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) { + new_groups = input_channel / filter_channel; + } + /* TODO(ysh329): mobile has no case below + else { + LOG(FATAL) << "Not support conv3x3 case with" + << " input_dims:" << input_dims << " output_dims:" << + output_dims + << " filter_dims:" << filter_dims; + } + */ + + const std::vector& default_work_size = + DefaultWorkSize(output_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + VLOG(4) << "============ conv2d params ============"; + VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," + << input_image_shape["height"]; + VLOG(4) << "input_c_block: " << input_c_block; + VLOG(4) << "input_c: " << input_c; + VLOG(4) << "input_image: " << input_image; + VLOG(4) << "input_dims: " << input_dims; + VLOG(4) << "filter_dims: " << filter_dims; + VLOG(4) << "filter_image: " << filter_image; + VLOG(4) << "output_dims: " << output_dims; + VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " + << out_image_shape["height"]; + VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; + VLOG(4) << "has bias: " << has_bias; + VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; + VLOG(4) << "strides: " << strides[0] << "," << strides[1]; + VLOG(4) << "offset: " << offset; + VLOG(4) << "dilations.size : " << dilations.size(); + VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + VLOG(4) << "param.groups(groups):" << param.groups; + VLOG(4) << "new_groups:" << new_groups; + VLOG(4) << "default work size{c_block, w, nh}: " + << "{" << c_block << ", " << w << ", " << nh << "" + << "}"; + + CHECK_GE(dilations.size(), 2); + CHECK(dilations[0] == dilations[1]); + CHECK_GE(input_dims.size(), 4); + CHECK_GE(paddings.size(), 2); + CHECK(paddings[0] == paddings[1]); + CHECK_GE(strides.size(), 2); + CHECK(strides[0] == strides[1]); + + const cl::Image2D* bias_image = nullptr; + if (has_bias) { + bias_image = bias_gpu_image_.data(); + } + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_names_[0] << build_options_[0]; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + VLOG(4) << "kernel_key: " << kernel_key.str(); + VLOG(4) << "kernel ready ... " << kernel_key.str(); + VLOG(4) << "w: " << w; + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, c_block); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, w); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, nh); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *input_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *filter_image); + CL_CHECK_FATAL(status); + if (has_bias) { + VLOG(4) << "set bias_image: "; + status = kernel.setArg(++arg_idx, *bias_image); + CL_CHECK_FATAL(status); + } + status = kernel.setArg(++arg_idx, *out_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, strides[0]); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, offset); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_c_block); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, dilations[0]); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_channel); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_channel); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, new_groups); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(default_work_size.data()[1]), + static_cast(default_work_size.data()[2])}; + + VLOG(4) << "out_image: " << out_image; + VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," + << global_work_size[1] << "," << global_work_size[2] << "}"; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_image, event_); +} + void ConvImageCompute::Conv2d5x5() { const auto& param = *param_.get_mutable(); auto input_dims = param.x->dims(); @@ -726,6 +918,7 @@ void ConvImageCompute::Conv2d5x5() { CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); } + void ConvImageCompute::Conv2d7x7() { const auto& param = *param_.get_mutable(); auto input_dims = param.x->dims(); diff --git a/lite/kernels/opencl/conv_compute.h b/lite/kernels/opencl/conv_compute.h index 5b98767af0a740ce4a0adbc671000a36a156240e..d5dd65cdc855ebc25624e8316866a5944a2418b8 100644 --- a/lite/kernels/opencl/conv_compute.h +++ b/lite/kernels/opencl/conv_compute.h @@ -71,6 +71,7 @@ class ConvImageCompute : public KernelLite 1) { + filter_channel = 1; + } + + const int oh = + ConvOutputSize(ih, ksize, dilation, pad, pad, stride); + const int ow = + ConvOutputSize(iw, ksize, dilation, pad, pad, stride); + SHADOW_LOG << "to get kernel ..."; + auto kernels = + KernelRegistry::Global().Create("conv2d", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + CHECK(batch_size == 1) << "conv3x3 only supprt batch_size == 1"; + + auto kernel = std::move(kernels.front()); + SHADOW_LOG << "created conv2d kernel"; + + SHADOW_LOG << "prepare kernel ------"; + + lite::Tensor input, filter, bias, output; + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + param.groups = group; + if (bias_flag) { + param.bias = &bias; + } + if (relu_flag == "relu") { + param.fuse_relu = true; + } else if (relu_flag == "None") { + param.fuse_relu = false; + } else if (relu_flag == "relu6") { + param.activation_param.Relu_clipped_coef = 6.f; + param.activation_param.has_active = true; + param.activation_param.active_type = + lite_api::ActivationType::kRelu6; + } + + std::vector paddings = {pad, pad, pad, pad}; + std::vector dilations = {dilation, dilation}; + + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); + param.strides = std::vector{stride, stride}; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + std::unique_ptr conv_1x1_context( + new KernelContext); + context->As().CopySharedTo( + &(conv_1x1_context->As())); + kernel->SetContext(std::move(conv_1x1_context)); + + const DDim& input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + + const DDim& filter_dim = lite::DDim{ + std::vector({oc, filter_channel, ksize, ksize})}; + const DDim& out_dim = + lite::DDim{std::vector({batch_size, oc, oh, ow})}; + // element wise bias + const DDim& bias_dim = lite::DDim{std::vector({oc})}; + + LOG(INFO) << "input_dim:" << input_dim + << " filter_dim:" << filter_dim + << " out_dim:" << out_dim; + + param.x->Resize(input_dim); + param.filter->Resize(filter_dim); + param.output->Resize(out_dim); + if (bias_flag) { + param.bias->Resize(bias_dim); + } + + kernel->SetParam(param); + + size_t input_image_width = iw * ((ic + 3) / 4); + size_t input_image_height = ih * batch_size; + + size_t out_image_width = ow * ((oc + 3) / 4); + size_t out_image_height = oh * batch_size; + + size_t bias_image_width = ow * ((oc + 3) / 4); + size_t bias_image_height = oh * batch_size; + + size_t filter_image_width = ksize * ((filter_channel + 3) / 4); + size_t filter_image_height = oc * ksize; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + + std::vector input_v(batch_size * ic * ih * iw); + std::vector filter_v(oc * filter_channel * ksize * ksize); + std::vector output_v(batch_size * oc * oh * ow); + std::vector bias_v(oc); + + SHADOW_LOG << "gen input and filter ..."; + for (int i = 0; i < input_v.size(); ++i) { + input_v[i] = i; // gen(engine); + } + for (int i = 0; i < filter_v.size(); ++i) { + filter_v[i] = 1; // gen(engine); + } + + SHADOW_LOG << "after gen input and filter ..."; + SHADOW_LOG << "input_v.size(): " << input_v.size(); + SHADOW_LOG << "filter_v.size(): " << filter_v.size(); + SHADOW_LOG << "output_v.size(): " << output_v.size(); + SHADOW_LOG << "bias_v.size(): " << bias_v.size(); + SHADOW_LOG << "input_dim.production(): " + << input_dim.production(); + SHADOW_LOG << "filter_dim.production(): " + << filter_dim.production(); + SHADOW_LOG << "out_dim.production(): " << out_dim.production(); + SHADOW_LOG << "bias_dim.production(): " << bias_dim.production(); + SHADOW_LOG << "input_image_height:" << input_image_height + << " input_image_width:" << input_image_width; + SHADOW_LOG << "filter_image_height:" << filter_image_height + << " filter_image_width:" << filter_image_width; + SHADOW_LOG << "4 * input_image_height *input_image_width: " + << 4 * input_image_height * input_image_width; + SHADOW_LOG << "4 * filter_image_width * filter_image_height: " + << 4 * filter_image_width * filter_image_height; + + CHECK(input_dim.production() == input_v.size()); + CHECK_LE(input_dim.production(), + 4 * input_image_height * input_image_width); + CHECK(filter_dim.production() == filter_v.size()); + CHECK_LE(filter_dim.production(), + 4 * filter_image_width * filter_image_height); + + paddle::lite::CLImageConverterDefault default_convertor; + SHADOW_LOG << "set mapped input ..."; + std::vector x_image_v(input_image_width * + input_image_height * 4); // 4 :RGBA + std::vector filter_image_v( + filter_image_width * filter_image_height * 4); // 4 : RGBA + std::vector bias_image_v( + bias_image_width * bias_image_height * 4); // 4 : RGBA + std::vector out_image_v(out_image_width * + out_image_height * 4); // 4 :RGBA + + default_convertor.NCHWToImage( + input_v.data(), x_image_v.data(), input_dim); + SHADOW_LOG << "输入: ---- "; + for (int i = 0; i < input_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << input_v[i]; + } + SHADOW_LOG << "输入image : ---- "; + for (int i = 0; i < x_image_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << x_image_v[i]; + } + SHADOW_LOG << "set mapped filter ..."; + CLImageConverterFolder folder_convertor; + + folder_convertor.NCHWToImage( + filter_v.data(), filter_image_v.data(), filter_dim); + SHADOW_LOG << "卷积核: ---- "; + for (int i = 0; i < filter_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << filter_v[i]; + } + SHADOW_LOG << "卷积核image: ---- "; + for (int i = 0; i < filter_image_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << filter_image_v[i]; + } + auto* input_image2d = input.mutable_data( + input_image_width, input_image_height, x_image_v.data()); + // assign filter as target arm + filter.Assign(filter_v.data(), + filter_dim); + // filter kernel + // auto* filter_image2d = filter.mutable_data( + // filter_image_width, + // filter_image_height, + // filter_image_v.data()); + + if (bias_flag) { + for (int i = 0; i < bias_dim.production(); ++i) { + bias_v[i] = static_cast(gen(engine)); + } + bias.Assign(bias_v.data(), + bias_dim); + // CLImageConverterFolder folder_convertor; + // folder_convertor.NCHWToImage( + // bias_v.data(), bias_image_v.data(), + // bias_dim); + // + // auto* bias_data = bias.mutable_data( + // bias_image_width, bias_image_height, + // bias_image_v.data()); + } + + SHADOW_LOG << "resize output ..."; + output.Resize(out_dim); + + // cpu conv basic calc + lite::Tensor out_ref; + out_ref.Resize(out_dim); + + SHADOW_LOG << "prepare kernel ready"; + + SHADOW_LOG << "kernel launch ..."; + kernel->Launch(); + SHADOW_LOG << "mutable output ..."; + auto* output_image2d = output.mutable_data( + out_image_width, out_image_height); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + SHADOW_LOG << "--- Find the sync event for the target cl " + "tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + + TargetWrapperCL::ImgcpySync(out_image_v.data(), + output.data(), + out_image_width, + out_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + + DDim out_image_shape = + default_convertor.InitImageDimInfoWith(output.dims()); + + default_convertor.ImageToNCHW(out_image_v.data(), + output_v.data(), + out_image_shape, + output.dims()); + + SHADOW_LOG << "输出: ---- "; + for (int i = 0; i < output_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << output_v[i]; + } + + SHADOW_LOG << "输出image: ---- "; + for (int i = 0; i < out_image_v.size(); i++) { + SHADOW_LOG << "(" << i << ")" << out_image_v[i]; + } + SHADOW_LOG << "mutable_data out_ref_data: "; + + // run cpu ref + auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); + + SHADOW_LOG << " conv_basic beigin ..... "; + + conv_basic(input_v.data(), + out_ref_data, + batch_size, + oc, + oh, + ow, + ic, + ih, + iw, + filter_v.data(), + bias_v.data(), // mapped_bias, + group, + ksize, + ksize, + stride, + stride, + dilation, + dilation, + pad, + pad, + bias_flag, + relu_flag); + SHADOW_LOG << " conv_basic end ..... "; + + SHADOW_LOG << " out_dim: " << out_dim; + const DDim& out_image_dims = lite::DDim{std::vector( + {static_cast(out_image_width), + static_cast(out_image_height)})}; + +#ifdef PRINT_RESULT + for (int i = 0; i < out_dim.production(); i++) { + VLOG(4) << "output_v[" << i << "]:" << output_v[i] + << " out_ref_data[" << i << "]:" << out_ref_data[i]; + } +#endif + + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2); + if (abs(output_v[i] - out_ref_data[i]) > 1e-2) { + LOG(FATAL) << "error idx:" << i; + } + } + +#ifdef LOOP_TEST + } + } + } + } + } + } +#else +// nothing to do. +#endif +} +#undef LOOP_TEST +#undef PRINT_RESULT + // #define PRINT_RESULT // #define LOOP_TEST TEST(conv2d, compute_image2d_5x5) { diff --git a/lite/kernels/opencl/scale_compute.cc b/lite/kernels/opencl/scale_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a7d4d4f61d452bb6193277766ecf94fd6034c6b --- /dev/null +++ b/lite/kernels/opencl/scale_compute.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class ScaleComputeImage2D : public KernelLite { + public: + using param_t = operators::ScaleParam; + + std::string doc() const override { return "Scale using cl::Image2D, kFloat"; } + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/scale_kernel.cl", build_options_); + } + + void Run() override { + const auto& param = *param_.get_mutable(); + const auto& in_dims = param.x->dims(); + auto* x_img = param.x->data(); + const float scale = param.scale; + const float bias = param.bias; + + LOG(INFO) << "x_image" << x_img; + auto out_image_shape = InitImageDimInfoWith(in_dims); + LOG(INFO) << "out_image_shape = " << out_image_shape["width"] << " " + << out_image_shape["height"]; + auto* out_img = param.output->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + LOG(INFO) << "out_image" << out_img; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + auto global_work_size = + cl::NDRange{static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])}; + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, scale); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, bias); + CL_CHECK_FATAL(status); + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); + } + + private: + std::string kernel_func_name_{"scale"}; + std::string build_options_{"-DCL_DTYPE_float"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(scale, + kOpenCL, + kFloat, + kImageDefault, + paddle::lite::kernels::opencl::ScaleComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/scale_compute_test.cc b/lite/kernels/opencl/scale_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..72381fee4f62e029172286fd70aae9fcd6380515 --- /dev/null +++ b/lite/kernels/opencl/scale_compute_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +void scale(const float* input_data, + const DDim& in_dim, + float* output_data, + const float scale, + const float bias) { + for (int i = 0; i < in_dim.production(); i++) { + output_data[i] = input_data[i] * scale + bias; + } +} + +TEST(scale_image2d_fp32, compute) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "scale", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "get kernel:" << kernel->doc(); + + lite::Tensor x, out; + operators::ScaleParam param; + param.x = &x; + param.output = &out; + param.scale = 1.5f; + param.bias = 0.3f; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr scale_context(new KernelContext); + context->As().CopySharedTo( + &(scale_context->As())); + kernel->SetContext(std::move(scale_context)); + + const DDim in_dim = DDim(std::vector{4, 11, 107, 107}); + const DDim out_dim = DDim(std::vector{4, 11, 107, 107}); + x.Resize(in_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-5, 5); + std::vector input_v(4 * 11 * 107 * 107); + for (auto& i : input_v) { + i = dist(engine); + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "image_shape = " << image_shape[0] << " " << image_shape[1]; + std::vector x_image_data(image_shape.production() * 4); // 4 : RGBA + default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + image_shape[0], image_shape[1], x_image_data.data()); + LOG(INFO) << "x_image:" << x_image; + + auto* out_image = + out.mutable_data(image_shape[0], image_shape[1]); + LOG(INFO) << "out_image:" << out_image; + kernel->Launch(); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target cl tensor."; + } + + std::unique_ptr out_ref(new float[out_dim.production()]); + scale(input_v.data(), in_dim, out_ref.get(), 1.5f, 0.3f); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + float* out_image_data = new float[image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + image_shape[0], + image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + float* out_data = new float[image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, image_shape, out_dim); + + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(out_data[i], out_ref[i], 1e-6); + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(scale, kOpenCL, kFloat, kImageDefault, image2d); diff --git a/lite/kernels/xpu/bridges/softmax_op.cc b/lite/kernels/xpu/bridges/softmax_op.cc index d964f29a86ac00034c61706af35f8ca220921ec0..740764015082a4c21bdef443e76e90065b2a99cb 100644 --- a/lite/kernels/xpu/bridges/softmax_op.cc +++ b/lite/kernels/xpu/bridges/softmax_op.cc @@ -41,7 +41,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto out_type = kernel->GetOutputDeclType("Out"); CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->layout() == DATALAYOUT(kNCHW)); - auto axis = op_info->GetAttr("axis"); + int axis = op_info->HasAttr("axis") ? op_info->GetAttr("axis") : -1; // X node std::shared_ptr x_node = nullptr; diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index 0dcb8e1eeab4b07d533a1bfc57cb8d9ca38b4d82..5d00570703f2caaf71ff5b5e1e6c3ad9e27eb6f2 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -20,6 +20,7 @@ #include "lite/core/scope.h" #include "lite/core/tensor.h" #include "lite/core/variable.h" +#include "lite/core/version.h" #include "lite/model_parser/desc_apis.h" #include "lite/model_parser/naive_buffer/combined_params_desc.h" #include "lite/model_parser/naive_buffer/param_desc.h" @@ -536,7 +537,7 @@ void SaveCombinedParamsNaive(const std::string &path, } pt_desc.Save(); - table.SaveToFile(path); + table.AppendToFile(path); } void SaveModelNaive(const std::string &model_dir, @@ -545,30 +546,46 @@ void SaveModelNaive(const std::string &model_dir, bool combined) { MkDirRecur(model_dir); // Save program - const std::string prog_path = model_dir + "/__model__.nb"; + const std::string prog_path = model_dir + ".nb"; naive_buffer::BinaryTable table; naive_buffer::proto::ProgramDesc nb_proto_prog(&table); naive_buffer::ProgramDesc nb_prog(&nb_proto_prog); TransformProgramDescCppToAny(cpp_prog, &nb_prog); nb_proto_prog.Save(); - table.SaveToFile(prog_path); + // Save meta_version(uint16) into file + naive_buffer::BinaryTable meta_version_table; + meta_version_table.Require(sizeof(uint16_t)); + uint16_t meta_version = 0; + memcpy(meta_version_table.cursor(), &meta_version, sizeof(uint16_t)); + meta_version_table.Consume(sizeof(uint16_t)); + meta_version_table.SaveToFile(prog_path); + + // Save lite_version(char[16]) into file + const int paddle_version_length = 16 * sizeof(char); + naive_buffer::BinaryTable paddle_version_table; + paddle_version_table.Require(paddle_version_length); + std::string paddle_version = version(); + memcpy(paddle_version_table.cursor(), + paddle_version.c_str(), + paddle_version_length); + paddle_version_table.Consume(paddle_version_length); + paddle_version_table.AppendToFile(prog_path); + VLOG(4) << "paddle_version:" << paddle_version << std::endl; + + // Save topology_size(uint64) into file + naive_buffer::BinaryTable topology_size_table; + topology_size_table.Require(sizeof(uint64_t)); + uint64_t topology_size = table.size(); + memcpy(topology_size_table.cursor(), &topology_size, sizeof(uint64_t)); + topology_size_table.Consume(sizeof(uint64_t)); + topology_size_table.AppendToFile(prog_path); + + // save topology data into model file + table.AppendToFile(prog_path); // Save Params - // NOTE: Only main block be used now. - if (combined) { - const std::string combined_params_path = model_dir + "/param.nb"; - SaveCombinedParamsNaive(combined_params_path, exec_scope, cpp_prog); - } else { - auto prog = cpp_prog; - auto &main_block_desc = *prog.GetBlock(0); - for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { - auto &var = *main_block_desc.GetVar(i); - if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) - continue; - const std::string path = model_dir + "/" + var.Name() + ".nb"; - SaveParamNaive(path, exec_scope, var.Name()); - } - } + SaveCombinedParamsNaive(prog_path, exec_scope, cpp_prog); + LOG(INFO) << "Save naive buffer model in '" << model_dir << "' successfully"; } #endif @@ -638,14 +655,15 @@ void LoadParamNaive(const std::string &path, } void LoadCombinedParamsNaive(const std::string &path, + const uint64_t &offset, lite::Scope *scope, const cpp::ProgramDesc &cpp_prog, bool params_from_memory) { naive_buffer::BinaryTable table; if (params_from_memory) { - table.LoadFromMemory(path.c_str(), path.length()); + table.LoadFromMemory(path.c_str() + offset, path.length() - offset); } else { - table.LoadFromFile(path); + table.LoadFromFile(path, offset, 0); } naive_buffer::proto::CombinedParamsDesc pt_desc(&table); pt_desc.Load(); @@ -693,7 +711,7 @@ void LoadModelNaive(const std::string &model_dir, // NOTE: Only main block be used now. if (combined) { const std::string combined_params_path = model_dir + "/param.nb"; - LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, false); + LoadCombinedParamsNaive(combined_params_path, 0, scope, *cpp_prog, false); } else { auto &prog = *cpp_prog; auto &main_block_desc = *prog.GetBlock(0); @@ -718,6 +736,84 @@ void LoadModelNaive(const std::string &model_dir, VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully"; } +/* + * Binary structure of naive_buffer model: model.nb + * ---------------------------------------------------------- + * | | PART | Precision | Length(byte) | + * | 1 | meta_version | uint16_t | 2 | + * | 2 | opt_version | char[16] | 16 | + * | 3 | topo_size | uint64_t | 8 | + * | 4 | topo_data | char[] | topo_size byte | + * | 5 | param_data | char[] | | + * ---------------------------------------------------------- + * Meaning of each part: + * meta_version: meata_version, 0 default. + * opt_version: lite_version of opt tool that transformed this model. + * topo_size: length of `topo_data`. + * topo_data: contains model's topology data. + * param_data: contains model's params data. +*/ + +// usage: LoadModelNaiveFromFile is used for loading model from file. +template +void ReadModelDataFromFile(T *data, + const std::string &prog_path, + uint64_t *offset, + const uint64_t &size) { + naive_buffer::BinaryTable data_table; + data_table.LoadFromFile(prog_path, *offset, size); + memcpy(data, data_table.cursor(), size); + *offset = *offset + size; +} + +void LoadModelNaiveFromFile(const std::string &filename, + Scope *scope, + cpp::ProgramDesc *cpp_prog) { + CHECK(cpp_prog); + CHECK(scope); + cpp_prog->ClearBlocks(); + // ModelFile + const std::string prog_path = filename; + + // Offset + uint64_t offset = 0; + + // (1)get meta version + uint16_t meta_version; + ReadModelDataFromFile( + &meta_version, prog_path, &offset, sizeof(uint16_t)); + VLOG(4) << "Meta_version:" << meta_version; + + // (2)get opt version + char opt_version[16]; + const uint64_t paddle_version_length = 16 * sizeof(char); + ReadModelDataFromFile( + opt_version, prog_path, &offset, paddle_version_length); + VLOG(4) << "Opt_version:" << opt_version; + + // (3)get topo_size + uint64_t topo_size; + ReadModelDataFromFile( + &topo_size, prog_path, &offset, sizeof(uint64_t)); + + // (4)get topo data + naive_buffer::BinaryTable topo_table; + topo_table.LoadFromFile(prog_path, offset, topo_size); + offset = offset + topo_size; + // transform topo_data into cpp::ProgramDesc + naive_buffer::proto::ProgramDesc nb_proto_prog(&topo_table); + nb_proto_prog.Load(); + naive_buffer::ProgramDesc nb_prog(&nb_proto_prog); + TransformProgramDescAnyToCpp(nb_prog, cpp_prog); + + // (5)Load Params + LoadCombinedParamsNaive(prog_path, offset, scope, *cpp_prog, false); + + VLOG(4) << "Load naive buffer model in '" << filename << "' successfully"; +} + +// warning: this is an old inference and is not suggested. +// todo: this inference will be abandened in release/v3.0.0 void LoadModelNaiveFromMemory(const std::string &model_buffer, const std::string ¶m_buffer, Scope *scope, @@ -741,7 +837,64 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer, // Load Params // NOTE: Only main block be used now. // only combined Params are supported in Loading Model from memory - LoadCombinedParamsNaive(param_buffer, scope, *cpp_prog, true); + LoadCombinedParamsNaive(param_buffer, 0, scope, *cpp_prog, true); + + VLOG(4) << "Load model from naive buffer memory successfully"; +} + +// usage: LoadModelNaiveFromMemory is used for loading naive model from memory +template +void ReadModelDataFromBuffer(T *data, + const std::string &model_buffer, + uint64_t *offset, + const uint64_t &size) { + naive_buffer::BinaryTable data_table; + data_table.LoadFromMemory(model_buffer.c_str() + *offset, size); + memcpy(data, data_table.cursor(), size); + *offset = *offset + size; +} +void LoadModelNaiveFromMemory(const std::string &model_buffer, + Scope *scope, + cpp::ProgramDesc *cpp_prog) { + CHECK(cpp_prog); + CHECK(scope); + cpp_prog->ClearBlocks(); + + // Offset + uint64_t offset = 0; + + // (1)get meta version + uint16_t meta_version; + ReadModelDataFromBuffer( + &meta_version, model_buffer, &offset, sizeof(uint16_t)); + VLOG(4) << "Meta_version:" << meta_version; + + // (2)get opt version + char opt_version[16]; + const uint64_t paddle_version_length = 16 * sizeof(char); + ReadModelDataFromBuffer( + opt_version, model_buffer, &offset, paddle_version_length); + VLOG(4) << "Opt_version:" << opt_version; + + // (3)get topo_size and topo_data + uint64_t topo_size; + ReadModelDataFromBuffer( + &topo_size, model_buffer, &offset, sizeof(uint64_t)); + naive_buffer::BinaryTable table; + table.LoadFromMemory(model_buffer.c_str() + offset, topo_size); + offset = offset + topo_size; + + naive_buffer::proto::ProgramDesc nb_proto_prog(&table); + nb_proto_prog.Load(); + naive_buffer::ProgramDesc nb_prog(&nb_proto_prog); + + // Transform to cpp::ProgramDesc + TransformProgramDescAnyToCpp(nb_prog, cpp_prog); + + // Load Params + // NOTE: Only main block be used now. + // only combined Params are supported in Loading Model from memory + LoadCombinedParamsNaive(model_buffer, offset, scope, *cpp_prog, true); VLOG(4) << "Load model from naive buffer memory successfully"; } diff --git a/lite/model_parser/model_parser.h b/lite/model_parser/model_parser.h index bca7533c24af517994dae677c7b63e088f2ef1ca..e4641f69ada380c91f69280290dd020ea27d2ad1 100644 --- a/lite/model_parser/model_parser.h +++ b/lite/model_parser/model_parser.h @@ -94,15 +94,22 @@ void LoadParamNaive(const std::string& path, lite::Scope* scope, const std::string& name); +// warning:this old inference will be abandened in release/v3.0.0 +// and LoadModelNaiveFromFile is suggested. void LoadModelNaive(const std::string& model_dir, lite::Scope* scope, cpp::ProgramDesc* prog, bool combined = true); - +void LoadModelNaiveFromFile(const std::string& filename, + lite::Scope* scope, + cpp::ProgramDesc* prog); void LoadModelNaiveFromMemory(const std::string& model_buffer, const std::string& param_buffer, lite::Scope* scope, cpp::ProgramDesc* cpp_prog); +void LoadModelNaiveFromMemory(const std::string& model_buffer, + lite::Scope* scope, + cpp::ProgramDesc* cpp_prog); } // namespace lite } // namespace paddle diff --git a/lite/model_parser/model_parser_test.cc b/lite/model_parser/model_parser_test.cc index 58083027849cc007bce80bd10004d0a13259fda7..d9c0f501c37862236cacd2624dc70c8cf1dacc86 100644 --- a/lite/model_parser/model_parser_test.cc +++ b/lite/model_parser/model_parser_test.cc @@ -121,17 +121,23 @@ TEST(ModelParser, SaveModelNaive) { SaveModelNaive(save_pb_model_path, scope, prog); } +TEST(ModelParser, LoadModelNaiveFromFile) { + CHECK(!FLAGS_model_dir.empty()); + cpp::ProgramDesc prog; + Scope scope; + + auto model_path = std::string(FLAGS_model_dir) + ".saved.naive.nb"; + LoadModelNaiveFromFile(model_path, &scope, &prog); +} + TEST(ModelParser, LoadModelNaiveFromMemory) { CHECK(!FLAGS_model_dir.empty()); cpp::ProgramDesc prog; Scope scope; - auto model_path = std::string(FLAGS_model_dir) + ".saved.naive/__model__.nb"; - auto params_path = std::string(FLAGS_model_dir) + ".saved.naive/param.nb"; + auto model_path = std::string(FLAGS_model_dir) + ".saved.naive.nb"; std::string model_buffer = lite::ReadFile(model_path); - std::string params_buffer = lite::ReadFile(params_path); - - LoadModelNaiveFromMemory(model_buffer, params_buffer, &scope, &prog); + LoadModelNaiveFromMemory(model_buffer, &scope, &prog); } } // namespace lite diff --git a/lite/model_parser/naive_buffer/naive_buffer.cc b/lite/model_parser/naive_buffer/naive_buffer.cc index cefaf0c28a34a70c095362e9972c9ef99d5fa80c..02538602fb5b5ae319d1041d501a87c212e47d2d 100644 --- a/lite/model_parser/naive_buffer/naive_buffer.cc +++ b/lite/model_parser/naive_buffer/naive_buffer.cc @@ -44,24 +44,37 @@ void BinaryTable::SaveToFile(const std::string &filename) const { fclose(fp); } -void BinaryTable::LoadFromFile(const std::string &filename) { - // get file size +void BinaryTable::AppendToFile(const std::string &filename) const { + FILE *fp = fopen(filename.c_str(), "ab"); + CHECK(fp) << "Unable to open file: " << filename; + if (fwrite(reinterpret_cast(data()), 1, size(), fp) != size()) { + fclose(fp); + LOG(FATAL) << "Write file error: " << filename; + } + fclose(fp); +} + +void BinaryTable::LoadFromFile(const std::string &filename, + const size_t &offset, + const size_t &size) { + // open file in readonly mode FILE *fp = fopen(filename.c_str(), "rb"); CHECK(fp) << "Unable to open file: " << filename; - fseek(fp, 0L, SEEK_END); - size_t file_size = ftell(fp); - LOG(INFO) << "file size " << file_size; - - // load data. - fseek(fp, 0L, SEEK_SET); - Require(file_size); - if (fread(reinterpret_cast(&bytes_[0]), 1, file_size, fp) != - file_size) { + // move fstream pointer backward for size of offset + size_t buffer_size = size; + if (size == 0) { + fseek(fp, 0L, SEEK_END); + buffer_size = ftell(fp) - offset; + } + fseek(fp, offset, SEEK_SET); + Require(buffer_size); + // read data of `size` into binary_data_variable:`bytes_` + if (fread(reinterpret_cast(&bytes_[0]), 1, buffer_size, fp) != + buffer_size) { fclose(fp); LOG(FATAL) << "Read file error: " << filename; } fclose(fp); - // Set readonly. is_mutable_mode_ = false; } diff --git a/lite/model_parser/naive_buffer/naive_buffer.h b/lite/model_parser/naive_buffer/naive_buffer.h index 9be2be954328e757e79a880f34b49c3f0cf77c7a..5be17856a25aabfed81ae88d80e788c8dd2be4bc 100644 --- a/lite/model_parser/naive_buffer/naive_buffer.h +++ b/lite/model_parser/naive_buffer/naive_buffer.h @@ -61,8 +61,12 @@ struct BinaryTable { /// Serialize the table to a binary buffer. void SaveToFile(const std::string& filename) const; + void AppendToFile(const std::string& filename) const; - void LoadFromFile(const std::string& filename); + // void LoadFromFile(const std::string& filename); + void LoadFromFile(const std::string& filename, + const size_t& offset = 0, + const size_t& size = 0); void LoadFromMemory(const char* buffer, size_t buffer_size); }; diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt old mode 100644 new mode 100755 index ccc9c825db8a8a5030c6481ee0e33b8f324f4d11..61d568426525efc7fe2bd0109882fc149b92d3d2 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -135,6 +135,8 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS ${op_DEPS}) add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) +add_operator(one_hot basic SRCS one_hot_op.cc DEPS ${op_DEPS}) + if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc DEPS fc_op memory diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index 6faa9eb225c76735460227b77387d0b0e8157525..76c257c6d34f0a82a920eaf49c1ef88efbd0daf4 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -82,7 +82,20 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.variance = scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); - param_.is_test = op_desc.GetAttr("is_test"); + + auto is_test_type = op_desc.GetAttrType("is_test"); + switch (is_test_type) { + case OpDescAPI::AttrType::INT: + param_.is_test = op_desc.GetAttr("is_test"); + break; + case OpDescAPI::AttrType::BOOLEAN: + param_.is_test = op_desc.GetAttr("is_test"); + break; + default: + LOG(FATAL) << "Unsupported attribute type: the type of attribute " + "`is_test` in BatchNormOP should be int or bool."; + } + if (op_desc.HasAttr("use_global_stats")) { param_.use_global_stats = op_desc.GetAttr("use_global_stats"); } diff --git a/lite/operators/batch_norm_op_test.cc b/lite/operators/batch_norm_op_test.cc index 574bb4cfd316b05bf08086d865f4eb7de7dd03a3..b79037c0bc9c3e9188eaf0e54b3f958960ab0893 100644 --- a/lite/operators/batch_norm_op_test.cc +++ b/lite/operators/batch_norm_op_test.cc @@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { desc.SetInput("Mean", {"mean"}); desc.SetInput("Variance", {"variance"}); desc.SetOutput("Y", {"y"}); - desc.SetAttr("is_test", static_cast(1)); + desc.SetAttr("is_test", static_cast(true)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); @@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) { desc.SetOutput("VarianceOut", {"variance_out"}); desc.SetOutput("SavedMean", {"saved_mean"}); desc.SetOutput("SavedVariance", {"saved_variance"}); - desc.SetAttr("is_test", static_cast(0)); + desc.SetAttr("is_test", static_cast(false)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); diff --git a/lite/operators/fake_quantize_range_abs_max.cc b/lite/operators/fake_quantize_range_abs_max.cc index a8ce3f75a59fec5b032c60f51177f428bd15fe0d..ebf7e41f4b1af6f6961da07fe95caece19fa59f5 100644 --- a/lite/operators/fake_quantize_range_abs_max.cc +++ b/lite/operators/fake_quantize_range_abs_max.cc @@ -23,3 +23,5 @@ namespace operators {} // namespace operators REGISTER_LITE_OP(fake_quantize_range_abs_max, paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); +REGISTER_LITE_OP(fake_quantize_abs_max, + paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h index 726731595a9c4b7cd2e30db911230cc2f00b5b92..f68d1e20f6e60bb5aa99a2402ea8c9f88aa18470 100644 --- a/lite/operators/fake_quantize_range_abs_max.h +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); - auto in_scale = op_desc.Input("InScale").front(); + if (op_desc.HasInput("InScale")) { + auto in_scale = op_desc.Input("InScale").front(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + } auto out = op_desc.Output("Out").front(); auto out_scale = op_desc.Output("OutScale").front(); param_.x = scope->FindVar(x)->GetMutable(); - param_.in_scale = scope->FindVar(in_scale)->GetMutable(); param_.out = scope->FindVar(out)->GetMutable(); param_.out_scale = scope->FindVar(out_scale)->GetMutable(); diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index 3cddde38b291f189649175a43c994d4fcfcabb9b..ec449cd4bdc33f191c33fc04f215ad672b283215 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -37,15 +37,6 @@ class FcOpLite : public OpLite { bool InferShape() const override; - /* - bool Run() override { - CHECK(kernel_); - kernel_->Run(); - return true; - } - */ - - // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/one_hot_op.cc b/lite/operators/one_hot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..023cdc23aeb8329736b7438af2c52cbfa899c75c --- /dev/null +++ b/lite/operators/one_hot_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/one_hot_op.h" +#include "lite/core/op_registry.h" + +#include "lite/backends/fpga/KD/debugger.hpp" + +namespace paddle { +namespace lite { +namespace operators { + +bool OneHotOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool OneHotOp::InferShape() const { + CHECK_OR_FALSE(param_.Out); + // TODO(Superjomn) Enable data sharing. + auto out_dims = param_.X->dims(); + + out_dims[out_dims.size() - 1] = param_.depth; + param_.Out->Resize(out_dims); + return true; +} + +bool OneHotOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.X = + scope->FindVar(opdesc.Input("X").front())->GetMutable(); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + + if (opdesc.HasInput("depth_tensor")) { + auto depth_tensor = opdesc.Input("depth_tensor").front(); + param_.depth_tensor = + scope->FindVar(depth_tensor)->GetMutable(); + } + + CHECK(param_.X); + CHECK(param_.Out); + param_.depth = opdesc.GetAttr("depth"); + param_.dtype = opdesc.GetAttr("dtype"); + + if (opdesc.HasAttr("allow_out_of_range")) { + param_.allow_out_of_range = opdesc.GetAttr("allow_out_of_range"); + } + + auto out_lod = param_.Out->mutable_lod(); + *out_lod = param_.X->lod(); + // param_.allow_out_of_range = opdesc.GetAttr("allow_out_of_range"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp); diff --git a/lite/operators/one_hot_op.h b/lite/operators/one_hot_op.h new file mode 100755 index 0000000000000000000000000000000000000000..4a0613952520279699a0f4a56d002483de325241 --- /dev/null +++ b/lite/operators/one_hot_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class OneHotOp : public OpLite { + public: + OneHotOp() {} + explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "one_hot"; } + + private: + mutable OneHotParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h old mode 100644 new mode 100755 index 9aba4a1f3e7b96abedb2f4d835f99072bf4b7f4e..9d752f4b725947afae400dfb489a3265c0e27bb9 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1133,7 +1133,15 @@ struct GridSamplerParam { lite::Tensor* out{}; lite::Tensor* grid{}; }; - -} // namespace operators -} // namespace lite -} // namespace paddle +/// --------------------- attentions operators -------------- +struct OneHotParam { + lite::Tensor* X{}; + lite::Tensor* depth_tensor{nullptr}; + lite::Tensor* Out{}; + int depth{-1}; + int dtype{}; + bool allow_out_of_range{false}; +}; +}; // namespace operators +}; // namespace lite +}; // namespace paddle diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index 1d5adaa6cca8986b2fb302c1f480730512b458b5..6d879385a27e834b3fa27835ee94edc599f5564c 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -192,7 +192,7 @@ class FcOPTest : public arena::TestCase { fill_data_rand(bin.data(), -1.f, 1.f, bdims_.production()); SetCommonTensor(input_, dims_, din.data()); - SetCommonTensor(weight_, wdims_, win.data()); + SetCommonTensor(weight_, wdims_, win.data(), {}, true); if (padding_weights_) { std::vector win_padding(wdims_padding_.production()); for (int64_t i = 0; i < wdims_[0]; ++i) { @@ -203,15 +203,15 @@ class FcOPTest : public arena::TestCase { SetCommonTensor(weight_padding_, wdims_padding_, win_padding.data()); } if (flag_bias) { - SetCommonTensor(bias_, bdims_, bin.data()); + SetCommonTensor(bias_, bdims_, bin.data(), {}, true); } } }; -void TestFCMain(Place place, - float abs_error, - bool with_relu = false, - bool padding = false) { +void TestFC2D(Place place, + float abs_error, + bool with_relu = false, + bool padding = false) { for (auto& m : {1, 3, 16}) { for (auto& n : {1, 4, 16, 128, 256, 1024}) { for (auto& k : {1, 16, 128, 1024}) { @@ -242,9 +242,35 @@ void TestFCMain(Place place, } } +void TestFCHelper(Place place, + float abs_error, + std::vector xdims, + std::vector wdims, + std::vector bdims, + int in_num_col_dims) { + std::unique_ptr tester(new FcOPTest(place, + "def", + DDim(xdims), + DDim(wdims), + DDim(bdims), + in_num_col_dims, + false, + false)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); +} + +void TestFCnD(Place place, float abs_error) { + TestFCHelper(place, abs_error, {2, 3, 4}, {4, 5}, {5}, 2); + TestFCHelper(place, abs_error, {2, 3, 4}, {12, 5}, {5}, 1); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {5, 6}, {6}, 3); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {20, 6}, {6}, 2); + TestFCHelper(place, abs_error, {2, 3, 4, 5}, {60, 6}, {6}, 1); +} + TEST(FcOP, precision) { Place place; - float abs_error = 6e-5; + float abs_error = 1e-4; #if defined(LITE_WITH_NPU) place = TARGET(kNPU); abs_error = 2e-1; // Using fp16 in NPU @@ -256,7 +282,9 @@ TEST(FcOP, precision) { #else return; #endif - TestFCMain(place, abs_error); + + TestFC2D(place, abs_error); + TestFCnD(place, abs_error); } #ifdef LITE_WITH_X86 @@ -264,7 +292,7 @@ TEST(FcOP, padding_and_parallel) { Place place(TARGET(kX86)); float abs_error = 1e-4; x86::SetNumThreads(4); - TestFCMain(place, abs_error, true, true); + TestFC2D(place, abs_error, true, true); } #endif diff --git a/lite/tests/math/conv_compute_test.cc b/lite/tests/math/conv_compute_test.cc index df238ceae9e39541fb954d9262832d01cd9d3b7f..ceb35ffb6e4c728904d1f63f96f16434a561e904 100644 --- a/lite/tests/math/conv_compute_test.cc +++ b/lite/tests/math/conv_compute_test.cc @@ -307,7 +307,7 @@ void test_conv_fp32(const std::vector& input_dims, #endif // LITE_WITH_ARM // TODO(chenjiaoAngel): fix multi-threds, diff: 3x3 depthwise conv -#if 1 /// 3x3dw +#if 1 // 3x3dw TEST(TestConv3x3DW, test_conv3x3_depthwise) { if (FLAGS_basic_test) { for (auto& stride : {1, 2}) { @@ -449,7 +449,7 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { dims.push_back(DDim({batch, cin, h, h})); } } - if (cin == 1 && cout ==1) { + if (cin == 1 && cout == 1) { continue; } const float leakey_relu_scale = 8.88; diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index 25879a15184965b128bfa100a2b41a17aa842860..8eab3109418540671f324ae0e46bd7b8d2b7a7db 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -285,7 +285,7 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) { paddle::lite::DeviceInfo::Init(); #endif LOG(INFO) << "run basic sgemm test"; - for (auto& m : {1, 3, 8, 32, 397}) { + for (auto& m : {1, 3, 8, 32}) { // ,397 for (auto& n : {1, 3, 13, 141, 512, 789}) { for (auto& tra : {false}) { for (auto& has_bias : {false, true}) { diff --git a/lite/tools/build.sh b/lite/tools/build.sh index e1610b60d3b1b104699ab175bca3bb3cf81bd40b..7bb330b28bc51ca4a241831bd320cb25474a74cd 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -14,7 +14,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4} # global variables BUILD_EXTRA=OFF -BUILD_JAVA=ON +BUILD_JAVA=OFF BUILD_PYTHON=OFF BUILD_DIR=$(pwd) OPTMODEL_DIR="" @@ -62,17 +62,17 @@ function prepare_thirdparty { fi } -function build_model_optimize_tool { +function build_opt { cd $workspace prepare_thirdparty - mkdir -p build.model_optimize_tool - cd build.model_optimize_tool + mkdir -p build.opt + cd build.opt cmake .. -DWITH_LITE=ON \ -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON \ -DWITH_TESTING=OFF \ -DLITE_BUILD_EXTRA=ON \ -DWITH_MKL=OFF - make model_optimize_tool -j$NUM_PROC + make opt -j$NUM_PROC } function make_tiny_publish_so { @@ -395,7 +395,7 @@ function main { shift ;; build_optimize_tool) - build_model_optimize_tool + build_opt shift ;; cuda) diff --git a/lite/tools/build_fpga.sh b/lite/tools/build_fpga.sh index f8c186e92fc3ba23e5e09b6a139202d028e58fc6..ab10798fe7da34ddd88b2fab0bcc0e5f4b8ce233 100755 --- a/lite/tools/build_fpga.sh +++ b/lite/tools/build_fpga.sh @@ -2,12 +2,16 @@ build_dir=build_fpga mkdir -p ${build_dir} -cd ${build_dir} -GEN_CODE_PATH_PREFIX=lite/gen_code -mkdir -p ./${GEN_CODE_PATH_PREFIX} -touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc +root_dir=$(pwd) +build_dir=${build_dir} +# in build directory +# 1. Prepare gen_code file +GEN_CODE_PATH_PREFIX=${build_dir}/lite/gen_code +mkdir -p ${GEN_CODE_PATH_PREFIX} +touch ${GEN_CODE_PATH_PREFIX}/__generated_code__.cc +cd ${build_dir} cmake .. \ -DWITH_GPU=OFF \ -DWITH_MKL=OFF \ @@ -19,8 +23,9 @@ cmake .. \ -DLITE_WITH_OPENMP=ON \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DWITH_TESTING=OFF \ - -DARM_TARGET_OS=armlinux - -make -j8 + -DARM_TARGET_OS=armlinux \ + -DLITE_BUILD_EXTRA=ON \ + -DLITE_WITH_PROFILE=OFF +make -j42 cd - diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index a0273efe13512e38e157dda76401f8946f79880f..1960dc1e1506f9742cdd9be41d5448c646c026af 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -519,7 +519,7 @@ function test_model_optimize_tool_compile { cd $workspace cd build cmake .. -DWITH_LITE=ON -DLITE_ON_MODEL_OPTIMIZE_TOOL=ON -DWITH_TESTING=OFF -DLITE_BUILD_EXTRA=ON - make model_optimize_tool -j$NUM_CORES_FOR_COMPILE + make opt -j$NUM_CORES_FOR_COMPILE } function _test_paddle_code_generator { diff --git a/mobile/src/fpga/KD/pes/conv_pe.hpp b/mobile/src/fpga/KD/pes/conv_pe.hpp old mode 100644 new mode 100755 index 5ef89e920e60cd2ef1c57e1f342a342a4149563f..388672a99325c2d04d87c90fa5a6b556b676a820 --- a/mobile/src/fpga/KD/pes/conv_pe.hpp +++ b/mobile/src/fpga/KD/pes/conv_pe.hpp @@ -29,7 +29,6 @@ namespace zynqmp { class ConvPE : public PE { public: bool init() { - std::cout << "Conv init" << std::endl; return true; }