diff --git a/CMakeLists.txt b/CMakeLists.txt index 377e58d3ac7c37271d2a813b22912528c556164b..d40491f3eecbea5d4da5817c07be9cb27b8ce25e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,7 @@ lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF) lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF) lite_option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF) lite_option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) -lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF IF LITE_WITH_PROFILE) +lite_option(LITE_WITH_PRECISION_PROFILE "Enable precision profile in profile mode ON in lite" OFF) lite_option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF) lite_option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index d38c78f62fa2bed4f4483355de0683f1f5b7656b..0d60c578685cd3d3f3adbeac9fc75d1cdcc78c51 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -152,9 +152,10 @@ endif() if (LITE_WITH_PROFILE) add_definitions("-DLITE_WITH_PROFILE") - if (LITE_WITH_PRECISION_PROFILE) - add_definitions("-DLITE_WITH_PRECISION_PROFILE") - endif() +endif() + +if (LITE_WITH_PRECISION_PROFILE) + add_definitions("-DLITE_WITH_PRECISION_PROFILE") endif() if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/cmake/lite.cmake b/cmake/lite.cmake index 265de3fbf68542f1b1525257887cbfaa4d1c4d62..780cdea445cf10897ee71c85a939a64406b59c96 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -307,6 +307,9 @@ function(add_kernel TARGET device level) if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() if ("${device}" STREQUAL "Host") @@ -322,16 +325,11 @@ function(add_kernel TARGET device level) set(arm_kernels "${arm_kernels};${TARGET}" CACHE INTERNAL "") endif() if ("${device}" STREQUAL "X86") - if (NOT LITE_WITH_X86) + if (NOT LITE_WITH_X86 OR LITE_ON_MODEL_OPTIMIZE_TOOL) foreach(src ${args_SRCS}) file(APPEND ${fake_kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") endforeach() return() - elseif (LITE_ON_MODEL_OPTIMIZE_TOOL) - foreach(src ${args_SRCS}) - file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") - endforeach() - return() endif() set(x86_kernels "${x86_kernels};${TARGET}" CACHE INTERNAL "") endif() @@ -434,11 +432,13 @@ function(add_operator TARGET level) ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if ("${level}" STREQUAL "extra" AND (NOT LITE_BUILD_EXTRA)) return() endif() + if ("${level}" STREQUAL "train" AND (NOT LITE_WITH_TRAIN)) + return() + endif() foreach(src ${args_SRCS}) if(LITE_BUILD_TAILOR) diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index a39c0a02681f16578ae81c74d83979fe0c57e6c6..5b8a420b2a6b127ebbd6ce4005a426b03b527c0c 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -192,7 +192,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) add_dependencies(publish_inference publish_inference_cxx_lib) if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug") add_custom_command(TARGET publish_inference_cxx_lib POST_BUILD - COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a) + COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a + COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.so) endif() endif() else() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 49e60df246b98e7ee73f884a590c4d4bd91efab1..284f4e19ad14e917f6e142fc4c182a539ed93e79 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -8,11 +8,12 @@ if (LITE_ON_TINY_PUBLISH) set(CMAKE_CXX_FLAGS_RELEASE "-Os -DNDEBUG") set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG") endif() -set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer) -if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")) + +set(light_lib_DEPS light_api paddle_api paddle_api_light) +if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH_BM OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")) #full api dynamic library - add_library(paddle_full_api_shared SHARED "") - target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc) + lite_cc_library(paddle_full_api_shared SHARED SRCS paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc + DEPS paddle_api paddle_api_light paddle_api_full) add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) target_link_libraries(paddle_full_api_shared framework_proto) if(LITE_WITH_X86) @@ -27,13 +28,13 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE endif(LITE_WITH_CUDA) #light api dynamic library - lite_cc_library(paddle_light_api_shared MODULE - SRCS light_api_shared.cc - DEPS ${light_lib_DEPS} - ARM_DEPS ${arm_kernels} - CV_DEPS paddle_cv_arm - NPU_DEPS ${npu_kernels}) - + lite_cc_library(paddle_light_api_shared SHARED SRCS paddle_api.cc light_api.cc light_api_impl.cc + DEPS ${light_lib_DEPS} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + ) + add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) target_link_libraries(paddle_light_api_shared ${light_lib_DEPS} ${arm_kernels} ${npu_kernels}) set(LINK_MAP_FILE "${PADDLE_SOURCE_DIR}/lite/core/lite.map") set(LINK_FLAGS "-Wl,--version-script ${LINK_MAP_FILE}") @@ -262,7 +263,10 @@ if (NOT LITE_ON_TINY_PUBLISH) CV_DEPS paddle_cv_arm NPU_DEPS ${npu_kernels} CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels}) + FPGA_DEPS ${fpga_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + BM_DEPS ${bm_kernels}) # The final inference library for just MobileConfig. bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api) target_link_libraries(paddle_api_full ${cuda_deps}) @@ -311,7 +315,7 @@ add_dependencies(opt_base supported_kernel_op_info_h framework_proto all_kernel_ if (LITE_ON_MODEL_OPTIMIZE_TOOL) message(STATUS "Compiling opt") lite_cc_binary(opt SRCS opt.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc - DEPS gflags kernel op optimizer mir_passes utils) + DEPS gflags kernel op optimizer mir_passes utils ${host_kernels}) add_dependencies(opt op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h) endif(LITE_ON_MODEL_OPTIMIZE_TOOL) diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index d53de7bf2ed00fed70bbd1f70729a051e5d7203b..0843faf0d6b060a5b76a850de069b1dbf714da19 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -44,7 +44,10 @@ DEFINE_string(input_shape, "set input shapes according to the model, " "separated by colon and comma, " "such as 1,3,244,244"); -DEFINE_string(input_img_path, "", "the path of input image"); +DEFINE_string(input_img_path, + "", + "the path of input image, if not set " + "input_img_path, the input of model will be 1.0."); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(power_mode, @@ -57,16 +60,11 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_string(result_filename, "result.txt", - "save benchmark " - "result to the file"); + "save the inference time to the file."); DEFINE_bool(run_model_optimize, false, "if set true, apply model_optimize_tool to " "model and use optimized model to test. "); -DEFINE_bool(is_quantized_model, - false, - "if set true, " - "test the performance of the quantized model. "); namespace paddle { namespace lite_api { @@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { std::vector vaild_places = { Place{TARGET(kARM), PRECISION(kFloat)}, }; - if (FLAGS_is_quantized_model) { - vaild_places.insert(vaild_places.begin(), - Place{TARGET(kARM), PRECISION(kInt8)}); - } config.set_valid_places(vaild_places); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -181,8 +175,8 @@ void Run(const std::vector& input_shape, int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_model_dir == "" || FLAGS_result_filename == "") { - LOG(INFO) << "please run ./benchmark_bin --help to obtain usage."; + if (FLAGS_model_dir == "") { + LOG(INFO) << "Please run ./benchmark_bin --help to obtain usage."; exit(0); } diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 1bc33e6a988d387353eb30474dadffac6ddd105e..d96947d9d1019d062f7ccc1b52d48cd957117a11 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -19,6 +19,7 @@ #include #include #include +#include "lite/api/paddle_use_passes.h" #include "lite/utils/io.h" namespace paddle { @@ -298,6 +299,9 @@ void Predictor::Build(const std::shared_ptr &desc, inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); inner_places.emplace_back( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + + // Analysis whether the modle is quantized. + // For quantized model, add place(arm, int8) to inner_places const std::vector quant_dequant_op = { "fake_quantize_abs_max", "fake_quantize_range_abs_max", @@ -320,7 +324,8 @@ void Predictor::Build(const std::shared_ptr &desc, } } if (is_quantized_model) { - inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); + inner_places.insert(inner_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); } Program program(*desc.get(), scope_, inner_places); diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index b641973a15b2e6abc1cf4c999d759271f7522638..f61e2f35241bb2a361e665f37fb58e3bc5226090 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -13,13 +13,9 @@ // limitations under the License. #include "lite/api/light_api.h" +#include #include "paddle_use_kernels.h" // NOLINT #include "paddle_use_ops.h" // NOLINT -#ifndef LITE_ON_TINY_PUBLISH -#include "lite/api/paddle_use_passes.h" -#endif - -#include namespace paddle { namespace lite { diff --git a/lite/api/light_api_shared.cc b/lite/api/light_api_shared.cc deleted file mode 100644 index cfe3d9de09a646e33c4a116bb3cd087d28aa24c2..0000000000000000000000000000000000000000 --- a/lite/api/light_api_shared.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "lite/api/paddle_api.h" - -namespace paddle { -namespace lite_api { - -void RunModel() { - // 1. Set MobileConfig - MobileConfig mobile_config; - - // 2. Create PaddlePredictor by MobileConfig - std::shared_ptr mobile_predictor = - CreatePaddlePredictor(mobile_config); -} - -} // namespace lite_api -} // namespace paddle diff --git a/lite/api/opt.cc b/lite/api/opt.cc index b8497199684cb4f6d4cc602291be5762eb93f7f9..12003050af864da7d88d335553d71007cf5ed9c5 100644 --- a/lite/api/opt.cc +++ b/lite/api/opt.cc @@ -23,6 +23,7 @@ #include "kernel_src_map.h" // NOLINT #include "lite/api/cxx_api.h" #include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/core/op_registry.h" diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index cd68ca5187146fcf4d88b2008fc44533b3e1cf10..e48686b913cc5b07f87db0a503ce7081bbe7d058 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -54,7 +54,8 @@ enum class TargetType : int { kXPU = 9, kBM = 10, kAny = 6, // any target - NUM = 11, // number of fields. + kMLU = 11, + NUM = 12, // number of fields. }; enum class PrecisionType : int { kUnk = 0, @@ -98,7 +99,8 @@ enum class ActivationType : int { kTanh = 6, kSwish = 7, kExp = 8, - NUM = 9, + kAbs = 9, + NUM = 10, }; static size_t PrecisionTypeLength(PrecisionType type) { diff --git a/lite/api/test_classify_lite_bm.cc b/lite/api/test_classify_lite_bm.cc index 7da7dc03745aa623e35dec5b344e16de03cf5aca..b2507e28adbe050e4715e0c28a433a259607e7a9 100644 --- a/lite/api/test_classify_lite_bm.cc +++ b/lite/api/test_classify_lite_bm.cc @@ -36,7 +36,8 @@ void TestModel(const std::vector& valid_places) { predictor.Build(FLAGS_model_dir, "", "", valid_places, passes); auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + input_tensor->Resize(DDim( + std::vector({1, 3, FLAGS_im_height, FLAGS_im_width}))); auto* data = input_tensor->mutable_data(); auto item_size = input_tensor->dims().production(); if (FLAGS_input_img_txt_path.empty()) { @@ -67,15 +68,13 @@ void TestModel(const std::vector& valid_places) { << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; - auto* out = predictor.GetOutput(0); - ASSERT_EQ(out->dims().size(), 2); - ASSERT_EQ(out->dims()[0], 1); - ASSERT_EQ(out->dims()[1], 1000); - - auto* out_data = out->data(); + auto out = predictor.GetOutputs(); FILE* fp = fopen("result.txt", "wb"); - for (int i = 0; i < out->numel(); i++) { - fprintf(fp, "%f\n", out_data[i]); + for (int i = 0; i < out.size(); i++) { + auto* out_data = out[i]->data(); + for (int j = 0; j < out[i]->numel(); j++) { + fprintf(fp, "%f\n", out_data[j]); + } } fclose(fp); } diff --git a/lite/api/transform_test.cc b/lite/api/transform_test.cc index 896b47a97fb20e6935764e12fbe9ebd646a4f816..e1c315f4a63ffd3ed8f51fa4b73ac88b50835cab 100644 --- a/lite/api/transform_test.cc +++ b/lite/api/transform_test.cc @@ -13,7 +13,9 @@ // limitations under the License. #include +#ifdef PADDLE_WITH_TESTING #include +#endif #include #include #include "lite/api/cxx_api.h" diff --git a/lite/backends/cuda/math/utils.h b/lite/backends/cuda/math/utils.h index b6aa9c7d160ad6c8b60b132e4a2bbd7ae1e0b9ff..78aa689ff767e8a454dec3aa48a97ecefafdbe7a 100644 --- a/lite/backends/cuda/math/utils.h +++ b/lite/backends/cuda/math/utils.h @@ -29,6 +29,7 @@ enum class BinaryOperation { kADD = 0, kMUL = 1, kDIV = 2, + kSUB = 3, }; template @@ -41,6 +42,7 @@ __device__ __forceinline__ float binary_calc(float x, if (type == BinaryOperation::kADD) return x + y; if (type == BinaryOperation::kMUL) return x * y; if (type == BinaryOperation::kDIV) return x / y; + if (type == BinaryOperation::kSUB) return x - y; } template diff --git a/lite/backends/host/target_wrapper.cc b/lite/backends/host/target_wrapper.cc index e00bf125e1abb745e0f219455c3e534467c3c919..1854675bad19ca084d171c18dae7535ce8aa641b 100644 --- a/lite/backends/host/target_wrapper.cc +++ b/lite/backends/host/target_wrapper.cc @@ -34,7 +34,7 @@ void* TargetWrapper::Malloc(size_t size) { return r; } void TargetWrapper::Free(void* ptr) { - ptr=Malloc(1); + ptr = Malloc(1); if (ptr) { free(static_cast(ptr)[-1]); } diff --git a/lite/backends/opencl/cl_context.cc b/lite/backends/opencl/cl_context.cc index f0105e060f03df3e4d49c358cf314730cdd16393..153c0620035377afac065e8049a9ebbc0a6f0c15 100644 --- a/lite/backends/opencl/cl_context.cc +++ b/lite/backends/opencl/cl_context.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "lite/backends/opencl/cl_context.h" +#include #include #include #include @@ -35,8 +36,10 @@ cl::Program &CLContext::GetProgram(const std::string &file_name, STL::stringstream program_key_ss; program_key_ss << file_name << options; std::string program_key = program_key_ss.str(); - auto it = programs_.find(program_key); - if (it != programs_.end()) { + + auto &programs = CLRuntime::Global()->programs(); + auto it = programs.find(program_key); + if (it != programs.end()) { VLOG(3) << " --- program -> " << program_key << " has been built --- "; return *(it->second); } @@ -47,14 +50,15 @@ cl::Program &CLContext::GetProgram(const std::string &file_name, CLRuntime::Global()->BuildProgram(program.get(), options); VLOG(3) << " --- end build program -> " << program_key << " --- "; - programs_[program_key] = std::move(program); + programs[program_key] = std::move(program); - return *(programs_[program_key]); + return *(programs[program_key]); } void CLContext::AddKernel(const std::string &kernel_name, const std::string &file_name, - const std::string &options) { + const std::string &options, + const std::string &time_stamp) { cl_int status{CL_SUCCESS}; VLOG(3) << " --- to get program " << file_name << " --- "; auto program = GetProgram(file_name, options); @@ -64,25 +68,30 @@ void CLContext::AddKernel(const std::string &kernel_name, new cl::Kernel(program, kernel_name.c_str(), &status)); CL_CHECK_FATAL(status); VLOG(3) << " --- end create kernel --- "; - kernels_.emplace_back(std::move(kernel)); + + auto &kernels = CLRuntime::Global()->kernels(); + auto &kernel_offset_map = CLRuntime::Global()->kernel_offset(); + kernels.emplace_back(std::move(kernel)); STL::stringstream kernel_key; - kernel_key << kernel_name << options; - kernel_offset_[kernel_key.str()] = kernels_.size() - 1; + kernel_key << kernel_name << options << time_stamp; + kernel_offset_map[kernel_key.str()] = kernels.size() - 1; } cl::Kernel &CLContext::GetKernel(const int index) { - VLOG(3) << " --- kernel count: " << kernels_.size() << " --- "; - CHECK(static_cast(index) < kernels_.size()) + auto &kernels = CLRuntime::Global()->kernels(); + VLOG(3) << " --- kernel count: " << kernels.size() << " --- "; + CHECK(static_cast(index) < kernels.size()) << "The index must be less than the size of kernels."; - CHECK(kernels_[index] != nullptr) + CHECK(kernels[index] != nullptr) << "The target kernel pointer cannot be null."; - return *(kernels_[index]); + return *(kernels[index]); } cl::Kernel &CLContext::GetKernel(const std::string &name) { - auto it = kernel_offset_.find(name); - CHECK(it != kernel_offset_.end()) << "Cannot find the kernel function: " - << name; + auto &kernel_offset_map = CLRuntime::Global()->kernel_offset(); + auto it = kernel_offset_map.find(name); + CHECK(it != kernel_offset_map.end()) << "Cannot find the kernel function: " + << name; return GetKernel(it->second); } @@ -121,14 +130,53 @@ cl::NDRange CLContext::DefaultWorkSize(const CLImage &image) { } } +cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size, + size_t max_work_size, + int divisor) { + int preferred_lws = 0; +#if 1 + auto gws0 = global_work_size[0]; + auto gws1 = global_work_size[1]; + auto gws2 = global_work_size[2]; +#else + auto gws2 = global_work_size[0]; + auto gws1 = global_work_size[1]; + auto gws0 = global_work_size[2]; +#endif + if (divisor > 1) { + max_work_size /= divisor; + } + if (preferred_lws > 0 && preferred_lws <= max_work_size) { + max_work_size = preferred_lws; + } + while (gws1 > max_work_size && max_work_size > 0) { + gws1 = gws1 % 2 == 0 ? gws1 / 2 : 1; + } + while (gws2 * gws1 > max_work_size && max_work_size > 0) { + gws2 = gws2 % 2 == 0 ? gws2 / 2 : 1; + } + while (gws0 * gws1 * gws2 > max_work_size && max_work_size > 0) { + gws0 = gws0 % 2 == 0 ? gws0 / 2 : 1; + } +#if 1 + return cl::NDRange{static_cast(gws0), + static_cast(gws1), + static_cast(gws2)}; +#else + return cl::NDRange{static_cast(gws2), + static_cast(gws1), + static_cast(gws0)}; +#endif +} + cl::NDRange CLContext::LocalWorkSize(cl::NDRange global_work_size, size_t max_work_size) { int preferred_lws = 0; int divisor = 2; - auto tmp0 = global_work_size[0]; - auto tmp1 = global_work_size[1]; - auto tmp2 = global_work_size[2]; + auto gws0 = global_work_size[0]; + auto gws1 = global_work_size[1]; + auto gws2 = global_work_size[2]; if (divisor > 1) { max_work_size /= divisor; @@ -136,18 +184,18 @@ cl::NDRange CLContext::LocalWorkSize(cl::NDRange global_work_size, if (preferred_lws > 0 && preferred_lws <= max_work_size) { max_work_size = preferred_lws; } - while (tmp1 > max_work_size && max_work_size > 0) { - tmp1 = tmp1 % 2 == 0 ? tmp1 / 2 : 1; + while (gws1 > max_work_size && max_work_size > 0) { + gws1 = gws1 % 2 == 0 ? gws1 / 2 : 1; } - while (tmp2 * tmp1 > max_work_size && max_work_size > 0) { - tmp2 = tmp2 % 2 == 0 ? tmp2 / 2 : 1; + while (gws2 * gws1 > max_work_size && max_work_size > 0) { + gws2 = gws2 % 2 == 0 ? gws2 / 2 : 1; } - while (tmp0 * tmp1 * tmp2 > max_work_size && max_work_size > 0) { - tmp0 = tmp0 % 2 == 0 ? tmp0 / 2 : 1; + while (gws0 * gws1 * gws2 > max_work_size && max_work_size > 0) { + gws0 = gws0 % 2 == 0 ? gws0 / 2 : 1; } - return cl::NDRange{static_cast(tmp0), - static_cast(tmp1), - static_cast(tmp2)}; + return cl::NDRange{static_cast(gws0), + static_cast(gws1), + static_cast(gws2)}; } } // namespace lite diff --git a/lite/backends/opencl/cl_context.h b/lite/backends/opencl/cl_context.h index 1964c4bf56b55841ba735c79b2f7a17dc1ed451e..b12473ccf5b4238f4ee95b7848a0842ee5b2ffe0 100644 --- a/lite/backends/opencl/cl_context.h +++ b/lite/backends/opencl/cl_context.h @@ -36,7 +36,8 @@ class CLContext { void AddKernel(const std::string &kernel_name, const std::string &file_name, - const std::string &options = ""); + const std::string &options = "", + const std::string &time_stamp = ""); cl::Kernel &GetKernel(const int index); @@ -46,10 +47,11 @@ class CLContext { cl::NDRange LocalWorkSize(cl::NDRange global_work_size, size_t max_work_size); - private: - std::unordered_map> programs_; - std::vector> kernels_; - std::map kernel_offset_; + cl::NDRange LocalWorkSizeTurn(cl::NDRange global_work_size, + size_t max_work_size, + int divitor = 2); + // cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size, + // size_t max_work_size); }; } // namespace lite diff --git a/lite/backends/opencl/cl_runtime.cc b/lite/backends/opencl/cl_runtime.cc index 63c9954f9181e9252c4d14f57b6ed29107965fe3..8405fc967239e851705feb96f517b3980192ebef 100644 --- a/lite/backends/opencl/cl_runtime.cc +++ b/lite/backends/opencl/cl_runtime.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "lite/backends/opencl/cl_runtime.h" #include +#include #include #include #include "lite/utils/cp_logging.h" @@ -29,10 +30,26 @@ CLRuntime* CLRuntime::Global() { CLRuntime::~CLRuntime() { if (command_queue_ != nullptr) { + command_queue_->flush(); command_queue_->finish(); } - // For controlling the destruction order: + + for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) { + clReleaseKernel(kernels_[kidx]->get()); + kernels_[kidx].reset(); + } + kernels_.clear(); + kernel_offset_.clear(); + + for (auto& p : programs_) { + clReleaseProgram(p.second->get()); + } + programs_.clear(); + + // For controlling the destruction order + command_queue_&& clReleaseCommandQueue(command_queue_->get()); command_queue_.reset(); + context_&& clReleaseContext(context_->get()); context_.reset(); device_.reset(); platform_.reset(); @@ -73,14 +90,14 @@ cl::CommandQueue& CLRuntime::command_queue() { return *command_queue_; } -std::unique_ptr CLRuntime::CreateProgram( +std::shared_ptr CLRuntime::CreateProgram( const cl::Context& context, std::string file_name) { auto cl_file = opencl_kernels_files.find(file_name); std::string content(cl_file->second.begin(), cl_file->second.end()); cl::Program::Sources sources; sources.push_back(content); auto prog = - std::unique_ptr(new cl::Program(context, sources, &status_)); + std::shared_ptr(new cl::Program(context, sources, &status_)); VLOG(4) << "OpenCL kernel file name: " << file_name; VLOG(4) << "Program source size: " << content.size(); CL_CHECK_FATAL(status_); diff --git a/lite/backends/opencl/cl_runtime.h b/lite/backends/opencl/cl_runtime.h index 1a5ededeff37d9f6820af6a49dc22c669620734b..36e5d64b906ff5c91b2b5cb5e97855d7dff511c4 100644 --- a/lite/backends/opencl/cl_runtime.h +++ b/lite/backends/opencl/cl_runtime.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include #include #include "lite/backends/opencl/cl_include.h" #include "lite/backends/opencl/cl_utility.h" @@ -42,7 +43,7 @@ class CLRuntime { cl::CommandQueue& command_queue(); - std::unique_ptr CreateProgram(const cl::Context& context, + std::shared_ptr CreateProgram(const cl::Context& context, std::string file_name); std::unique_ptr CreateEvent(const cl::Context& context); @@ -57,6 +58,12 @@ class CLRuntime { std::map& GetDeviceInfo(); + std::unordered_map>& programs() { + return programs_; + } + std::vector>& kernels() { return kernels_; } + std::map& kernel_offset() { return kernel_offset_; } + private: CLRuntime() = default; @@ -98,6 +105,12 @@ class CLRuntime { std::shared_ptr command_queue_{nullptr}; + std::unordered_map> programs_{}; + + std::vector> kernels_{}; + + std::map kernel_offset_{}; + cl_int status_{CL_SUCCESS}; bool initialized_{false}; diff --git a/lite/backends/opencl/cl_utility.h b/lite/backends/opencl/cl_utility.h index b7f14c15e61ba050220ef0819fa9c3d13a7b8606..de01f896a6eb461eb24023a77935bba07de029e7 100644 --- a/lite/backends/opencl/cl_utility.h +++ b/lite/backends/opencl/cl_utility.h @@ -32,7 +32,7 @@ const char* opencl_error_to_str(cl_int error); __FILE__, \ __LINE__); \ } - +#ifndef LITE_SHUTDOWN_LOG #define CL_CHECK_FATAL(err_code__) \ if (err_code__ != CL_SUCCESS) { \ LOG(FATAL) << string_format( \ @@ -42,5 +42,8 @@ const char* opencl_error_to_str(cl_int error); __FILE__, \ __LINE__); \ } +#else +#define CL_CHECK_FATAL(err_code__) +#endif } // namespace lite } // namespace paddle diff --git a/lite/core/context.h b/lite/core/context.h index 978fb5d67a2fae8025fa725ed1f717aa3df611c0..88fe00d0f2aab41cfd3e5562d29f0a8a82598428 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -52,6 +52,7 @@ using XPUContext = Context; using OpenCLContext = Context; using FPGAContext = Context; using BMContext = Context; +using MLUContext = Context; template <> class Context { diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index f5a7837b53650e08f9632b499a4c2ab1faeaeedf..4393832931c95ca20e34ca3b3d2fb4501274b15f 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -26,7 +26,8 @@ namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { for (auto conv_type : conv_type_cases) { diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 0f5bb64e10dd61c3edf4ddd32569a2d365651cdf..150a6e68d8a924ebfa96fdffb99e28b230689a48 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -103,10 +103,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { std::string conv_weight_name = matched.at("conv_weight")->arg()->name; auto conv_weight_t = scope->FindVar(conv_weight_name)->GetMutable(); - CHECK_EQ(static_cast(bn_scale_t->data_size()), - static_cast(conv_weight_t->dims()[0])) - << "The BN bias's size should be equal to the size of the first " - << "dim size of the conv weights"; + if (conv_type_ == "conv2d_transpose") { + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[1])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + } else { + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[0])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + } size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; bool is_weight_quantization = @@ -153,12 +160,29 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { // compute new conv_weight for int8 auto weight_scale = conv_op_desc->GetAttr>("weight_scale"); - for (unsigned int i = 0; i < h; ++i) { - weight_scale[i] *= fabsf(alpha_data[i]); - if (alpha_data[i] < 0.f) { - auto ptr_row = conv_weight_d + i * w; - for (unsigned int j = 0; j < w; ++j) { - ptr_row[j] *= -1; + if (conv_type_ == "conv2d_transpose") { + int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * + conv_weight_t->dims()[3]; + int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; + for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) { + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + k * c_size + i * hw; + for (unsigned int j = 0; j < hw; ++j) { + ptr_row[j] *= -1; + } + } + } + } + } else { + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + i * w; + for (unsigned int j = 0; j < w; ++j) { + ptr_row[j] *= -1; + } } } } @@ -176,9 +200,23 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } else { // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); - for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels - for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels - conv_weight_d[i * w + j] *= alpha_data[i]; + if (conv_type_ == "conv2d_transpose") { + int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * + conv_weight_t->dims()[3]; + int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; + for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) { + for (unsigned int i = 0; i < h; ++i) { + auto ptr_row = conv_weight_d + k * c_size + i * hw; + for (unsigned int j = 0; j < hw; ++j) { + ptr_row[j] *= alpha_data[i]; + } + } + } + } else { + for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels + for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels + conv_weight_d[i * w + j] *= alpha_data[i]; + } } } } diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ab81f3d809507dd340056c97a39998c908a75dc7..80a033c75f2e23efa091375ee2a9f78e3ff40d71 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { fuser(graph.get()); } - // delete quant_dequant_node - for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) { - fusion::DeleteQuantDequantOpFuser fuser(op_type); - fuser(graph.get()); - } + // process quant_dequant_node + fusion::DeleteQuantDequantOpFuser dqd_fuser; + dqd_fuser(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 7797864a2e4b75f52fd7da93ea81613a2175f423..a3a98b871fb4b6f8230299cda978b0f1f8faa779 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, auto* output_scale_node = matched.at("output_scale_node"); auto* output_act_node = matched.at("output_act_node"); - // obtain values, save values and relink node + // obtain scale, save attrs and relink node int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); auto* scope = quant_node->stmt()->op()->scope(); @@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph, ->GetMutable(); float scale_value = scale_tensor->data()[0] / range; + auto in_act_name = input_act_node->arg()->name; + auto out_act_name = output_act_node->arg()->name; auto outlinks = output_act_node->outlinks; for (auto* quantized_node : outlinks) { - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("input_scale", scale_value); + // save input scale in quantized op by input argname + index + auto op_desc = *quantized_node->stmt()->mutable_op_info(); + std::string argname; + int index; + op_desc.GetInputArgname(out_act_name, &argname); + op_desc.GetInputIndex(out_act_name, &index); + op_desc.SetAttr(argname + std::to_string(index) + "_input_scale", + scale_value); + op_desc.SetAttr("input_scale", scale_value); // save it for now + op_desc.SetAttr("bit_length", bit_length); + op_desc.UpdateAllInputs(out_act_name, in_act_name); + quantized_node->stmt()->ResetOp(op_desc, graph->valid_places()); IR_NODE_LINK_TO(input_act_node, quantized_node) } @@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); - // obtain input_scale and weight_scale + // obtain weight_scale from max_range auto* scope = quantized_op->stmt()->op()->scope(); auto& valid_places = quantized_op->stmt()->op()->valid_places(); int bit_length = quantized_op->stmt()->op_info()->GetAttr("bit_length"); int range = ((1 << (bit_length - 1)) - 1); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); float max_range = dequant_op->stmt()->op_info()->GetAttr("max_range"); float whole_weight_scale = static_cast(range * range) / max_range / range; - // max_range = range * range / max(abs(weight)) - // weight_scale = range * range / (range * range / max(abs(weight))) / range - // = max(abs(weight)) / range + // As: max_range = range * range / max(abs(weight)) + // So: whole_weight_scale + // = range * range / (range * range / max(abs(weight))) / range + // = max(abs(weight)) / range // set op desc cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); @@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should // be Cout. weight_scale_size = quantized_weight_t->dims()[0]; - } else if (quantized_op_type_ == "mul") { + } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") { op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); // Fc weight: Cin * Cout, the weight_scale_size should be Cout. @@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, weight_scale.push_back(whole_weight_scale); } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() { ->assert_is_op_output(quantized_op_type_) ->assert_is_op_input(dequant_op_type, "X") ->AsIntermediate(); + // The scale var_node of input activation is deleted in DeleteQuantOpFuser auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale") ->assert_is_op_input(dequant_op_type) ->AsIntermediate(); @@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, auto* dequant_op = matched.at("dequant_op"); auto* dequant_op_out = matched.at("dequant_op_out"); - // obtain input_scale and weight_scale + // obtain input weight_scale from fake_dequant op auto* scope = quantized_op->stmt()->op()->scope(); auto& valid_places = quantized_op->stmt()->op()->valid_places(); - float input_scale = - quantized_op->stmt()->op_info()->GetAttr("input_scale"); std::vector weight_scale; std::vector quant_bits = @@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, // set op desc cpp::OpDesc op_desc = *quantized_op->stmt()->op_info(); - op_desc.SetInput("Input", {quantized_op_input->arg()->name}); - op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); - + if (quantized_op_type_ == "conv2d" || + quantized_op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {quantized_op_input->arg()->name}); + op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); + } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") { + op_desc.SetInput("X", {quantized_op_input->arg()->name}); + op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); + } op_desc.SetAttr("enable_int8", true); - op_desc.SetAttr("input_scale", input_scale); op_desc.SetAttr("weight_scale", weight_scale); // change the weight from the float type to int8 type. @@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { void DeleteQuantDequantOpFuser::BuildPattern() { std::string quant_dequant_op_type = "fake_quantize_dequantize_moving_average_abs_max"; - if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { - auto* input_scale_node = - VarNode("input_scale_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_node = VarNode("input_act_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_node = - OpNode("quant_dequant_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_node = - VarNode("output_scale_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_node = - VarNode("output_act_node") - ->assert_is_op_output(quant_dequant_op_type, "Out"); - auto* quantized_node = OpNode("quantized_node", quantized_op_type_) - ->assert_is_op(quantized_op_type_); - - quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); - output_scale_node->LinksFrom({quant_dequant_node}); - output_act_node->LinksFrom({quant_dequant_node}); - quantized_node->LinksFrom({output_act_node}); - } else if (quantized_op_type_ == "elementwise_add") { - auto* input_scale_left_node = - VarNode("input_scale_left_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_left_node = - VarNode("input_act_left_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_left_node = - OpNode("quant_dequant_left_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_left_node = - VarNode("output_scale_left_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_left_node = - VarNode("output_act_left_node") - ->assert_is_op_output(quant_dequant_op_type, "Out") - ->assert_is_op_input(quantized_op_type_, "X"); - quant_dequant_left_node->LinksFrom( - {input_scale_left_node, input_act_left_node}); - output_scale_left_node->LinksFrom({quant_dequant_left_node}); - output_act_left_node->LinksFrom({quant_dequant_left_node}); - - auto* input_scale_right_node = - VarNode("input_scale_right_node") - ->assert_is_op_input(quant_dequant_op_type, "InScale"); - auto* input_act_right_node = - VarNode("input_act_right_node") - ->assert_is_op_input(quant_dequant_op_type, "X"); - auto* quant_dequant_right_node = - OpNode("quant_dequant_right_node", quant_dequant_op_type) - ->assert_is_op(quant_dequant_op_type); - auto* output_scale_right_node = - VarNode("output_scale_right_node") - ->assert_is_op_output(quant_dequant_op_type, "OutScale"); - auto* output_act_right_node = - VarNode("output_act_right_node") - ->assert_is_op_output(quant_dequant_op_type, "Out") - ->assert_is_op_input(quantized_op_type_, "Y"); - quant_dequant_right_node->LinksFrom( - {input_scale_right_node, input_act_right_node}); - output_scale_right_node->LinksFrom({quant_dequant_right_node}); - output_act_right_node->LinksFrom({quant_dequant_right_node}); - - auto* quantized_node = OpNode("quantized_node", quantized_op_type_) - ->assert_is_op(quantized_op_type_); - quantized_node->LinksFrom({output_act_left_node, output_act_right_node}); - } else { - LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; - } - VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:" - << quantized_op_type_; + auto* input_scale_node = + VarNode("input_scale_node") + ->assert_is_op_input(quant_dequant_op_type, "InScale"); + auto* input_act_node = + VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X"); + auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type) + ->assert_is_op(quant_dequant_op_type); + auto* output_scale_node = + VarNode("output_scale_node") + ->assert_is_op_output(quant_dequant_op_type, "OutScale"); + auto* output_act_node = + VarNode("output_act_node") + ->assert_is_op_output(quant_dequant_op_type, "Out"); + + quant_dequant_node->LinksFrom({input_scale_node, input_act_node}); + output_scale_node->LinksFrom({quant_dequant_node}); + output_act_node->LinksFrom({quant_dequant_node}); } void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { - auto* input_scale_node = matched.at("input_scale_node"); - auto* input_act_node = matched.at("input_act_node"); - auto* quant_dequant_node = matched.at("quant_dequant_node"); - auto* output_scale_node = matched.at("output_scale_node"); - auto* output_act_node = matched.at("output_act_node"); - auto* quantized_node = matched.at("quantized_node"); - - // obtain values, save values and relink node - int bit_length = - quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = quant_dequant_node->stmt()->op()->scope(); - auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) - ->GetMutable(); - float scale_value = scale_tensor->data()[0] / range; - - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("input_scale", scale_value); - op_desc->SetInput("X", {input_act_node->arg()->name}); - IR_NODE_LINK_TO(input_act_node, quantized_node) - auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); - quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); - - // delete nodes and edges - std::unordered_set nodes2rm = {input_scale_node, - quant_dequant_node, - output_scale_node, - output_act_node}; - GraphSafeRemoveNodes(graph, nodes2rm); - } else if (quantized_op_type_ == "elementwise_add") { - auto* input_scale_left_node = matched.at("input_scale_left_node"); - auto* input_act_left_node = matched.at("input_act_left_node"); - auto* quant_dequant_left_node = matched.at("quant_dequant_left_node"); - auto* output_scale_left_node = matched.at("output_scale_left_node"); - auto* output_act_left_node = matched.at("output_act_left_node"); - - auto* input_scale_right_node = matched.at("input_scale_right_node"); - auto* input_act_right_node = matched.at("input_act_right_node"); - auto* quant_dequant_right_node = matched.at("quant_dequant_right_node"); - auto* output_scale_right_node = matched.at("output_scale_right_node"); - auto* output_act_right_node = matched.at("output_act_right_node"); - - auto* quantized_node = matched.at("quantized_node"); - - // obtain values, save values and relink node - int bit_length = - quant_dequant_left_node->stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = quant_dequant_left_node->stmt()->op()->scope(); - auto* left_scale_tensor = - scope->FindVar(output_scale_left_node->arg()->name) - ->GetMutable(); - float left_scale_value = left_scale_tensor->data()[0] / range; - auto* right_scale_tensor = - scope->FindVar(output_scale_right_node->arg()->name) - ->GetMutable(); - float right_scale_value = right_scale_tensor->data()[0] / range; - - auto* op_desc = quantized_node->stmt()->mutable_op_info(); - op_desc->SetAttr("bit_length", bit_length); - op_desc->SetAttr("x_input_scale", left_scale_value); - op_desc->SetAttr("y_input_scale", right_scale_value); - op_desc->SetInput("X", {input_act_left_node->arg()->name}); - op_desc->SetInput("Y", {input_act_right_node->arg()->name}); - IR_NODE_LINK_TO(input_act_left_node, quantized_node) - IR_NODE_LINK_TO(input_act_right_node, quantized_node) - auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); - quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); - - // delete nodes and edges - std::unordered_set nodes2rm = {input_scale_left_node, - quant_dequant_left_node, - output_scale_left_node, - output_act_left_node, - input_scale_right_node, - quant_dequant_right_node, - output_scale_right_node, - output_act_right_node}; - GraphSafeRemoveNodes(graph, nodes2rm); - } else { - LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_; + auto* input_scale_node = matched.at("input_scale_node"); + auto* input_act_node = matched.at("input_act_node"); + auto* quant_dequant_node = matched.at("quant_dequant_node"); + auto* output_scale_node = matched.at("output_scale_node"); + auto* output_act_node = matched.at("output_act_node"); + auto input_act_name = input_act_node->arg()->name; + auto output_act_name = output_act_node->arg()->name; + + // Get scale value from scale var node + int bit_length = + quant_dequant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_dequant_node->stmt()->op()->scope(); + auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; + + auto quantized_nodes = output_act_node->outlinks; + for (auto* quantized_node : quantized_nodes) { + // Save quantization info in op_info attr + auto op_info = *quantized_node->stmt()->op_info(); + std::string argname; + int index; + op_info.GetInputArgname(output_act_name, &argname); + op_info.GetInputIndex(output_act_name, &index); + op_info.SetAttr(argname + std::to_string(index) + "_input_scale", + scale_value); + op_info.SetAttr("input_scale", scale_value); // Save it for now + op_info.SetAttr("bit_length", bit_length); + + op_info.UpdateAllInputs(output_act_name, input_act_name); + quantized_node->stmt()->ResetOp(op_info, graph->valid_places()); + IR_NODE_LINK_TO(input_act_node, quantized_node); } + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, quant_dequant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph, nodes2rm); } cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h index bef9f4d9573d049700736c166cd0d31b668f7eff..ac3ac112b3aa504bc075125f2f13292073ca9444 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.h +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase { }; /* The pattern like "fake_quantize_dequantize_moving_average_abs_max + - * pooled/elementwise_add" can be deteted by this fuser. The fuser - * extract the input_scale form fake_quant_dequant_op and save into - * the quantized_op. Besides, the fuser delete fake_quant_dequant_op in - * the graph. + * quantized_op" can be deteted by this fuser. The fuser modifies the input + * scale for the quantized_op and deletes the fake_quant_dequant_op. */ - class DeleteQuantDequantOpFuser : public FuseBase { public: - explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type) - : quantized_op_type_(quantized_op_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - - private: - std::string quantized_op_type_{}; }; } // namespace fusion diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index c76e369466a9b998b2ad6fde67b97117649fddc0..a9ccd1b9ae9a5d45f8d0e5638b3aab1d73d1903c 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -22,6 +22,61 @@ namespace paddle { namespace lite { +bool OpLite::InferShape() { + // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ + // InferShapeByMemoryInternal will be applied. + if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { + return this->InferShapeWithCache(); + } else { + // otherwise, InferShapeImpl is applied directly. + return this->InferShapeImpl(); + } +} +bool OpLite::InferShapeWithCache() { + // 1. Get vector of current input tensors + auto *current_inputs = param_.input_tensor_ptrs(); + // 2. Get hash value of current inputs shape and lod + size_t new_hash = 0; + for (auto iter = current_inputs->begin(); iter != current_inputs->end(); + iter++) { + // combined dims value into new_hash value. + auto &element_dims = (*iter)->dims(); + for (int i = 0; i < element_dims.size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(element_dims[i])); + } + // combine lod value into new_hash valud. + auto &emement_lods = (*iter)->lod(); + for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end(); + lod_iter++) { + for (int i = 0; i < lod_iter->size(); i++) { + new_hash = + lite::hash_combine(new_hash, static_cast(lod_iter->at(i))); + } + } + } + // 3. infer shapes of output tensors + if (new_hash == io_shape_lod_hash_ && new_hash != 0) { + // if current hash value is consistent with io_shape_lod_hash_, + // previous outputs shape and lod are reused. + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + current_outputs->at(i)->Resize(last_output_shapes[i]); + current_outputs->at(i)->set_lod(last_output_lods[i]); + } + } else { + // otherwise, current hash value is changed, InferShapeImpl will apply. + io_shape_lod_hash_ = new_hash; + this->InferShapeImpl(); + auto *current_outputs = param_.output_tensor_ptrs(); + for (int i = 0; i < current_outputs->size(); i++) { + last_output_shapes[i] = current_outputs->at(i)->dims(); + last_output_lods[i] = current_outputs->at(i)->lod(); + } + } + return true; +} + std::vector> OpLite::CreateKernels( const std::vector &places, const std::string &kernel_type) { std::vector> kernels; diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 77d8091b4b16cfbce2efc3d549f916a9136c61ab..1cdc33825cb4ffb758b46ac4b9bee968b3fca055 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -24,6 +25,7 @@ #include "lite/core/kernel.h" #include "lite/core/scope.h" #include "lite/model_parser/cpp/op_desc.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { @@ -64,8 +66,8 @@ class OpLite : public Registry { // Check the shape. virtual bool CheckShape() const { return true; } // Inference the outputs' shape. - virtual bool InferShape() const { return true; } - virtual bool SmartInferShape() { return this->InferShape(); } + virtual bool InferShapeImpl() const { return true; } + virtual bool InferShape(); // Run this operator. virtual bool Run(); // Indicate whether the Op runs only once or not @@ -151,10 +153,16 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; - std::vector last_input_shapes; - std::vector last_output_shapes; - std::vector>> last_output_lods; - std::vector>> last_input_lods; + + std::vector last_output_shapes{}; + std::vector>> last_output_lods{}; + size_t io_shape_lod_hash_{}; + mutable operators::ParamBase param_; + + private: + // Infer Shape according to memory, if current input shapes are consistent + // with that of previous inputs, output shapes of last time will be reused. + bool InferShapeWithCache(); }; /* @@ -217,6 +225,32 @@ class OpInfo : public cpp::OpDesc { return false; } + // For the input variable name, find the index of the corresponding + // input argname + bool GetInputIndex(const std::string &value_name, int *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; + } + + // For the output variable name, find the index of the corresponding + // output argname + bool GetOutputIndex(const std::string &value_name, int *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = it - item.second.begin(); + return true; + } + } + return false; + } + void UpdateAllInputs(const std::string &from, const std::string &to) { for (auto &item : inputs_) { for (auto &var : item.second) { diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 4b6d3282ed300654c612325ff9c53c153ccea30a..fe1dff3c99c1d2413888e78c89c999caea0ab030 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -107,6 +107,9 @@ std::list> KernelRegistry::Create( case TARGET(kBM): { CREATE_KERNEL(kBM); } break; + case TARGET(kMLU): { + CREATE_KERNEL(kMLU); + } break; default: CHECK(false) << "not supported kernel target " << TargetToStr(target); } @@ -139,6 +142,15 @@ KernelRegistry::KernelRegistry() INIT_FOR(kCUDA, kInt64, kNCHW); INIT_FOR(kCUDA, kInt64, kNHWC); + INIT_FOR(kMLU, kFloat, kNHWC); + INIT_FOR(kMLU, kFloat, kNCHW); + INIT_FOR(kMLU, kFP16, kNHWC); + INIT_FOR(kMLU, kFP16, kNCHW); + INIT_FOR(kMLU, kInt8, kNHWC); + INIT_FOR(kMLU, kInt8, kNCHW); + INIT_FOR(kMLU, kInt16, kNHWC); + INIT_FOR(kMLU, kInt16, kNCHW); + INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kFloat, kNHWC); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 6f8f1e8bc6662a7b22fd8f4c3b9683eb6f4da139..3c41c1fd8af240401c3edf0343433f8d8d9c85db 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -268,7 +268,32 @@ class KernelRegistry final { DATALAYOUT(kAny)> *, // KernelRegistryForTarget * // + DATALAYOUT(kAny)> *, // + + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; KernelRegistry(); diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index f0bfed326c8307b17965f64b1cc52f2fc134b74c..68698e79517481df745d72c6116f94f0c92cb7b7 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -18,6 +18,7 @@ * of each kernel. */ #pragma once +#include #include #include #include "lite/core/program.h" diff --git a/lite/core/program.cc b/lite/core/program.cc index 580389fbad54c0de8efd65ef78c9b69fd3e72893..ff900c0e23be9a06313babba51e3ce364295231a 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -20,7 +20,7 @@ #include "lite/operators/conditional_block_op.h" #include "lite/operators/subgraph_op.h" #include "lite/operators/while_op.h" -#ifdef LITE_WITH_PROFILE +#ifdef LITE_WITH_PRECISION_PROFILE #include "lite/core/profile/precision_profiler.h" #endif @@ -136,12 +136,10 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } void RuntimeProgram::Run() { -#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler(); std::string precision_profiler_summary = inst_precision_profiler.GetSummaryHeader(); -#endif #endif for (auto& inst : instructions_) { @@ -149,21 +147,19 @@ void RuntimeProgram::Run() { if (inst.is_feed_fetch_op()) continue; #endif inst.Run(); -#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE #ifndef LITE_WITH_FPGA precision_profiler_summary += inst_precision_profiler.GetInstPrecision(&inst); #endif #endif // LITE_WITH_PRECISION_PROFILE -#endif // LITE_WITH_PROFILE } #ifdef LITE_WITH_PROFILE LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 0); +#endif #ifdef LITE_WITH_PRECISION_PROFILE LOG(INFO) << "\n" << precision_profiler_summary; -#endif // LITE_WITH_PRECISION_PROFILE -#endif // LITE_WITH_PROFILE +#endif } void Program::Build(const cpp::ProgramDesc& prog) { @@ -286,8 +282,7 @@ void Instruction::Run() { return; } - // op_->InferShape(); - op_->SmartInferShape(); + op_->InferShape(); kernel_->Launch(); has_run_ = true; } diff --git a/lite/demo/cxx/train_demo/README.md b/lite/demo/cxx/train_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..56f4513d45676a1deb51bfb93096db156ddd0449 --- /dev/null +++ b/lite/demo/cxx/train_demo/README.md @@ -0,0 +1,191 @@ + +# Introduction + 我们都知道,PaddleLite可以做移动端预测,事实上PaddleLite支持在移动端做模型训练。本文给出使用PaddleLite做训练的例子,这一例子对应的任务是“波士顿房价预测”,又称作“fit-a-line”。 + + 你可以通过book库中的 +[文档](https://paddlepaddle.org.cn/documentation/docs/zh/user_guides/simple_case/fit_a_line/README.cn.html) +和 +[源码](https://github.com/PaddlePaddle/book/tree/develop/01.fit_a_line) +进一步了解“波士顿房价预测”这一任务的定义及其建模过程, +其使用线性回归(Linear Regression) +模型做建模。本文主要介绍如何将其迁移至Paddle-Lite进行训练。 + +注:这是一篇使用C++ API做模型训练的教程,其他API暂时不支持训练功能。 + +# Requirements + +- 一部安卓手机,用于运行训练程序 +- 装了Paddle (version: 1.7.0) 的python + +# Quick start + +## Step1 build paddle-lite + +请按照[paddle-lite官方文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#paddlelite) 的教程编译full_publish的paddle-lite lib。以Linux上编译为例,其具体的命令为: + +```shell +## 配置环境 +wget -c https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz --no-check-certificate +tar xzf cmake-3.10.3-Linux-x86_64.tar.gz +export PATH=${PWD}'/cmake-3.10.3-Linux-x86_64/bin':$PATH + +wget https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip +unzip android-ndk-r17c-linux-x86_64.zip +export NDK_ROOT=/opt/android-ndk-r17c + +## 编译 +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +./lite/tools/build.sh \ + --arm_os=android \ + --arm_abi=armv7 \ + --build_extra=ON \ + --arm_lang=gcc \ + --android_stl=c++_static \ + --build_train=ON full_publish +``` + +产物: + +```shell +Paddle-Lite/build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/cxx/lib/libpaddle_full_api_shared.so +``` + +## Step2 编译lr_trainer + +```shell +cd Paddle-Lite/lite/demo/cxx/train_demo/cplus_train/ +sh run_build.sh /path/to/your/Paddle-Lite/build.lite.android.armv7.gcc/ /path/to/your/android-ndk-r17c +``` + +产物: +```shell +bin/ +`-- demo_trainer +``` + +## Step3 download model and run it! + +在你的笔记本电脑上,用usb连接到手机,开启开发者模式,在任意目录下执行: + +```shell +local_path=/data/local/tmp/linear_regression +adb shell "mkdir "${local_path} + +# download model and push to mobile +wget http://paddle-tar.bj.bcebos.com/paddle-lite/lite_lr_model.tar.gz +tar -zxvf lite_lr_model.tar.gz +adb push lite_lr_model/housing.data ${local_path} +adb push lite_lr_model/model_dir ${local_path} + +# push lib and executable file to moblie +adb push libpaddle_full_api_shared.so ${local_path} +adb push demo_trainer ${local_path} +adb shell chmod +x ${local_path}/demo_trainer + +# run it! +adb shell "export LD_LIBRARY_PATH="${local_path}" && export LIBRARY_PATH="${local_path}" && cd "${local_path}" && ./demo_trainer true" +``` + +期望结果: + +``` +sample 0: Loss: 564.317 +sample 1: Loss: 463.9 +sample 2: Loss: 1197.54 +sample 3: Loss: 1093.83 +sample 4: Loss: 1282.76 +sample 5: Loss: 792.097 +sample 6: Loss: 491.776 +sample 7: Loss: 698.496 +sample 8: Loss: 248.445 +sample 9: Loss: 325.135 +``` + +# 更多细节 +上面提到的模型是直接下载得到的,如果你想自己生成,可以执行以下命令: + +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite/lite/demo/cxx/train_demo/ +python train.py --save_model +``` + +产物: + +```shell +model_dir/ +|-- fc_0.b_0 +|-- fc_0.w_0 +|-- learning_rate_0 +`-- __model__ + +md5sum fc_0.w_0: 2c7b3649b2a9cf7bcd19f8b256ce795d +``` + +如果你想生成自己的模型用于训练,可以参考`train.py`中保存模型的方式。 + +# 与Paddle训练结果做校对 + +## 前10个Loss值 + +为了验证paddle与lite的一致性,我们控制模型参数一致、数据一致、batch size = 1的情况下,训练10个batch, 记录了二者的loss值。 + +python + paddle 命令: + +```shell + fluid train.py --num_steps=10 --batch_size=1 +``` + +python + paddle 结果: + +```shell +Train cost, Step 0, Cost 564.317017 +Train cost, Step 1, Cost 463.900238 +Train cost, Step 2, Cost 1197.537354 +Train cost, Step 3, Cost 1093.833008 +Train cost, Step 4, Cost 1282.760254 +Train cost, Step 5, Cost 792.097351 +Train cost, Step 6, Cost 491.775848 +Train cost, Step 7, Cost 698.496033 +Train cost, Step 8, Cost 248.444885 +Train cost, Step 9, Cost 325.135132 +``` + +c++ 与 paddle-lite命令: +``` +./demo_trainer true +``` + +c++ 与 paddle-lite结果: +``` +sample 0: Loss: 564.317 +sample 1: Loss: 463.9 +sample 2: Loss: 1197.54 +sample 3: Loss: 1093.83 +sample 4: Loss: 1282.76 +sample 5: Loss: 792.097 +sample 6: Loss: 491.776 +sample 7: Loss: 698.496 +sample 8: Loss: 248.445 +sample 9: Loss: 325.135 +``` + +## Loss 曲线 + +控制训练时的batch size为20,每个epoch对训练数据做全局shuffle,训练100个epoch后,paddle和lite的loss曲线对比如下。 + +![lr_loss](image/lr_loss.png) + +如果想复现上述效果,paddle+python的运行命令为: + +``` +git clone https://github.com/PaddlePaddle/book.git +cd book/01.fit_a_line +python train.py +``` + +lite + c++的运行命令为: +``` +./demo_trainer false +``` diff --git a/lite/demo/cxx/train_demo/cplus_train/CMakeLists.txt b/lite/demo/cxx/train_demo/cplus_train/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b41808352a186e8ed434c0cf9364a9cae7d3928e --- /dev/null +++ b/lite/demo/cxx/train_demo/cplus_train/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 2.8) +set (CMAKE_CXX_STANDARD 11) + +# Project's name + +if(NOT DEFINED LITE_ROOT) + message(FATAL_ERROR "please set LITE_ROOT with + -DLITE_ROOT=/path/to/your/build.lite.android.armv7.gcc/") +endif() + +project(demo_trainer) +# Set the output folder where your program will be created +set(CMAKE_BINARY_DIR ${CMAKE_SOURCE_DIR}/bin) +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}) +set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR}) + +# The following folder will be included +include_directories("include") +include_directories("${LITE_ROOT}/inference_lite_lib.android.armv7/cxx/include") + +add_executable(demo_trainer ${PROJECT_SOURCE_DIR}/demo_trainer.cc ${PROJECT_SOURCE_DIR}/data_reader.cc) + +TARGET_LINK_LIBRARIES(demo_trainer +"${LITE_ROOT}/inference_lite_lib.android.armv7/cxx/lib/libpaddle_full_api_shared.so") diff --git a/lite/demo/cxx/train_demo/cplus_train/data_reader.cc b/lite/demo/cxx/train_demo/cplus_train/data_reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..4546e2e5fecc17321e8126485022b4ac30876747 --- /dev/null +++ b/lite/demo/cxx/train_demo/cplus_train/data_reader.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/data_reader.h" +#include + +using std::string; +using std::vector; + +int FEATURE_NUM = 13; +float rate = 0.8; + +int get_samples(string line, vector* feature, float* label) { + std::istringstream reader(line); + std::vector numbers; + do { + // read as many numbers as possible. + for (float number; reader >> number;) { + numbers.push_back(number); + } + // consume and discard token from stream. + if (reader.fail()) { + reader.clear(); + std::string token; + reader >> token; + } + } while (!reader.eof()); + + assert(numbers.size() == FEATURE_NUM + 1); + for (int i = 0; i < FEATURE_NUM; i++) { + feature->push_back(numbers[i]); + } + *label = numbers[FEATURE_NUM]; + return 0; +} + +int normalize(const vector>& origin_features, + vector>* features, + float rate) { + int inf = std::numeric_limits::max(); + vector min_vec(FEATURE_NUM, static_cast(inf)); + vector max_vec(FEATURE_NUM, -(static_cast(inf))); + vector sum_vec(FEATURE_NUM, 0); + vector avg_vec(FEATURE_NUM, 0); + + for (int i = 0; i < origin_features.size(); i++) { + for (int j = 0; j < FEATURE_NUM; j++) { + min_vec[j] = min(min_vec[j], origin_features[i][j]); + max_vec[j] = max(max_vec[j], origin_features[i][j]); + sum_vec[j] += origin_features[i][j]; + } + } + + for (int i = 0; i < FEATURE_NUM; i++) { + avg_vec[i] = sum_vec[i] / origin_features.size(); + } + + for (int i = 0; i < origin_features.size() * rate - 1; i++) { + vector feat; + for (int j = 0; j < FEATURE_NUM; j++) { + feat.push_back((origin_features[i][j] - avg_vec[j]) / + (max_vec[j] - min_vec[j])); + } + features->push_back(feat); + } +} + +int read_samples(const string fname, + vector>* features, + vector* labels) { + fstream fin; + fin.open(fname); + if (!static_cast(fin)) { + return 1; + } + vector> origin_features; + vector lines; + string line; + while (getline(fin, line)) { + lines.push_back(line); + } + fin.close(); + + for (int i = 0; i < lines.size(); i++) { + vector feat; + float lbl = 0; + get_samples(lines[i], &feat, &lbl); + origin_features.push_back(feat); + if (i < lines.size() * rate - 1) { + labels->push_back(lbl); + } + } + + cout << "finish read fata" << endl; + normalize(origin_features, features, rate); + assert(features->size() == labels->size()); + return 0; +} diff --git a/lite/demo/cxx/train_demo/cplus_train/demo_trainer.cc b/lite/demo/cxx/train_demo/cplus_train/demo_trainer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f035078fff35c4b2c0b41d0de84d2621c550d14e --- /dev/null +++ b/lite/demo/cxx/train_demo/cplus_train/demo_trainer.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "include/data_reader.h" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +class LRModel { + public: + void InitModel() { + // 1. Set CxxConfig + CxxConfig config; + config.set_model_dir("model_dir"); + std::vector valid_places{Place{TARGET(kARM), PRECISION(kFloat)}}; + config.set_valid_places(valid_places); + predictor_ = CreatePaddlePredictor(config); + } + + float Predict(const vector>& features, + const vector& labels) { + // Create Tensor + assert(features.size() == labels.size()); + int batch_size = features.size(); + std::unique_ptr input_tensor(std::move(predictor_->GetInput(0))); + input_tensor->Resize(shape_t({batch_size, FEATURE_NUM})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < FEATURE_NUM; j++) { + data[FEATURE_NUM * i + j] = features[i][j]; + } + } + std::unique_ptr y_tensor(std::move(predictor_->GetInput(1))); + y_tensor->Resize(shape_t({batch_size, 1})); + auto* y_data = y_tensor->mutable_data(); + for (int i = 0; i < batch_size; i++) { + y_data[i] = labels[i]; + } + predictor_->Run(); + std::unique_ptr output_tensor( + std::move(predictor_->GetOutput(0))); + return output_tensor->data()[0]; + } + + private: + std::shared_ptr predictor_; +}; + +int shuffle(vector>* features, vector* labels) { + assert(features->size() == labels->size()); + vector index; + for (int i = 0; i < features->size(); i++) { + index.push_back(i); + } + random_shuffle(index.begin(), index.end()); + + vector> tmp_features; + vector tmp_labels; + + for (int i = 0; i < features->size(); i++) { + tmp_features.push_back((*features)[index[i]]); + tmp_labels.push_back((*labels)[index[i]]); + } + + for (int i = 0; i < features->size(); i++) { + for (int j = 0; j < FEATURE_NUM; j++) { + (*features)[i][j] = tmp_features[i][j]; + } + (*labels)[i] = tmp_labels[i]; + } + return 0; +} + +int main(int argc, char* argv[]) { + if (argc < 2) { + cerr << "usage: ./demo_trainer is_small" << endl; + cerr << " if is_small is true, the batch size is set to 1, " << endl; + cerr << " and it will only runs for 10 steps." << endl; + return 1; + } + string is_small = argv[1]; + vector> features; + vector labels; + read_samples("housing.data", &features, &labels); + cout << "sample count: " << features.size() << " " << endl; + + std::shared_ptr local_model(new LRModel()); + local_model->InitModel(); + + if (is_small == "true") { + cout << "small mode" << endl; + for (int i; i < 10; i++) { + vector> batch_feature; + vector batch_label; + batch_feature.push_back(features[i]); + batch_label.push_back(labels[i]); + auto loss = local_model->Predict(batch_feature, batch_label); + cout << "sample " << i << ": " << loss << endl; + } + } else if (is_small == "false") { + // shuffle + cout << "full model" << endl; + int epoch = 100; + int batch_size = 20; + int step = 0; + for (int i; i < epoch; i++) { + shuffle(&features, &labels); + for (int j = 0; + j < ceil(static_cast(features.size()) / batch_size); + j++) { + int start_idx = j * batch_size; + int end_idx = + min((j + 1) * batch_size, static_cast(features.size())); + auto batch_feature = vector>(features.begin() + start_idx, + features.begin() + end_idx); + auto batch_label = + vector(labels.begin() + start_idx, labels.begin() + end_idx); + auto loss = local_model->Predict(batch_feature, batch_label); + if (step % 10 == 0) { + std::cout << "batch: " << i << ", step: " << step + << ", Loss: " << loss << endl; + } + step += 1; + } + } + } else { + cerr << "wrong arg for is_small: " << is_small << endl; + } +} diff --git a/lite/demo/cxx/train_demo/cplus_train/include/data_reader.h b/lite/demo/cxx/train_demo/cplus_train/include/data_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..050e929c9135ac939dac747e2e4a2490397a4c3d --- /dev/null +++ b/lite/demo/cxx/train_demo/cplus_train/include/data_reader.h @@ -0,0 +1,37 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include + +using std::string; +using std::vector; +using std::cerr; +using std::cout; +using std::endl; +using std::min; +using std::max; +using std::fstream; + +extern int FEATURE_NUM; + +int get_samples(string line, const vector& feature, float* label); +int read_samples(const string fname, + vector>* features, + vector* labels); diff --git a/lite/demo/cxx/train_demo/cplus_train/run_build.sh b/lite/demo/cxx/train_demo/cplus_train/run_build.sh new file mode 100644 index 0000000000000000000000000000000000000000..4fb444ebd1ecda40db2d69c24016cb78bacdc0ad --- /dev/null +++ b/lite/demo/cxx/train_demo/cplus_train/run_build.sh @@ -0,0 +1,21 @@ + +rm -rf build +mkdir build +cd build + +LITE_ROOT=$1 +NDK_ROOT=$2 + + +cmake .. \ + -DLITE_ROOT=${LITE_ROOT} \ + -DNDK_ROOT=${NDK_ROOT} \ + -DCMAKE_TOOLCHAIN_FILE=${NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DANDROID_TOOLCHAIN=gcc \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_PLATFORM=android-23 \ + -DANDROID=true \ + -DANDROID_STL=c++_static +make +cd .. +# ./bin/demo_trainer diff --git a/lite/demo/cxx/train_demo/image/lr_loss.png b/lite/demo/cxx/train_demo/image/lr_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..626cb57ecd5d4cf50fd4d0b8aaadcc29146ca19b Binary files /dev/null and b/lite/demo/cxx/train_demo/image/lr_loss.png differ diff --git a/lite/demo/cxx/train_demo/train.py b/lite/demo/cxx/train_demo/train.py new file mode 100644 index 0000000000000000000000000000000000000000..37825a5cc472990664f68cb38dbf7ee7859286b8 --- /dev/null +++ b/lite/demo/cxx/train_demo/train.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import argparse + +import math +import numpy + +import paddle +import paddle.fluid as fluid + + +def parse_args(): + parser = argparse.ArgumentParser("fit_a_line") + parser.add_argument( + '--save_model', + action='store_true', + help="Whether to save main program") + parser.add_argument( + '--num_steps', + type=int, + default=1000000000000, + help="train steps") + parser.add_argument( + '--num_epochs', type=int, default=100, help="number of epochs.") + parser.add_argument( + '--batch_size', type=int, default=20, help="batch size.") + parser.add_argument( + '--shuffle', + action='store_true', + help="Whether to shuffle train data.") + args = parser.parse_args() + return args + +# For training test cost +def train_test(executor, program, reader, feeder, fetch_list): + accumulated = 1 * [0] + count = 0 + for data_test in reader(): + outs = executor.run( + program=program, feed=feeder.feed(data_test), fetch_list=fetch_list) + accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)] + count += 1 + return [x_d / count for x_d in accumulated] + + +def main(): + if args.shuffle: + print("doing shuffle") + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=args.batch_size) + else: + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=args.batch_size) + + # feature vector of length 13 + x = fluid.data(name='x', shape=[None, 13], dtype='float32') + y = fluid.data(name='y', shape=[None, 1], dtype='float32') + + main_program = fluid.default_main_program() + startup_program = fluid.default_startup_program() + + main_program.random_seed = 90 + startup_program.random_seed = 90 + + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_loss = fluid.layers.mean(cost) + + test_program = main_program.clone(for_test=True) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_loss) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + num_epochs = args.num_epochs + + # main train loop. + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe.run(startup_program) + if args.save_model: + fluid.io.save_persistables(exe, "model_dir") + + # add feed and fetch op + feeded_var_names = ['x', 'y'] + fetch_var_names = ['mean_0.tmp_0'] + fluid.io.prepend_feed_ops(main_program, feeded_var_names) + fluid.io.append_fetch_ops(main_program, fetch_var_names) + with open("model_dir/__model__", "wb") as f: + f.write(main_program.desc.serialize_to_string()) + + with open("debug_main_program", "w") as f: + f.write(str(main_program)) + print("train model saved to model_dir") + return + + train_prompt = "Train cost" + step = 0 + for pass_id in range(num_epochs): + for data_train in train_reader(): + avg_loss_value, = exe.run( + main_program, + feed=feeder.feed(data_train), + fetch_list=[avg_loss]) + print("%s, Step %d, Cost %f" % + (train_prompt, step, avg_loss_value[0])) + if step == args.num_steps - 1: + return + step += 1 + + if math.isnan(float(avg_loss_value[0])): + sys.exit("got NaN loss, training failed.") + + +if __name__ == '__main__': + args = parse_args() + main() diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 7550d770145d92ebd343f96a82c6f34d72c91ea5..a3b1c3680e283a4425fe22209c443ce7cd958267 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -106,13 +106,12 @@ add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math # 4. training kernels add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) -if(LITE_WITH_TRAIN) - add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) - add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) -endif() + +add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) diff --git a/lite/kernels/arm/sequence_pool_compute.cc b/lite/kernels/arm/sequence_pool_compute.cc index 8fcbb8cffe72935e4df503c3c1748ddb68247fb7..53fa5477036757fa70135569129fee115eb52047 100644 --- a/lite/kernels/arm/sequence_pool_compute.cc +++ b/lite/kernels/arm/sequence_pool_compute.cc @@ -59,7 +59,8 @@ void SequencePoolCompute::Run() { for (int i = 0; i <= batch_size; i++) { offset_new[i] = i; } - (output->mutable_lod())->push_back(offset_new); + output->mutable_lod()->clear(); + output->mutable_lod()->push_back(offset_new); } } // namespace arm diff --git a/lite/kernels/bm/bridges/CMakeLists.txt b/lite/kernels/bm/bridges/CMakeLists.txt index 75375f493fe9b6b1f436ef679a7ea8bd80e5ad0a..fe5c39380ef440a98cf0544177bb781644344eb2 100644 --- a/lite/kernels/bm/bridges/CMakeLists.txt +++ b/lite/kernels/bm/bridges/CMakeLists.txt @@ -30,6 +30,8 @@ lite_cc_library(subgraph_bridge_conv_transpose_op_bm SRCS conv_transpose_op.cc D lite_cc_library(subgraph_bridge_reduce_full_op_bm SRCS reduce_full_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_fill_constant_op_bm SRCS fill_constant_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_assign_value_op_bm SRCS assign_value_op.cc DEPS ${bm_subgraph_bridge_deps}) set(bm_subgraph_bridges subgraph_bridge_registry @@ -58,4 +60,6 @@ set(bm_subgraph_bridges subgraph_bridge_reduce_full_op_bm subgraph_bridge_squeeze_op_bm subgraph_bridge_cast_op_bm + subgraph_bridge_fill_constant_op_bm + subgraph_bridge_assign_value_op_bm CACHE INTERNAL "bm_subgraph_bridges") diff --git a/lite/kernels/bm/bridges/act_op.cc b/lite/kernels/bm/bridges/act_op.cc index 091743157995ab1a00e798a6ac560454d4b22ae7..1739dd4185ebcff6a35e2f75c5f8c84ceebd2f0a 100644 --- a/lite/kernels/bm/bridges/act_op.cc +++ b/lite/kernels/bm/bridges/act_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/npu/bridges/registry.h" @@ -35,16 +36,14 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output_var_name = op_info->Output("Out").front(); auto output = scope->FindVar(output_var_name)->GetMutable(); auto output_dims = output->dims(); - const int64_t* x_shape_data = const_cast(&x_dims.data()[0]); - const int64_t* output_shape_data = - const_cast(&output_dims.data()[0]); + bool x_is_const = !graph->HasNode(x_var_name); std::vector i_x_shape_data(x_dims.size()); std::vector i_output_shape_data(output_dims.size()); for (size_t i = 0; i < x_dims.size(); i++) { - i_x_shape_data[i] = static_cast(x_shape_data[i]); + i_x_shape_data[i] = x_dims[i]; } for (size_t i = 0; i < output_dims.size(); i++) { - i_output_shape_data[i] = static_cast(output_shape_data[i]); + i_output_shape_data[i] = output_dims[i]; } float alpha = 0.f; int active_type_id = 0; @@ -59,6 +58,15 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { LOG(FATAL) << "[BM] unsupport act type"; return FAILED; } + const float* x_data = const_cast(x->mutable_data()); + if (x_is_const) { + bm_add_const_tensor(graph->GetCompilerHandle(), + static_cast(x_var_name.c_str()), + const_cast(&i_x_shape_data[0]), + x_dims.size(), + static_cast(DTYPE_FP32), + static_cast(x_data)); + } if (op_type == "relu" || op_type == "leaky_relu") { add_relu_layer(graph->GetCompilerHandle(), const_cast(&i_x_shape_data[0]), diff --git a/lite/kernels/bm/bridges/assign_value_op.cc b/lite/kernels/bm/bridges/assign_value_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c59cdf29def68c58cd33ee44b24816e0dbccd32e --- /dev/null +++ b/lite/kernels/bm/bridges/assign_value_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int AssignValueConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + std::vector i_output_shape_data(output_dims.size()); + int buffer_size = 1; + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_dims[i]); + buffer_size *= i_output_shape_data[i]; + } + auto fp32_values = op_info->GetAttr>("fp32_values"); + float* assign_data = + reinterpret_cast(malloc(buffer_size * sizeof(float))); + CHECK(assign_data != nullptr); + CHECK_EQ(buffer_size, fp32_values.size()); + + bm_add_const_tensor(graph->GetCompilerHandle(), + static_cast(output_var_name.c_str()), + const_cast(i_output_shape_data.data()), + output_dims.size(), + static_cast(DTYPE_FP32), + reinterpret_cast(assign_data)); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(assign_value, + kBM, + paddle::lite::subgraph::bm::AssignValueConverter); diff --git a/lite/kernels/bm/bridges/conv_op.cc b/lite/kernels/bm/bridges/conv_op.cc index e4dff107024c02dcfe25afe37723b7d2418369b5..2a0903191b82bbc1c409f59f2eb19bd2ffd5ddac 100644 --- a/lite/kernels/bm/bridges/conv_op.cc +++ b/lite/kernels/bm/bridges/conv_op.cc @@ -39,6 +39,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto filter_var_name = op_info->Input("Filter").front(); auto filter = scope->FindVar(filter_var_name)->GetMutable(); auto filter_dims = filter->dims(); + CHECK_EQ(input_dims.size(), 4); CHECK_EQ(output_dims.size(), 4); CHECK_EQ(filter_dims.size(), 4); @@ -90,6 +91,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { dilations[1], static_cast(has_bias)); graph->AddNode(output_var_name); + LOG(INFO) << output_var_name << input_dims << " " << output_dims; return SUCCESS; } diff --git a/lite/kernels/bm/bridges/elementwise_ops.cc b/lite/kernels/bm/bridges/elementwise_ops.cc index 3006a8b6fdaef5a250af1b2e764aff9f2913898e..de5a62dae068620c378cd566abfde556c0f63102 100644 --- a/lite/kernels/bm/bridges/elementwise_ops.cc +++ b/lite/kernels/bm/bridges/elementwise_ops.cc @@ -65,6 +65,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output_dims = output->dims(); const int64_t* output_shape_data = const_cast(&output_dims.data()[0]); + LOG(INFO) << x_dims << " " << output_dims; std::vector i_output_shape_data(output_dims.size()); for (size_t i = 0; i < output_dims.size(); i++) { i_output_shape_data[i] = static_cast(output_shape_data[i]); diff --git a/lite/kernels/bm/bridges/fill_constant_op.cc b/lite/kernels/bm/bridges/fill_constant_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..835ccf0eb4e72ad945b7a24643190ff49f0b5723 --- /dev/null +++ b/lite/kernels/bm/bridges/fill_constant_op.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int FillConstantConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + + auto output_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + std::vector i_output_shape_data(output_dims.size()); + int buffer_size = 1; + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_dims[i]); + } + float* const_data = + reinterpret_cast(malloc(buffer_size * sizeof(float))); + CHECK(const_data != nullptr); + auto value = op_info->GetAttr("value"); + for (size_t i = 0; i < buffer_size; i++) { + const_data[i] = value; + } + bm_add_const_tensor(graph->GetCompilerHandle(), + static_cast(output_var_name.c_str()), + const_cast(i_output_shape_data.data()), + output_dims.size(), + static_cast(DTYPE_FP32), + reinterpret_cast(const_data)); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(fill_constant, + kBM, + paddle::lite::subgraph::bm::FillConstantConverter); diff --git a/lite/kernels/bm/bridges/mul_op.cc b/lite/kernels/bm/bridges/mul_op.cc index 06ec177bceb883758c42d45c9b07006a83b3c9f6..35e1aac7660683ff5544a7d72574167359b29fdb 100644 --- a/lite/kernels/bm/bridges/mul_op.cc +++ b/lite/kernels/bm/bridges/mul_op.cc @@ -29,7 +29,6 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto op_info = op->op_info(); auto op_type = op_info->Type(); auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); - // only support y is const // input auto x_var_name = op_info->Input("X").front(); auto x = scope->FindVar(x_var_name)->GetMutable(); @@ -61,6 +60,12 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto y_var_name = op_info->Input("Y").front(); auto y = scope->FindVar(y_var_name)->GetMutable(); auto y_dims = y->dims(); + bool y_is_const = !graph->HasNode(y_var_name); + CHECK_EQ(y_dims.size(), 2); + int i_y_shape_data[2]; + for (size_t i = 0; i < 2; i++) { + i_y_shape_data[i] = y_dims[i]; + } // output auto output_var_name = op_info->Output("Out").front(); auto output = scope->FindVar(output_var_name)->GetMutable(); @@ -71,20 +76,39 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { for (size_t i = 0; i < output_dims.size(); i++) { i_output_shape_data[i] = static_cast(output_shape_data[i]); } - add_fc_layer(graph->GetCompilerHandle(), - const_cast(&i_x_reshape_shape_data[0]), - 2, - static_cast(unique_op_reshape_name.c_str()), - const_cast(&i_output_shape_data[0]), - output_dims.size(), - static_cast(output_var_name.c_str()), - static_cast(unique_op_name.c_str()), - i_x_reshape_shape_data[1], - i_output_shape_data[1], - static_cast(y->mutable_data()), - nullptr, - 0, - 0); + if (y_is_const) { + add_fc_layer(graph->GetCompilerHandle(), + const_cast(&i_x_reshape_shape_data[0]), + 2, + static_cast(unique_op_reshape_name.c_str()), + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + static_cast(unique_op_name.c_str()), + i_x_reshape_shape_data[1], + i_output_shape_data[1], + static_cast(y->mutable_data()), + nullptr, + 0, + 0); + } else { + add_fc_weight_layer( + graph->GetCompilerHandle(), + const_cast(&i_x_reshape_shape_data[0]), + 2, + static_cast(unique_op_reshape_name.c_str()), + const_cast(&i_output_shape_data[0]), + output_dims.size(), + static_cast(output_var_name.c_str()), + static_cast(unique_op_name.c_str()), + const_cast(&i_y_shape_data[0]), + 2, + static_cast(y_var_name.c_str()), + i_x_reshape_shape_data[1], + nullptr, + 0, + 0); + } graph->AddNode(output_var_name); return SUCCESS; } diff --git a/lite/kernels/bm/bridges/paddle_use_bridges.h b/lite/kernels/bm/bridges/paddle_use_bridges.h index 8dbbb53d810952743228d96d60d7927965d2d527..48b708fed4889a9a0f515d2b9d76232cfb532be5 100644 --- a/lite/kernels/bm/bridges/paddle_use_bridges.h +++ b/lite/kernels/bm/bridges/paddle_use_bridges.h @@ -51,3 +51,5 @@ USE_SUBGRAPH_BRIDGE(reduce_mean, kBM); USE_SUBGRAPH_BRIDGE(squeeze, kBM); USE_SUBGRAPH_BRIDGE(squeeze2, kBM); USE_SUBGRAPH_BRIDGE(cast, kBM); +USE_SUBGRAPH_BRIDGE(fill_constant, kBM); +USE_SUBGRAPH_BRIDGE(assign_value, kBM); diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index 338939f019fb8da37d0b0a234e2c8b408e5a9ad0..45f9bd16c433ef94daf242eae9c2168ac1424147 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() { graph.CreateCompilerHandle(); auto& ctx = this->ctx_->template As(); for (auto& inst : origin_program_) { - auto op = inst.op(); + auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); op->InferShape(); diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 3fb3136bfc0787f9d8e539039811d25559919f4e..0fb3c2ea7aa66b313411ac9d97c9918eb2ca8d2f 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -8,6 +8,8 @@ add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_de add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps}) add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps}) @@ -45,6 +47,8 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_ #nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda) +nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda) nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) @@ -61,7 +65,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) -#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda) +#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda) #nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) if(LITE_BUILD_EXTRA) diff --git a/lite/kernels/cuda/abs_compute.cu b/lite/kernels/cuda/abs_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..4f00aacc0cdcea07329e3836c7068f419d26f90c --- /dev/null +++ b/lite/kernels/cuda/abs_compute.cu @@ -0,0 +1,71 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/abs_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void AbsKernel(const int num, const T* input, T* output); + +template <> +__global__ void AbsKernel(const int num, + const float* input, + float* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = fabsf(input[index]); + } +} + +template <> +__global__ void AbsKernel(const int num, + const double* input, + double* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = fabs(input[index]); + } +} + +void AbsCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + const int threads = 512; + const int blocks = (num + threads - 1) / threads; + AbsKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + abs, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::AbsCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/abs_compute.h b/lite/kernels/cuda/abs_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..d1f8a0cc5ac52e01cc8ea920bdad62ef46fd0640 --- /dev/null +++ b/lite/kernels/cuda/abs_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class AbsCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~AbsCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/abs_compute_test.cc b/lite/kernels/cuda/abs_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bfbcae56fa51fc59c9917aa112fa5320c2759a9a --- /dev/null +++ b/lite/kernels/cuda/abs_compute_test.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/abs_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(abs, normal) { + AbsCompute abs_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 3, w = 3; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 1.5; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + abs_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + abs_kernel.SetContext(std::move(ctx)); + abs_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], std::fabs(x_cpu_data[i]), 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_compute.cu b/lite/kernels/cuda/elementwise_compute.cu index 64759f86f5df85f9855b9c1f186bbc9c039a044c..02b7c8f7d9e829b100e6c96aca2a8cee3ca74ef1 100644 --- a/lite/kernels/cuda/elementwise_compute.cu +++ b/lite/kernels/cuda/elementwise_compute.cu @@ -152,6 +152,18 @@ void ElementwiseAddComputeNHWC::Run() { if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } +void ElementwiseSubCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseSubComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + void ElementwiseMulCompute::Run() { ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) cudaError_t error = cudaGetLastError(); @@ -204,6 +216,17 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseSubCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kCUDA, kFloat, @@ -224,6 +247,26 @@ REGISTER_LITE_KERNEL(elementwise_add, DATALAYOUT(kNHWC))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseSubComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kCUDA, kFloat, diff --git a/lite/kernels/cuda/elementwise_compute.h b/lite/kernels/cuda/elementwise_compute.h index 986a4db2272d9a6607090babd937747f861f49c7..bc9ffd5d27c7b030f397d1b631a155cae5f34678 100644 --- a/lite/kernels/cuda/elementwise_compute.h +++ b/lite/kernels/cuda/elementwise_compute.h @@ -38,6 +38,24 @@ class ElementwiseAddComputeNHWC virtual ~ElementwiseAddComputeNHWC() = default; }; +class ElementwiseSubCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseSubCompute() = default; +}; + +class ElementwiseSubComputeNHWC + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseSubComputeNHWC() = default; +}; + class ElementwiseMulCompute : public KernelLite { public: diff --git a/lite/kernels/cuda/tanh_compute.cu b/lite/kernels/cuda/tanh_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..4f9e2729a7fa0f300308b9f1afcf35e852d11223 --- /dev/null +++ b/lite/kernels/cuda/tanh_compute.cu @@ -0,0 +1,56 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/tanh_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void TanhKernel(const int num, const T* input, T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + output[index] = tanh(input[index]); + } +} + +void TanhCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + const int threads = 512; + const int blocks = (num + threads - 1) / threads; + TanhKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + tanh, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::TanhCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/tanh_compute.h b/lite/kernels/cuda/tanh_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b23b27882cc2ef6f5c8e15ba49fbdd5316cbfa3e --- /dev/null +++ b/lite/kernels/cuda/tanh_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class TanhCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~TanhCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/tanh_compute_test.cc b/lite/kernels/cuda/tanh_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bc8f25df0bed46254c56d8ec1080e45062bada2 --- /dev/null +++ b/lite/kernels/cuda/tanh_compute_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/tanh_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(tanh, fp32) { + TanhCompute tanh_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 3, w = 3; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 1.5; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + tanh_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + tanh_kernel.SetContext(std::move(ctx)); + tanh_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], tanh(x_cpu_data[i]), 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index f337e518abd071ac262ce9ee47beae1600cc57d1..a52428aa097099150139de82627d5770c9b9071c 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -5,6 +5,3 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps}) - -#lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) -#lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/kernels/host/multiclass_nms_compute.cc b/lite/kernels/host/multiclass_nms_compute.cc index d68f296e3fc4ffbe69d18d556a125e0478ed47cf..5a09fca72b4bb30ac67b1186cf90c58a5f9a1dd4 100644 --- a/lite/kernels/host/multiclass_nms_compute.cc +++ b/lite/kernels/host/multiclass_nms_compute.cc @@ -370,6 +370,7 @@ void MulticlassNmsCompute::Run() { } } else { outs->Resize({static_cast(num_kept), out_dim}); + outs->mutable_data(); int offset = 0; int* oindices = nullptr; for (int i = 0; i < n; ++i) { diff --git a/lite/kernels/host/multiclass_nms_compute_test.cc b/lite/kernels/host/multiclass_nms_compute_test.cc deleted file mode 100644 index 83fb717042515a7a06fe0c014fca7482ad6c8684..0000000000000000000000000000000000000000 --- a/lite/kernels/host/multiclass_nms_compute_test.cc +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/host/multiclass_nms_compute.h" -#include -#include -#include -#include - -namespace paddle { -namespace lite { -namespace kernels { -namespace host { - -template -static bool sort_score_pair_descend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; -} - -template -void get_max_score_index(const dtype* scores, - int num, - float threshold, - int top_k, - std::vector>* score_index_vec) { - //! Generate index score pairs. - for (int i = 0; i < num; ++i) { - if (scores[i] > threshold) { - score_index_vec->push_back(std::make_pair(scores[i], i)); - } - } - - //! Sort the score pair according to the scores in descending order - std::stable_sort(score_index_vec->begin(), - score_index_vec->end(), - sort_score_pair_descend); - - //! Keep top_k scores if needed. - if (top_k > -1 && top_k < score_index_vec->size()) { - score_index_vec->resize(top_k); - } -} - -template -dtype bbox_size(const dtype* bbox, bool normalized = true) { - if (bbox[2] < bbox[0] || bbox[3] < bbox[1]) { - // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. - return dtype(0.); - } else { - const dtype width = bbox[2] - bbox[0]; - const dtype height = bbox[3] - bbox[1]; - - if (normalized) { - return width * height; - } else { - // If bbox is not within range [0, 1]. - return (width + 1) * (height + 1); - } - } -} - -template -dtype jaccard_overlap(const dtype* bbox1, const dtype* bbox2) { - if (bbox2[0] > bbox1[2] || bbox2[2] < bbox1[0] || bbox2[1] > bbox1[3] || - bbox2[3] < bbox1[1]) { - return dtype(0.); - } else { - const dtype inter_xmin = std::max(bbox1[0], bbox2[0]); - const dtype inter_ymin = std::max(bbox1[1], bbox2[1]); - const dtype inter_xmax = std::min(bbox1[2], bbox2[2]); - const dtype inter_ymax = std::min(bbox1[3], bbox2[3]); - - const dtype inter_width = inter_xmax - inter_xmin; - const dtype inter_height = inter_ymax - inter_ymin; - const dtype inter_size = inter_width * inter_height; - - const dtype bbox1_size = bbox_size(bbox1); - const dtype bbox2_size = bbox_size(bbox2); - - return inter_size / (bbox1_size + bbox2_size - inter_size); - } -} - -template -void apply_nms_fast(const dtype* bboxes, - const dtype* scores, - int num, - float score_threshold, - float nms_threshold, - float eta, - int top_k, - std::vector* indices) { - // Get top_k scores (with corresponding indices). - std::vector> score_index_vec; - get_max_score_index(scores, num, score_threshold, top_k, &score_index_vec); - - // Do nms. - float adaptive_threshold = nms_threshold; - indices->clear(); - - while (score_index_vec.size() != 0) { - const int idx = score_index_vec.front().second; - bool keep = true; - - for (int k = 0; k < indices->size(); ++k) { - if (keep) { - const int kept_idx = (*indices)[k]; - float overlap = - jaccard_overlap(bboxes + idx * 4, bboxes + kept_idx * 4); - keep = overlap <= adaptive_threshold; - } else { - break; - } - } - - if (keep) { - indices->push_back(idx); - } - - score_index_vec.erase(score_index_vec.begin()); - - if (keep && eta < 1 && adaptive_threshold > 0.5) { - adaptive_threshold *= eta; - } - } -} - -template -void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param, - int class_num, - const std::vector& priors, - bool share_location, - std::vector* result) { - int background_id = param.background_label; - int keep_topk = param.keep_top_k; - int nms_topk = param.nms_top_k; - float conf_thresh = param.score_threshold; - float nms_thresh = param.nms_threshold; - float nms_eta = param.nms_eta; - const dtype* bbox_data = param.bboxes->data(); - const dtype* conf_data = param.scores->data(); - dtype* out = param.out->mutable_data(); - (*result).clear(); - - int num_kept = 0; - std::vector>> all_indices; - int64_t conf_offset = 0; - int64_t bbox_offset = 0; - for (int i = 0; i < priors.size(); ++i) { - std::map> indices; - int num_det = 0; - int num_priors = priors[i]; - - int conf_idx = class_num * conf_offset; - int bbox_idx = - share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num; - - for (int c = 0; c < class_num; ++c) { - if (c == background_id) { - // Ignore background class - continue; - } - - const dtype* cur_conf_data = conf_data + conf_idx + c * num_priors; - const dtype* cur_bbox_data = bbox_data + bbox_idx; - - if (!share_location) { - cur_bbox_data += c * num_priors * 4; - } - - apply_nms_fast(cur_bbox_data, - cur_conf_data, - num_priors, - conf_thresh, - nms_thresh, - nms_eta, - nms_topk, - &(indices[c])); - num_det += indices[c].size(); - } - - if (keep_topk > -1 && num_det > keep_topk) { - std::vector>> score_index_pairs; - - for (auto it = indices.begin(); it != indices.end(); ++it) { - int label = it->first; - const std::vector& label_indices = it->second; - - for (int j = 0; j < label_indices.size(); ++j) { - int idx = label_indices[j]; - float score = conf_data[conf_idx + label * num_priors + idx]; - score_index_pairs.push_back( - std::make_pair(score, std::make_pair(label, idx))); - } - } - - // Keep top k results per image. - std::stable_sort(score_index_pairs.begin(), - score_index_pairs.end(), - sort_score_pair_descend>); - score_index_pairs.resize(keep_topk); - // Store the new indices. - std::map> new_indices; - - for (int j = 0; j < score_index_pairs.size(); ++j) { - int label = score_index_pairs[j].second.first; - int idx = score_index_pairs[j].second.second; - new_indices[label].push_back(idx); - } - - all_indices.push_back(new_indices); - num_kept += keep_topk; - } else { - all_indices.push_back(indices); - num_kept += num_det; - } - conf_offset += num_priors; - bbox_offset += num_priors; - } - - if (num_kept == 0) { - (*result).clear(); - (*result).resize(1); - (*result)[0] = -1; - return; - } else { - (*result).resize(num_kept * 6); - } - - int count = 0; - - conf_offset = 0; - bbox_offset = 0; - for (int i = 0; i < priors.size(); ++i) { - int num_priors = priors[i]; - int conf_idx = class_num * conf_offset; - int bbox_idx = - share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num; - - for (auto it = all_indices[i].begin(); it != all_indices[i].end(); ++it) { - int label = it->first; - std::vector& indices = it->second; - const dtype* cur_conf_data = conf_data + conf_idx + label * num_priors; - const dtype* cur_bbox_data = bbox_data + bbox_idx; - - if (!share_location) { - cur_bbox_data += label * num_priors * 4; - } - - for (int j = 0; j < indices.size(); ++j) { - int idx = indices[j]; - (*result)[count * 6] = label; - (*result)[count * 6 + 1] = cur_conf_data[idx]; - - for (int k = 0; k < 4; ++k) { - (*result)[count * 6 + 2 + k] = cur_bbox_data[idx * 4 + k]; - } - - ++count; - } - } - conf_offset += num_priors; - bbox_offset += num_priors; - } -} - -TEST(multiclass_nms_host, init) { - MulticlassNmsCompute multiclass_nms; - ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat)); - ASSERT_EQ(multiclass_nms.target(), TARGET(kHost)); -} - -TEST(multiclass_nms_host, retrive_op) { - auto multiclass_nms = - KernelRegistry::Global().Create( - "multiclass_nms"); - ASSERT_FALSE(multiclass_nms.empty()); - ASSERT_TRUE(multiclass_nms.front()); -} - -TEST(multiclass_nms_host, compute) { - MulticlassNmsCompute multiclass_nms; - operators::MulticlassNmsParam param; - lite::Tensor bbox, conf, out; - std::vector out_ref; - - for (std::vector priors : {std::vector({2, 2, 2})}) { - int N = priors.size(); - for (bool share_location : {true}) { - for (int class_num : {1, 4, 10}) { - DDim* bbox_dim; - DDim* conf_dim; - int M = priors[0]; - if (share_location) { - bbox_dim = new DDim({N, M, 4}); - } else { - bbox_dim = new DDim({class_num, M, 4}); - } - conf_dim = new DDim({N, class_num, M}); - bbox.Resize(*bbox_dim); - conf.Resize(*conf_dim); - for (int background_id : {0}) { - for (int keep_topk : {1, 5, 10}) { - for (int nms_topk : {1, 5, 10}) { - for (float nms_eta : {1.0, 0.99, 0.9}) { - for (float nms_thresh : {0.5, 0.7}) { - for (float conf_thresh : {0.5, 0.7}) { - auto* conf_data = conf.mutable_data(); - auto* bbox_data = bbox.mutable_data(); - for (int i = 0; i < bbox_dim->production(); ++i) { - bbox_data[i] = i * 1. / bbox_dim->production(); - } - for (int i = 0; i < conf_dim->production(); ++i) { - conf_data[i] = i * 1. / conf_dim->production(); - } - param.bboxes = &bbox; - param.scores = &conf; - param.out = &out; - param.background_label = background_id; - param.keep_top_k = keep_topk; - param.nms_top_k = nms_topk; - param.score_threshold = conf_thresh; - param.nms_threshold = nms_thresh; - param.nms_eta = nms_eta; - multiclass_nms.SetParam(param); - multiclass_nms.Run(); - auto* out_data = out.mutable_data(); - out_ref.clear(); - multiclass_nms_compute_ref( - param, class_num, priors, share_location, &out_ref); - EXPECT_EQ(out.dims().production(), out_ref.size()); - if (out.dims().production() == out_ref.size()) { - auto* out_ref_data = out_ref.data(); - for (int i = 0; i < out.dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); - } - } - } - } - } - } - } - } - delete bbox_dim; - delete conf_dim; - } - } - } -} - -} // namespace host -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def); diff --git a/lite/kernels/npu/bridges/conv_op.cc b/lite/kernels/npu/bridges/conv_op.cc index 637b6eea5c99f9ab2a43d4bd442a3a720dced96a..f21e5618b0d8b2e0e7ed4aec0b1bc9b16c4877d9 100644 --- a/lite/kernels/npu/bridges/conv_op.cc +++ b/lite/kernels/npu/bridges/conv_op.cc @@ -220,6 +220,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { act_op->set_attr_mode(CvtActMode(act_type)); if (act_type == "leaky_relu") { act_op->set_attr_negative_slope(leaky_relu_alpha); + } else if (act_type == "relu6") { + act_op->set_attr_coef(6.f); } } diff --git a/lite/kernels/npu/bridges/paddle_use_bridges.h b/lite/kernels/npu/bridges/paddle_use_bridges.h index 6c406302212640ec41d0701f530c0c1f32229539..2b41c36b3c0f7dc0a56049fdb3a154370883836c 100644 --- a/lite/kernels/npu/bridges/paddle_use_bridges.h +++ b/lite/kernels/npu/bridges/paddle_use_bridges.h @@ -18,6 +18,7 @@ USE_SUBGRAPH_BRIDGE(sigmoid, kNPU); USE_SUBGRAPH_BRIDGE(relu, kNPU); USE_SUBGRAPH_BRIDGE(tanh, kNPU); USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU); +USE_SUBGRAPH_BRIDGE(relu6, kNPU); USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU); USE_SUBGRAPH_BRIDGE(softsign, kNPU); USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU); diff --git a/lite/kernels/npu/bridges/pool_op.cc b/lite/kernels/npu/bridges/pool_op.cc index e30a286961c376ac94de78a5a8f9f8a776af062a..51f67a1c6f0122c1140aeb762b448a928bd16692 100644 --- a/lite/kernels/npu/bridges/pool_op.cc +++ b/lite/kernels/npu/bridges/pool_op.cc @@ -99,10 +99,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ksize); // ceil mode - int ceil_mode = 0; - if (op_info->HasAttr("ceil_mode")) { - ceil_mode = op_info->GetAttr("ceil_mode") ? 1 : 0; - } + bool ceil_mode = + op_info->HasAttr("ceil_mode") && op_info->GetAttr("ceil_mode"); // Pooling node auto pool_node = graph->Add(out_name); @@ -112,12 +110,14 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { pool_op->set_attr_pad_mode(pad_mode); pool_op->set_attr_global_pooling(global_pooling); pool_op->set_attr_window(ge::AttrValue::LIST_INT(ksize.begin(), ksize.end())); - pool_op->set_attr_pad(ge::AttrValue::LIST_INT{ - paddings[0], paddings[1], paddings[2], paddings[3]}); + pool_op->set_attr_pad( + ge::AttrValue::LIST_INT(paddings.begin(), paddings.end())); pool_op->set_attr_stride( ge::AttrValue::LIST_INT(strides.begin(), strides.end())); - pool_op->set_attr_ceil_mode(ceil_mode); - // pool_op->set_attr_data_mode(data_mode); + if (ceil_mode) { + pool_op->set_attr_ceil_mode(1); + pool_op->set_attr_data_mode(0); + } return REBUILD_WHEN_SHAPE_CHANGED; } diff --git a/lite/kernels/npu/subgraph_compute.cc b/lite/kernels/npu/subgraph_compute.cc index d7b14a9319951eb827cbc9d346ee8e59e9571aee..1baa5a0de44d71356cabd505fb0cdfe388a0bae3 100644 --- a/lite/kernels/npu/subgraph_compute.cc +++ b/lite/kernels/npu/subgraph_compute.cc @@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() { subgraph::npu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); for (auto& inst : origin_program_) { - auto op = inst.op(); + auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); op->InferShape(); @@ -44,10 +44,8 @@ int SubgraphEngine::BuildDeviceProgram() { return subgraph::FAILED; } auto kernel = inst.kernel(); - status |= - bridges.Select(op_type, TARGET(kNPU))(reinterpret_cast(&graph), - const_cast(op), - const_cast(kernel)); + status |= bridges.Select(op_type, TARGET(kNPU))( + reinterpret_cast(&graph), op, const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { return subgraph::FAILED; } diff --git a/lite/kernels/opencl/activation_buffer_compute.cc b/lite/kernels/opencl/activation_buffer_compute.cc index c662aa89fb257aded70119ea14494111398f0529..03ccdac99e5f11e1c056374463f7a8068dbd4f56 100644 --- a/lite/kernels/opencl/activation_buffer_compute.cc +++ b/lite/kernels/opencl/activation_buffer_compute.cc @@ -32,8 +32,10 @@ class ReluCompute std::string doc() const override { return "Relu using cl::Buffer, kFloat"; } void PrepareForRun() override { auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/relu_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/relu_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -46,7 +48,7 @@ class ReluCompute auto* x_buf = param.X->data(); auto* out_buf = param.Out->mutable_data(TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); VLOG(4) << TargetToStr(param.X->target()); VLOG(4) << TargetToStr(param.Out->target()); @@ -74,6 +76,7 @@ class ReluCompute private: std::string kernel_func_name_{"relu"}; std::string build_options_{"-DCL_DTYPE_float -DRELU"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; @@ -87,8 +90,10 @@ class SigmoidCompute } void PrepareForRun() override { auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/sigmoid_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/sigmoid_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -101,7 +106,7 @@ class SigmoidCompute auto* x_buf = param.X->data(); auto* out_buf = param.Out->mutable_data(TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); VLOG(4) << TargetToStr(param.X->target()); VLOG(4) << TargetToStr(param.Out->target()); @@ -129,6 +134,7 @@ class SigmoidCompute private: std::string kernel_func_name_{"sigmoid"}; std::string build_options_{"-DCL_DTYPE_float -DSIGMOID"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/activation_image_compute.cc b/lite/kernels/opencl/activation_image_compute.cc index dbe487ba91d00c2de4c08edf140526d727bac6b5..a99e588eccd79eb35a5e7c0f3da73471849ab581 100644 --- a/lite/kernels/opencl/activation_image_compute.cc +++ b/lite/kernels/opencl/activation_image_compute.cc @@ -37,11 +37,12 @@ class ActivationComputeImageDefault } void PrepareForRun() override { - auto& context = ctx_->As(); act_param_ = param_.get_mutable(); int act_type = static_cast(act_param_->active_type); +#ifndef LITE_SHUTDOWN_LOG VLOG(1) << "ActivationTypeToStr(act_param_->active_type):" << ActivationTypeToStr(act_param_->active_type); +#endif switch (act_type) { case 1: kernel_func_name_ = "relu"; @@ -71,41 +72,70 @@ class ActivationComputeImageDefault LOG(FATAL) << "This act type:" << act_type << " doesn't support."; return; } +#ifndef LITE_SHUTDOWN_LOG VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - context.cl_context()->AddKernel( - kernel_func_name_, "image/activation_kernel.cl", build_options_); - } - - void Run() override { - auto& param = *param_.get_mutable(); - const auto& x_dims = param.X->dims(); - auto* x_img = param.X->data(); - auto image_shape = InitImageDimInfoWith(x_dims); - auto* out_img = param.Out->mutable_data( - image_shape["width"], image_shape["height"]); - const auto& y_dims = param.Out->dims(); // useless: check dim only +#endif auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); + context.cl_context()->AddKernel(kernel_func_name_, + "image/activation_kernel.cl", + build_options_, + time_stamp_); + STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } - int arg_idx = 0; - cl_int status = kernel.setArg(arg_idx, *x_img); + void ReInitWhenNeeded() override { + act_param_ = param_.get_mutable(); + auto x_dims = act_param_->X->dims(); + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + x_img_shape_ = default_convertor.InitImageDimInfoWith( + act_param_->X->dims()); // w, h + out_img_shape_ = default_convertor.InitImageDimInfoWith( + act_param_->Out->dims()); // w, h + + // compute global work size + GetGlobalWorkSize(); + } + } + + void GetGlobalWorkSize() { + global_work_size_ = + cl::NDRange{static_cast(x_img_shape_[0]), + static_cast(x_img_shape_[1])}; + } + + void Run() override { + auto* x_img = act_param_->X->data(); + auto* out_img = act_param_->Out->mutable_data( + out_img_shape_[0], out_img_shape_[1]); + + auto kernel = kernel_; + cl_int status; + status = kernel.setArg(0, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_img); + status = kernel.setArg(1, *out_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, threshold_); + status = kernel.setArg(2, threshold_); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, scale_); + status = kernel.setArg(3, scale_); CL_CHECK_FATAL(status); #ifndef LITE_SHUTDOWN_LOG - VLOG(4) << TargetToStr(param.X->target()); - VLOG(4) << TargetToStr(param.Out->target()); - VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " - << image_shape["height"]; + const auto& x_dims = act_param_->X->dims(); + const auto& y_dims = act_param_->Out->dims(); // useless: check dim only + VLOG(4) << TargetToStr(act_param_->X->target()); + VLOG(4) << TargetToStr(act_param_->Out->target()); + VLOG(4) << "x_img_shape_(w,h):" << x_img_shape_[0] << " " + << x_img_shape_[1]; VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " @@ -115,13 +145,12 @@ class ActivationComputeImageDefault VLOG(4) << "kernel func name:" << kernel_func_name_; #endif - auto global_work_size = - cl::NDRange{static_cast(image_shape["width"]), - static_cast(image_shape["height"])}; + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, - global_work_size, + global_work_size_, cl::NullRange, nullptr, event_.get()); @@ -131,10 +160,20 @@ class ActivationComputeImageDefault private: param_t* act_param_{nullptr}; + DDim x_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + DDim last_x_dims_; std::string kernel_func_name_{}; float threshold_{6.f}; float scale_{1.f}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; } // namespace opencl diff --git a/lite/kernels/opencl/bilinear_interp_image_compute.cc b/lite/kernels/opencl/bilinear_interp_image_compute.cc index 7e32010c0b5ff5cedad8b0da7ce7233fbf73da6f..53f260789e12a94dc39f785df12a8e988d08bcbe 100644 --- a/lite/kernels/opencl/bilinear_interp_image_compute.cc +++ b/lite/kernels/opencl/bilinear_interp_image_compute.cc @@ -43,8 +43,10 @@ class BilinearInterpImageCompute bilinear_interp_param_ = param_.get_mutable(); auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/bilinear_interp_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/bilinear_interp_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -103,7 +105,7 @@ class BilinearInterpImageCompute #endif STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int arg_idx = 0; @@ -159,6 +161,7 @@ class BilinearInterpImageCompute param_t* bilinear_interp_param_{nullptr}; std::string kernel_func_name_{"bilinear_interp"}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/concat_buffer_compute.cc b/lite/kernels/opencl/concat_buffer_compute.cc index 010e7726170ab1f40adc2fcb56a66835ac7d2bd2..414f62ff0c4f86f29756b933817de2a7682ecd4c 100644 --- a/lite/kernels/opencl/concat_buffer_compute.cc +++ b/lite/kernels/opencl/concat_buffer_compute.cc @@ -38,8 +38,10 @@ class ConcatCompute : public KernelLiteAddKernel( - kernel_func_name_, "buffer/concat_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/concat_kernel.cl", + build_options_, + time_stamp_); auto axis = concat_param_->axis; auto inputs = concat_param_->x; @@ -88,7 +90,7 @@ class ConcatCompute : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto inputs = param.x; int arg_idx = 0; @@ -177,6 +179,7 @@ class ConcatCompute : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/concat_image_compute.cc b/lite/kernels/opencl/concat_image_compute.cc index 95e64025662a4b87cd68c211ccc0b0fb7b84a9f2..60d1ac628ab1474d7e82f1861067bca838548569 100644 --- a/lite/kernels/opencl/concat_image_compute.cc +++ b/lite/kernels/opencl/concat_image_compute.cc @@ -40,8 +40,10 @@ class ConcatComputeImage : public KernelLiteAddKernel( - kernel_func_name_, "image/concat_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/concat_kernel.cl", + build_options_, + time_stamp_); auto axis = concat_param_->axis; auto inputs = concat_param_->x; @@ -117,7 +119,7 @@ class ConcatComputeImage : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto inputs = param.x; int arg_idx = 0; @@ -251,6 +253,7 @@ class ConcatComputeImage : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/conv_buffer_compute.cc b/lite/kernels/opencl/conv_buffer_compute.cc index 65477e89c7d00408bf4d639138dea936a61a3d70..4c118e1263c0d3c23eb223b01b98a8d9a53bac0e 100644 --- a/lite/kernels/opencl/conv_buffer_compute.cc +++ b/lite/kernels/opencl/conv_buffer_compute.cc @@ -114,8 +114,10 @@ void ConvCompute::PrepareForRun() { } for (size_t i = 0; i < kernel_func_names_.size(); i++) { - context.cl_context()->AddKernel( - kernel_func_names_[i], kernel_func_paths_[i], build_options_[i]); + context.cl_context()->AddKernel(kernel_func_names_[i], + kernel_func_paths_[i], + build_options_[i], + time_stamp_); } } @@ -153,7 +155,7 @@ void ConvCompute::GemmlikeConv2d() { auto& context = ctx_->As(); std::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; + kernel_key << kernel_func_names_[0] << build_options_[0] << time_stamp_; auto img2col_kernel = context.cl_context()->GetKernel(kernel_key.str()); int n_threads = c_in * h_out * w_out; @@ -218,7 +220,7 @@ void ConvCompute::GemmlikeConv2d() { int n = h_out * w_out; VLOG(4) << "m = " << m << " n = " << n << " k = " << k; kernel_key.str(""); - kernel_key << kernel_func_names_[1] << build_options_[1]; + kernel_key << kernel_func_names_[1] << build_options_[1] << time_stamp_; auto gemm_kernel = context.cl_context()->GetKernel(kernel_key.str()); GemmBatched( gemm_kernel, col_buf, filter_buf, bias_buf, output_buf, bs, m, n, k); @@ -249,7 +251,8 @@ void ConvCompute::Conv2d1x1() { auto& context = ctx_->As(); std::stringstream kernel_key; - kernel_key << kernel_func_names_.front() << build_options_.front(); + kernel_key << kernel_func_names_.front() << build_options_.front() + << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); GemmBatched(kernel, x_d, filter_d, bias_d, output_d, batch_size, m, n, k); diff --git a/lite/kernels/opencl/conv_buffer_compute.h b/lite/kernels/opencl/conv_buffer_compute.h index 44ada55d92352edf3c64cd653e832b26718cdd2f..3dabe906f128ef96fb03dfa82ab3847febaeeed5 100644 --- a/lite/kernels/opencl/conv_buffer_compute.h +++ b/lite/kernels/opencl/conv_buffer_compute.h @@ -21,6 +21,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" namespace paddle { @@ -55,6 +56,7 @@ class ConvCompute std::vector kernel_func_names_{}; std::vector kernel_func_paths_{}; std::vector build_options_{}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/conv_image_compute.cc b/lite/kernels/opencl/conv_image_compute.cc index a409690a2e37750c88395a592565b9e968e62845..aadd7010cca2ec03ea417e3b486d8c946d80fcab 100644 --- a/lite/kernels/opencl/conv_image_compute.cc +++ b/lite/kernels/opencl/conv_image_compute.cc @@ -369,15 +369,17 @@ void ConvImageCompute::PrepareForRun() { build_options_.push_back(build_options_single); for (size_t i = 0; i < kernel_func_names_.size(); i++) { - context.cl_context()->AddKernel( - kernel_func_names_[i], kernel_func_paths_[i], build_options_[i]); + context.cl_context()->AddKernel(kernel_func_names_[i], + kernel_func_paths_[i], + build_options_[i], + time_stamp_); } VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," << global_work_size_[1] << "," << global_work_size_[2] << "}"; std::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; + kernel_key << kernel_func_names_[0] << build_options_[0] << time_stamp_; kernel_ = context.cl_context()->GetKernel(kernel_key.str()); VLOG(4) << "kernel_key: " << kernel_key.str(); VLOG(4) << "kernel ready ... " << kernel_key.str(); @@ -388,18 +390,43 @@ void ConvImageCompute::PrepareForRun() { VLOG(4) << "max_work_group_size: " << max_work_group_size; - if (max_work_group_size > 0 && use_lws) { - // local_work_size_ = context.cl_context()->LocalWorkSizeConv1x1( - // global_work_size_, max_work_group_size); - local_work_size_ = context.cl_context()->LocalWorkSize(global_work_size_, - max_work_group_size); - + if (max_work_group_size > 0 && use_lws_) { + double min_turn_time = DBL_MAX; + cl::NDRange best_local_work_size = context.cl_context()->LocalWorkSize( + global_work_size_, max_work_group_size); + cl::NDRange last_local_work_size = cl::NDRange{ + static_cast(0), static_cast(0), static_cast(0)}; + if (use_turn_) { + for (size_t i = 1; i < 15; i++) { + if (kernel_h == 1 && kernel_w == 1) { + // todo use diff logics + local_work_size_ = context.cl_context()->LocalWorkSizeTurn( + global_work_size_, max_work_group_size, i); + } else { + local_work_size_ = context.cl_context()->LocalWorkSizeTurn( + global_work_size_, max_work_group_size, i); + } + if (last_local_work_size[0] == local_work_size_[0] && + last_local_work_size[1] == local_work_size_[1] && + last_local_work_size[2] == local_work_size_[2]) { + // skiped turned lws + continue; + } + auto turn_time = this->Turn(5); + if (min_turn_time > turn_time) { + min_turn_time = turn_time; + best_local_work_size = local_work_size_; + } + last_local_work_size = local_work_size_; + } + } + local_work_size_ = best_local_work_size; VLOG(4) << "local_work_size_[3D]: {" << local_work_size_[0] << "," << local_work_size_[1] << "," << local_work_size_[2] << "}"; } } -void ConvImageCompute::Conv2d1x1opt() { +void ConvImageCompute::Conv2d1x1opt(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -431,23 +458,6 @@ void ConvImageCompute::Conv2d1x1opt() { int input_c = input_dims[1]; auto dilations = *param.dilations; -// const std::vector& default_work_size = -// DefaultWorkSize(output_dims, -// DDim(std::vector{ -// static_cast(out_image_shape["width"]), -// static_cast(out_image_shape["height"])})); - -// int c_block = default_work_size[0]; -// int w = default_work_size[1]; -// int nh = default_work_size[2]; - -// int maped_w = maptofactor(w, 4); - -// auto global_work_size_ = -// cl::NDRange{static_cast(default_work_size.data()[0]), -// static_cast(maped_w), -// static_cast(default_work_size.data()[2])}; - #ifndef LITE_SHUTDOWN_LOG // VLOG(4) << "out_image: " << out_image; VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," @@ -541,73 +551,12 @@ void ConvImageCompute::Conv2d1x1opt() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); - -#ifdef PROFILE_CONV_KERNEL - bool use_profile = false; - auto GetCurrentUS = []() -> double { - struct timeval time; - gettimeofday(&time, NULL); - return 1e+6 * time.tv_sec + time.tv_usec; - }; - double start = GetCurrentUS(); - - if (use_profile) { - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); - } else { - int count = 50; - double sumtime = 0; - if (!use_profile) { - count = 1; - } - for (size_t i = 0; i < count; i++) { - start = GetCurrentUS(); - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); - if (use_profile) { - event_->wait(); - double duration = GetCurrentUS() - start; - sumtime += duration; - } - } - - auto dims_string = [](DDimLite dims) -> std::string { - std::ostringstream stream; - stream << "[" << dims[0] << "," << dims[1] << "," << dims[2] << "," - << dims[3] << "]"; - return stream.str(); - }; - if (use_profile) { - // LOG(INFO) << "input: " << input_dims; - // LOG(INFO) << "filter: " << filter_dims; - // LOG(INFO) << "output: " << output_dims; - - std::cout << std::setw(25) << std::left << dims_string(input_dims) - << std::setw(25) << std::left << dims_string(filter_dims) - << std::setw(25) << std::left << dims_string(output_dims) - << std::setw(25) << std::left << sumtime / count << std::endl; - } else { - dims_string(input_dims); - } + if (is_turn) { + event_->wait(); } -#endif } -void ConvImageCompute::Conv2d3x3() { +void ConvImageCompute::Conv2d3x3(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -767,9 +716,13 @@ void ConvImageCompute::Conv2d3x3() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::Conv2d3x3opt() { +void ConvImageCompute::Conv2d3x3opt(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -890,9 +843,12 @@ void ConvImageCompute::Conv2d3x3opt() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::Conv2d5x5() { +void ConvImageCompute::Conv2d5x5(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1018,9 +974,12 @@ void ConvImageCompute::Conv2d5x5() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::Conv2d5x5opt() { +void ConvImageCompute::Conv2d5x5opt(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1134,9 +1093,12 @@ void ConvImageCompute::Conv2d5x5opt() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::Conv2d7x7() { +void ConvImageCompute::Conv2d7x7(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1262,8 +1224,12 @@ void ConvImageCompute::Conv2d7x7() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::Conv2d7x7opt() { +void ConvImageCompute::Conv2d7x7opt(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1374,8 +1340,12 @@ void ConvImageCompute::Conv2d7x7opt() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_image, event_); + + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::DepthwiseConv2d3x3s1() { +void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1454,9 +1424,13 @@ void ConvImageCompute::DepthwiseConv2d3x3s1() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(output_img, event_); + + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::DepthwiseConv2d3x3() { +void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1548,9 +1522,13 @@ void ConvImageCompute::DepthwiseConv2d3x3() { event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(output_img, event_); + + if (is_turn) { + event_->wait(); + } } -void ConvImageCompute::DepthwiseConv2d() { +void ConvImageCompute::DepthwiseConv2d(bool is_turn) { auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const auto& param = *param_.get_mutable(); @@ -1683,8 +1661,22 @@ void ConvImageCompute::DepthwiseConv2d() { context.cl_wait_list()->emplace(out_image, event_); } -void ConvImageCompute::Run() { (this->*impl_)(); } -#undef PROFILE_CONV_KERNEL +void ConvImageCompute::Run() { (this->*impl_)(false); } + +double ConvImageCompute::Turn(int times) { + auto GetCurrentUS = []() -> double { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; + }; + auto start = GetCurrentUS(); + for (size_t i = 0; i < times; i++) { + (this->*impl_)(true); + } + auto time_diff = (GetCurrentUS() - start) / times; + return time_diff; +} + } // namespace opencl } // namespace kernels } // namespace lite diff --git a/lite/kernels/opencl/conv_image_compute.h b/lite/kernels/opencl/conv_image_compute.h index c30c271498737acf3b831d7799af1b5b316e95de..6f293a0d7dd90e55bedd63c214ba38799a591080 100644 --- a/lite/kernels/opencl/conv_image_compute.h +++ b/lite/kernels/opencl/conv_image_compute.h @@ -22,40 +22,42 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" namespace paddle { namespace lite { namespace kernels { namespace opencl { - class ConvImageCompute : public KernelLite { public: using param_t = operators::ConvParam; - using kernel_t = void (ConvImageCompute::*)(); + using kernel_t = void (ConvImageCompute::*)(bool); void PrepareForRun() override; void Run() override; + double Turn(int times = 5); private: - void Conv2d1x1opt(); - void Conv2d3x3(); - void Conv2d3x3opt(); - void Conv2d5x5(); - void Conv2d5x5opt(); - void Conv2d7x7(); - void Conv2d7x7opt(); - void DepthwiseConv2d3x3s1(); - void DepthwiseConv2d3x3(); - void DepthwiseConv2d(); + void Conv2d1x1opt(bool is_turn = false); + void Conv2d3x3(bool is_turn = false); + void Conv2d3x3opt(bool is_turn = false); + void Conv2d5x5(bool is_turn = false); + void Conv2d5x5opt(bool is_turn = false); + void Conv2d7x7(bool is_turn = false); + void Conv2d7x7opt(bool is_turn = false); + void DepthwiseConv2d3x3s1(bool is_turn = false); + void DepthwiseConv2d3x3(bool is_turn = false); + void DepthwiseConv2d(bool is_turn = false); kernel_t impl_; std::vector kernel_func_names_{}; std::vector kernel_func_paths_{}; std::vector build_options_{}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; Tensor filter_gpu_image_; Tensor bias_gpu_image_; @@ -72,7 +74,8 @@ class ConvImageCompute : public KernelLite(1), static_cast(1), static_cast(1)}; - bool use_lws{true}; + bool use_lws_{true}; + bool use_turn_{false}; }; } // namespace opencl diff --git a/lite/kernels/opencl/depthwise_conv2d_buffer_compute.cc b/lite/kernels/opencl/depthwise_conv2d_buffer_compute.cc index 0c88509926041411eddac66bea08b5d3a08d6a3c..afe2aa1c66c04d2bdf180a77362e5d6f1271c1f6 100644 --- a/lite/kernels/opencl/depthwise_conv2d_buffer_compute.cc +++ b/lite/kernels/opencl/depthwise_conv2d_buffer_compute.cc @@ -44,8 +44,10 @@ class DepthwiseConv2dCompute build_options_ += " -DRELU6"; } auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/depthwise_conv2d_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/depthwise_conv2d_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -67,7 +69,7 @@ class DepthwiseConv2dCompute param.output->mutable_data(TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); cl_int status; @@ -120,6 +122,7 @@ class DepthwiseConv2dCompute private: std::string kernel_func_name_{"depthwise_conv2d"}; std::string build_options_{"-DCL_DTYPE_float"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/dropout_image_compute.cc b/lite/kernels/opencl/dropout_image_compute.cc index 490e34a8868a3f625591a1c621aa297bb0639576..2be5af2ef0bf3e30d1c586d57ed6c3d40d625b14 100644 --- a/lite/kernels/opencl/dropout_image_compute.cc +++ b/lite/kernels/opencl/dropout_image_compute.cc @@ -40,8 +40,10 @@ class DropoutComputeImage2D : public KernelLiteAs(); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - context.cl_context()->AddKernel( - kernel_func_name_, "image/dropout_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/dropout_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -63,7 +65,7 @@ class DropoutComputeImage2D : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); cl_int status; @@ -101,6 +103,7 @@ class DropoutComputeImage2D : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/elementwise_add_buffer_compute.cc b/lite/kernels/opencl/elementwise_add_buffer_compute.cc index 3961ac7583917fdcd761614558c493e6917d3294..b70f7d1ee017566e399ac86d35df56bd4ba4d383 100644 --- a/lite/kernels/opencl/elementwise_add_buffer_compute.cc +++ b/lite/kernels/opencl/elementwise_add_buffer_compute.cc @@ -25,8 +25,10 @@ namespace opencl { void ElementwiseAddCompute::PrepareForRun() { auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/elementwise_add_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/elementwise_add_kernel.cl", + build_options_, + time_stamp_); ele_param_ = param_.get_mutable(); UpdateParams(); } @@ -39,7 +41,7 @@ void ElementwiseAddCompute::Run() { auto* out_buf = ele_param_->Out->template mutable_data( TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); #ifndef LITE_SHUTDOWN_LOG VLOG(4) << TargetToStr(ele_param_->X->target()); diff --git a/lite/kernels/opencl/elementwise_add_buffer_compute.h b/lite/kernels/opencl/elementwise_add_buffer_compute.h index 5a9266ee69b81416d5f4dea9a3eb38aaed7b4165..7dbe5d0e8d5172386418d547812bf4e6c269f043 100644 --- a/lite/kernels/opencl/elementwise_add_buffer_compute.h +++ b/lite/kernels/opencl/elementwise_add_buffer_compute.h @@ -16,6 +16,7 @@ #include #include #include "lite/core/kernel.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" @@ -46,6 +47,7 @@ class ElementwiseAddCompute param_t* ele_param_{nullptr}; std::string kernel_func_name_{"elementwise_add"}; std::string build_options_{"-DCL_DTYPE_float"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/elementwise_add_image_compute.cc b/lite/kernels/opencl/elementwise_add_image_compute.cc index 6d0ebf638f0a8967e27a657131e1cac89967ee0b..51d488d51b72dd9af8225b45a7ee56063312d055 100644 --- a/lite/kernels/opencl/elementwise_add_image_compute.cc +++ b/lite/kernels/opencl/elementwise_add_image_compute.cc @@ -23,44 +23,84 @@ namespace lite { namespace kernels { namespace opencl { -void ElementwiseAddImageCompute::PrepareForRun() { - ele_param_ = param_.get_mutable(); - auto* x = ele_param_->X; - auto* y = ele_param_->Y; - auto axis = ele_param_->axis; +void ElementwiseAddImageCompute::PrepareForRun() {} - if (y->dims().size() == 4) { - kernel_func_name_ = "elementwise_add"; // y: ImageDefault - } else if (y->dims().size() == 1) { - if (axis == x->dims().size() - 1) { - kernel_func_name_ = "width_add"; // y: ImageDefault - } else if (axis == x->dims().size() - 3) { - kernel_func_name_ = "channel_add"; // y: ImageFolder +void ElementwiseAddImageCompute::ReInitWhenNeeded() { + ele_param_ = param_.get_mutable(); + auto x_dims = ele_param_->X->dims(); + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // choose kernel + auto* x = ele_param_->X; + auto* y = ele_param_->Y; + auto* out = ele_param_->Out; + auto axis = ele_param_->axis; + + if (y->dims().size() == 4) { + kernel_func_name_ = "elementwise_add"; // y: ImageDefault + } else if (y->dims().size() == 1) { + if (axis == x->dims().size() - 1) { + kernel_func_name_ = "width_add"; // y: ImageDefault + } else if (axis == x->dims().size() - 3) { + kernel_func_name_ = "channel_add"; // y: ImageFolder + } else { + LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } } else { LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis << ", x->dims().size():" << x->dims().size() << ", y->dims.size():" << y->dims().size(); } - } else { - LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis - << ", x->dims().size():" << x->dims().size() - << ", y->dims.size():" << y->dims().size(); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/elementwise_add_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + x_img_shape_ = default_convertor.InitImageDimInfoWith(x->dims()); // w, h + y_img_shape_ = default_convertor.InitImageDimInfoWith(y->dims()); + out_img_shape_ = + default_convertor.InitImageDimInfoWith(out->dims()); // w, h + + // compute global work size + GetGlobalWorkSize(); } - VLOG(1) << "kernel_func_name_:" << kernel_func_name_; +} - auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/elementwise_add_kernel.cl", build_options_); +void ElementwiseAddImageCompute::GetGlobalWorkSize() { + global_work_size_ = cl::NDRange{static_cast(x_img_shape_[0]), + static_cast(x_img_shape_[1])}; +#ifndef LITE_SHUTDOWN_LOG + VLOG(4) << "global_work_size:[2D]:" << x_img_shape_[0] << " " + << x_img_shape_[1]; +#endif } void ElementwiseAddImageCompute::Run() { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* x = ele_param_->X; auto* y = ele_param_->Y; auto* out = ele_param_->Out; auto axis = ele_param_->axis; + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + auto* x_img = x->data(); + auto* y_img = y->data(); + auto* out_img = out->mutable_data(out_img_shape_[0], + out_img_shape_[1]); #ifndef LITE_SHUTDOWN_LOG VLOG(4) << "x->target():" << TargetToStr(x->target()); @@ -70,75 +110,53 @@ void ElementwiseAddImageCompute::Run() { VLOG(4) << "y->dims():" << y->dims(); VLOG(4) << "out->dims():" << out->dims(); VLOG(4) << "axis:" << axis; -#endif - - paddle::lite::CLImageConverterDefault default_convertor; - auto x_img_shape = default_convertor.InitImageDimInfoWith(x->dims()); // w, h - auto x_img_width = x_img_shape[0]; - auto x_img_height = x_img_shape[1]; - auto out_img_shape = - default_convertor.InitImageDimInfoWith(out->dims()); // w, h - auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims()); - auto* x_img = x->data(); - auto* y_img = y->data(); - auto* out_img = out->mutable_data(out_img_shape[0], - out_img_shape[1]); - -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height; - VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1]; - VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " " - << out_img_shape[1]; + VLOG(4) << "x_img_shape_[w,h]:" << x_img_shape_[0] << " " << x_img_shape_[1]; + VLOG(4) << "y_img_shape_[w,h]:" << y_img_shape_[0] << " " << y_img_shape_[1]; + VLOG(4) << "out_img_shape_[w,h]:" << out_img_shape_[0] << " " + << out_img_shape_[1]; #endif - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - int arg_idx = 0; - auto y_dims = y->dims(); + cl_int status; + auto kernel = kernel_; if (y_dims.size() == 4) { - cl_int status = kernel.setArg(arg_idx, *x_img); + status = kernel.setArg(0, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *y_img); + status = kernel.setArg(1, *y_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_img); + status = kernel.setArg(2, *out_img); CL_CHECK_FATAL(status); } else if (y_dims.size() == 1) { - if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) { - int tensor_w = x->dims()[x->dims().size() - 1]; + if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) { + const int tensor_w = x_dims[x_dims.size() - 1]; #ifndef LITE_SHUTDOWN_LOG VLOG(4) << "tensor_w:" << tensor_w; #endif - cl_int status = kernel.setArg(arg_idx, *x_img); + status = kernel.setArg(0, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *y_img); + status = kernel.setArg(1, *y_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_img); + status = kernel.setArg(2, *out_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(tensor_w)); + status = kernel.setArg(3, tensor_w); CL_CHECK_FATAL(status); } else { LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis - << ", x->dims().size():" << x->dims().size() - << ", y->dims.size():" << y->dims().size(); + << ", x->dims().size():" << x_dims.size() + << ", y->dims.size():" << y_dims.size(); } } else { LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis - << ", x->dims().size():" << x->dims().size() - << ", y->dims.size():" << y->dims().size(); + << ", x->dims().size():" << x_dims.size() + << ", y->dims.size():" << y_dims.size(); } - auto global_work_size = cl::NDRange{static_cast(x_img_width), - static_cast(x_img_height)}; -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; -#endif - auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, - global_work_size, + global_work_size_, cl::NullRange, nullptr, event_.get()); diff --git a/lite/kernels/opencl/elementwise_add_image_compute.h b/lite/kernels/opencl/elementwise_add_image_compute.h index 084f0fe7fb18f9abe3c6ef41f10a9e38e31a54fc..a92a1b448176628381a3c65b838f6bba529eb4e0 100644 --- a/lite/kernels/opencl/elementwise_add_image_compute.h +++ b/lite/kernels/opencl/elementwise_add_image_compute.h @@ -15,8 +15,10 @@ #include #include +#include #include "lite/backends/opencl/cl_half.h" #include "lite/core/kernel.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" @@ -34,6 +36,10 @@ class ElementwiseAddImageCompute void PrepareForRun() override; + void ReInitWhenNeeded() override; + + void GetGlobalWorkSize(); + void Run() override; std::string doc() const override { @@ -42,8 +48,21 @@ class ElementwiseAddImageCompute protected: param_t* ele_param_{nullptr}; + DDim last_x_dims_; + DDim x_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + DDim y_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + std::string kernel_func_name_{"elementwise_add"}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; + bool first_epoch_for_reinit_{true}; + cl::Kernel kernel_; + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/elementwise_mul_image_compute.cc b/lite/kernels/opencl/elementwise_mul_image_compute.cc index aa6af2a29bfdedfb5fdd3114693514b6fad13a64..96dc2de1affba7c36be6c9c0e952b85be726fca8 100644 --- a/lite/kernels/opencl/elementwise_mul_image_compute.cc +++ b/lite/kernels/opencl/elementwise_mul_image_compute.cc @@ -71,8 +71,10 @@ class ElementwiseMulImageCompute VLOG(4) << "bias_dims.size():" << bias_dims.size(); auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/elementwise_mul_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/elementwise_mul_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -114,7 +116,7 @@ class ElementwiseMulImageCompute #endif STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto bias_dims = y->dims(); @@ -201,6 +203,7 @@ class ElementwiseMulImageCompute param_t* ele_param_{nullptr}; std::string kernel_func_name_{"elementwise_mul"}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/elementwise_sub_image_compute.cc b/lite/kernels/opencl/elementwise_sub_image_compute.cc index 0bc867d7f124582660b7a0a9a95d026d910fc2d3..b93167b99c064a2f9eb2256291adad99f3912baf 100644 --- a/lite/kernels/opencl/elementwise_sub_image_compute.cc +++ b/lite/kernels/opencl/elementwise_sub_image_compute.cc @@ -49,8 +49,10 @@ void ElementwiseSubImageCompute::PrepareForRun() { VLOG(1) << "kernel_func_name_:" << kernel_func_name_; auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/elementwise_sub_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/elementwise_sub_kernel.cl", + build_options_, + time_stamp_); } void ElementwiseSubImageCompute::Run() { @@ -93,7 +95,7 @@ void ElementwiseSubImageCompute::Run() { #endif STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int arg_idx = 0; diff --git a/lite/kernels/opencl/elementwise_sub_image_compute.h b/lite/kernels/opencl/elementwise_sub_image_compute.h index 48386b083e5375f8943c04afb1da70a2ed207dbf..db3e1db9813bffd985a41abbac14e5c89e574397 100644 --- a/lite/kernels/opencl/elementwise_sub_image_compute.h +++ b/lite/kernels/opencl/elementwise_sub_image_compute.h @@ -17,6 +17,7 @@ #include #include "lite/backends/opencl/cl_half.h" #include "lite/core/kernel.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" @@ -44,6 +45,7 @@ class ElementwiseSubImageCompute param_t* ele_param_{nullptr}; std::string kernel_func_name_{"elementwise_sub"}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/fc_buffer_compute.cc b/lite/kernels/opencl/fc_buffer_compute.cc index dbdedd136ea6b8c6b06d02d4f6d893e4ea849e8a..0fb83db2fe76e27baf7a096395369cb92b995072 100644 --- a/lite/kernels/opencl/fc_buffer_compute.cc +++ b/lite/kernels/opencl/fc_buffer_compute.cc @@ -16,6 +16,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/replace_stl/stream.h" #include "lite/utils/string.h" @@ -30,74 +31,98 @@ class FcCompute public: using param_t = operators::FcParam; - void PrepareForRun() override { - const auto& param = *param_.get_mutable(); - const auto x_dims = param.input->dims(); - const auto w_dims = param.w->dims(); - - CHECK_GE(x_dims.size(), 2UL); - CHECK_GE(w_dims.size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); - - m_ = x_dims.Slice(0, param.in_num_col_dims).production(); - k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); - n_ = w_dims[1]; - CHECK_EQ(k_, static_cast(w_dims[0])); - VLOG(4) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] - << " " << x_dims[3]; - VLOG(4) << "w_dims:" << w_dims[0] << " " << w_dims[1] << " " << w_dims[2] - << " " << w_dims[3]; - VLOG(4) << "m_: " << m_ << " n_: " << n_ << " k_: " << k_; + void PrepareForRun() override {} + void ReInitWhenNeeded() override { + fc_param_ = param_.get_mutable(); + const auto x_dims = fc_param_->input->dims(); + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute m,n,k + const auto w_dims = fc_param_->w->dims(); + CHECK_GE(x_dims.size(), 2UL); + CHECK_GE(w_dims.size(), 2UL); + CHECK_EQ(fc_param_->output->dims().size(), 2UL); + + m_ = x_dims.Slice(0, fc_param_->in_num_col_dims).production(); + k_ = x_dims.Slice(fc_param_->in_num_col_dims, x_dims.size()).production(); + n_ = w_dims[1]; + CHECK_EQ(k_, static_cast(w_dims[0])); + +#ifndef LITE_SHUTDOWN_LOG + VLOG(4) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] + << " " << x_dims[3]; + VLOG(4) << "w_dims:" << w_dims[0] << " " << w_dims[1] << " " << w_dims[2] + << " " << w_dims[3]; + VLOG(4) << "m_: " << m_ << " n_: " << n_ << " k_: " << k_; +#endif + + // choose kernel + if (m_ == 1) { // gemv + kernel_func_name_ = "fc_gemv_1x4"; + } else { // gemm + kernel_func_name_ = "fc_gemm_4x4"; + } +#ifndef LITE_SHUTDOWN_LOG + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; +#endif + + if (fc_param_->activation_type == "relu") { + build_options_ += "-DRELU"; + } + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/fc_kernel.cl", + build_options_, + time_stamp_); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + + // compute global work size + GetGlobalWorkSize(); + } + } + + void GetGlobalWorkSize() { if (m_ == 1) { // gemv - kernel_func_name_ = "fc_gemv_1x4"; global_work_size_ = cl::NDRange{static_cast((n_ + 3) / 4)}; } else { // gemm - kernel_func_name_ = "fc_gemm_4x4"; global_work_size_ = cl::NDRange{static_cast((m_ + 3) / 4), static_cast((n_ + 3) / 4)}; } - VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - - if (param.activation_type == "relu") { - build_options_ += "-DRELU"; - } - auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/fc_kernel.cl", build_options_); } void Run() override { - const auto& param = *param_.get_mutable(); - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* x_buf = param.input->data(); - auto* w_buf = param.w->data(); - auto* bias_buf = param.bias->data(); + auto* x_buf = fc_param_->input->data(); + auto* w_buf = fc_param_->w->data(); + auto* bias_buf = fc_param_->bias->data(); auto* out_buf = - param.output->mutable_data(TARGET(kOpenCL)); - - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + fc_param_->output->mutable_data(TARGET(kOpenCL)); + auto kernel = kernel_; cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, *x_buf); + status = kernel.setArg(0, *x_buf); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *w_buf); + status = kernel.setArg(1, *w_buf); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *bias_buf); + status = kernel.setArg(2, *bias_buf); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_buf); + status = kernel.setArg(3, *out_buf); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(m_)); + status = kernel.setArg(4, static_cast(m_)); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(n_)); + status = kernel.setArg(5, static_cast(n_)); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(k_)); + status = kernel.setArg(6, static_cast(k_)); CL_CHECK_FATAL(status); + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, @@ -111,9 +136,14 @@ class FcCompute private: int m_, n_, k_; + param_t* fc_param_{nullptr}; std::string kernel_func_name_{}; std::string build_options_{"-DCL_DTYPE_float "}; + std::string time_stamp_{GetTimeStamp()}; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; cl::NDRange global_work_size_; + cl::Kernel kernel_; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/fusion_elementwise_add_activation_buffer_compute.cc b/lite/kernels/opencl/fusion_elementwise_add_activation_buffer_compute.cc index d76e00fa85d4ebb6da9d779e9c2b220a2fd731d9..730b70525e818512aea11e1f42c1282b125aae54 100644 --- a/lite/kernels/opencl/fusion_elementwise_add_activation_buffer_compute.cc +++ b/lite/kernels/opencl/fusion_elementwise_add_activation_buffer_compute.cc @@ -28,8 +28,10 @@ class FusionElementwiseAddActivationCompute : public ElementwiseAddCompute { void PrepareForRun() override { build_options_ += " -DRELU"; auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/elementwise_add_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/elementwise_add_kernel.cl", + build_options_, + time_stamp_); ele_param_ = param_.get_mutable(); UpdateParams(); auto act_t = static_cast(ele_param_)->act_type; diff --git a/lite/kernels/opencl/fusion_elementwise_add_activation_image_compute.cc b/lite/kernels/opencl/fusion_elementwise_add_activation_image_compute.cc index e5c0e29bddf5cd6c25ccf98f05aa7cb091a4be7e..8e687340943dcb0f1b68e4c9495cbab1ad703645 100644 --- a/lite/kernels/opencl/fusion_elementwise_add_activation_image_compute.cc +++ b/lite/kernels/opencl/fusion_elementwise_add_activation_image_compute.cc @@ -16,6 +16,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/op_registry.h" #include "lite/kernels/opencl/elementwise_add_image_compute.h" +#include "lite/kernels/opencl/image_helper.h" namespace paddle { namespace lite { @@ -30,8 +31,10 @@ class FusionElementwiseAddActivationImageCompute void PrepareForRun() override { build_options_ += " -DRELU"; auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/elementwise_add_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/elementwise_add_kernel.cl", + build_options_, + time_stamp_); ele_param_ = param_.get_mutable(); auto act_t = static_cast(ele_param_)->act_type; VLOG(4) << "act: " << act_t; diff --git a/lite/kernels/opencl/grid_sampler_image_compute.cc b/lite/kernels/opencl/grid_sampler_image_compute.cc index 243737a81331a7159834d30ccfb2fab181baeebe..4fb13a61181ba282f7005ea158768ee18b94b7a0 100644 --- a/lite/kernels/opencl/grid_sampler_image_compute.cc +++ b/lite/kernels/opencl/grid_sampler_image_compute.cc @@ -39,96 +39,120 @@ class GridSamplerImageCompute : public KernelLiteAs(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/grid_sampler_kernel.cl", + build_options_, + time_stamp_); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + VLOG(4) << "kernel_key: " << kernel_key.str(); + } + + void ReInitWhenNeeded() override { grid_param_ = param_.get_mutable(); + auto x_dims = grid_param_->x->dims(); + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = + default_convertor.InitImageDimInfoWith(grid_param_->out->dims()); + + // compute global work size + GetGlobalWorkSize(); + } + } - auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_); - VLOG(4) << "kernel_func_name_:" << kernel_func_name_; + void GetGlobalWorkSize() { + auto default_work_size = + DefaultWorkSize(grid_param_->out->dims(), + DDim(std::vector{ + static_cast(out_img_shape_[0]), + static_cast(out_img_shape_[1])})); + global_work_size_ = + cl::NDRange{static_cast(default_work_size[0]), + static_cast(default_work_size[1]), + static_cast(default_work_size[2] / 4)}; +#ifndef LITE_SHUTDOWN_LOG + VLOG(4) << "default_work_size: " << default_work_size[0] << ", " + << default_work_size[1] << ", " << default_work_size[2]; + VLOG(4) << "global_work_size_:[2D]:" << global_work_size_[0] << " " + << global_work_size_[1] << " " << global_work_size_[2]; +#endif } void Run() override { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* x = grid_param_->x; - auto* out = grid_param_->out; auto* grid = grid_param_->grid; + auto* out = grid_param_->out; + auto out_dims = out->dims(); - auto in_dims = x->dims(); + int out_height = out_dims[2]; + int out_width = out_dims[3]; + + auto* x_img = x->data(); + auto* grid_img = x->data(); + auto* out_img = out->mutable_data(out_img_shape_[0], + out_img_shape_[1]); #ifndef LITE_SHUTDOWN_LOG + auto in_dims = x->dims(); VLOG(4) << "x->target():" << TargetToStr(x->target()); VLOG(4) << "out->target():" << TargetToStr(out->target()); VLOG(4) << "x->dims():" << in_dims; VLOG(4) << "out->dims():" << out_dims; -#endif - - auto out_image_shape = InitImageDimInfoWith(out_dims); - auto* x_img = x->data(); // VLOG(4) << "x_image: " << x_img; - - auto* grid_img = x->data(); // VLOG(4) << "grid_img: " << grid_img; - - auto* out_img = out->mutable_data( - out_image_shape["width"], out_image_shape["height"]); -#ifndef LITE_SHUTDOWN_LOG // VLOG(4) << "out_image" << out_img; - VLOG(4) << "out_image_shape[w,h]:" << out_image_shape["width"] << " " - << out_image_shape["height"]; + VLOG(4) << "out_img_shape_[w,h]:" << out_img_shape_[0] << " " + << out_img_shape_[1]; #endif - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - int arg_idx = 0; - int out_height = out_dims[2]; - int out_width = out_dims[3]; - auto default_work_size = - DefaultWorkSize(out_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "default_work_size: " << default_work_size[0] << ", " - << default_work_size[1] << ", " << default_work_size[2]; -#endif - cl_int status = kernel.setArg(arg_idx++, *x_img); + cl_int status; + auto kernel = kernel_; + status = kernel.setArg(0, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(arg_idx++, *grid_img); + status = kernel.setArg(1, *grid_img); CL_CHECK_FATAL(status); - status = kernel.setArg(arg_idx++, *out_img); + status = kernel.setArg(2, *out_img); CL_CHECK_FATAL(status); - status = kernel.setArg(arg_idx++, out_height); + status = kernel.setArg(3, out_height); CL_CHECK_FATAL(status); - status = kernel.setArg(arg_idx++, out_width); + status = kernel.setArg(4, out_width); CL_CHECK_FATAL(status); - auto global_work_size = - cl::NDRange{static_cast(default_work_size[0]), - static_cast(default_work_size[1]), - static_cast(default_work_size[2] / 4)}; - + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, - global_work_size, + global_work_size_, cl::NullRange, nullptr, event_.get()); CL_CHECK_FATAL(status); context.cl_wait_list()->emplace(out_img, event_); -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " - << global_work_size[1] << " " << global_work_size[2]; -#endif } protected: param_t* grid_param_{nullptr}; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); std::string kernel_func_name_{"grid_sampler"}; + cl::Kernel kernel_; + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/image_helper.h b/lite/kernels/opencl/image_helper.h index d0d282250d1c5658bc8f684b52b4b0d140895833..81d38bc683eb355b1d85a307d35839b4e3e8ef45 100644 --- a/lite/kernels/opencl/image_helper.h +++ b/lite/kernels/opencl/image_helper.h @@ -74,6 +74,12 @@ static std::vector DefaultWorkSize(const DDim& image_dim, LOG(FATAL) << " not support this dim, need imp "; } +static const std::string GetTimeStamp() { + struct timeval time; + gettimeofday(&time, NULL); + return std::to_string(time.tv_usec); +} + } // namespace opencl } // namespace kernels } // namespace lite diff --git a/lite/kernels/opencl/instance_norm_image_compute.cc b/lite/kernels/opencl/instance_norm_image_compute.cc index 6bdec0ca6cdfd16219becf704de4d5701aad3197..c5e02ae0ed4ae9facf36747d99ee825e6eab6515 100644 --- a/lite/kernels/opencl/instance_norm_image_compute.cc +++ b/lite/kernels/opencl/instance_norm_image_compute.cc @@ -60,8 +60,10 @@ class InstanceNormImageCompute : public KernelLiteAs(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/instance_norm_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/instance_norm_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -115,7 +117,7 @@ class InstanceNormImageCompute : public KernelLiteGetKernel(kernel_key.str()); cl_int status = kernel.setArg(0, out_w); @@ -180,8 +182,10 @@ class InstanceNormImageCompute : public KernelLite( scale_img_size[0], scale_img_size[1], bias_img.data()); auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/instance_norm_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/instance_norm_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -234,7 +238,7 @@ class InstanceNormImageCompute : public KernelLiteGetKernel(kernel_key.str()); auto* scale_img = scale_image_.data(); auto* bias_img = bias_image_.data(); @@ -271,6 +275,7 @@ class InstanceNormImageCompute : public KernelLite event_{new cl::Event}; Tensor scale_image_; Tensor bias_image_; diff --git a/lite/kernels/opencl/lrn_image_compute.cc b/lite/kernels/opencl/lrn_image_compute.cc index edce0368ddc9cda54fdab44b472fcd0e771413ae..0e01bdc107c4fcb4a0caf943cfb1b768557dd671 100644 --- a/lite/kernels/opencl/lrn_image_compute.cc +++ b/lite/kernels/opencl/lrn_image_compute.cc @@ -48,7 +48,7 @@ class LrnImageCompute : public KernelLitebeta; norm_region_ = lrn_param_->norm_region; context.cl_context()->AddKernel( - kernel_func_name_, "image/lrn_kernel.cl", build_options_); + kernel_func_name_, "image/lrn_kernel.cl", build_options_, time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -91,7 +91,7 @@ class LrnImageCompute : public KernelLiteGetKernel(kernel_key.str()); int arg_idx = 0; @@ -152,6 +152,7 @@ class LrnImageCompute : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/mul_buffer_compute.cc b/lite/kernels/opencl/mul_buffer_compute.cc index 4c46da67da9877fb37b214b6d738b3dd3da3e5bb..e8edb359898fb47cf47919a25e521ca9f8353104 100644 --- a/lite/kernels/opencl/mul_buffer_compute.cc +++ b/lite/kernels/opencl/mul_buffer_compute.cc @@ -16,6 +16,7 @@ #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" #include "lite/operators/op_params.h" #include "lite/utils/replace_stl/stream.h" #include "lite/utils/string.h" @@ -32,8 +33,10 @@ class MulCompute void PrepareForRun() override { auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/mat_mul_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/mat_mul_kernel.cl", + build_options_, + time_stamp_); const auto& param = *param_.get_mutable(); const auto* x_data = param.x->data(); const auto* y_data = param.y->data(); @@ -68,7 +71,7 @@ class MulCompute param.output->mutable_data(TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); cl_int status; @@ -103,6 +106,7 @@ class MulCompute int m_, n_, k_; std::string kernel_func_name_{"mat_mul"}; std::string build_options_{"-DCL_DTYPE_float"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/nearest_interp_image_compute.cc b/lite/kernels/opencl/nearest_interp_image_compute.cc index 082f21ab1ae792ae33e9e2a368073274258b8884..17637e2569556d1eeb8b6002c0073223345ac7ec 100644 --- a/lite/kernels/opencl/nearest_interp_image_compute.cc +++ b/lite/kernels/opencl/nearest_interp_image_compute.cc @@ -38,8 +38,10 @@ class NearestInterpComputeImageDefault void PrepareForRun() override { auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/nearest_interp_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/nearest_interp_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -66,7 +68,7 @@ class NearestInterpComputeImageDefault auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int arg_idx = 0; @@ -121,6 +123,7 @@ class NearestInterpComputeImageDefault private: std::string kernel_func_name_{"nearest_interp"}; std::string build_options_{" -DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/pad2d_image_compute.cc b/lite/kernels/opencl/pad2d_image_compute.cc index 1be4729ee1b24ac77383de4d7c111e9d37d29d6b..f16642d449d29c2afd3db7097432945c73d107e3 100644 --- a/lite/kernels/opencl/pad2d_image_compute.cc +++ b/lite/kernels/opencl/pad2d_image_compute.cc @@ -52,8 +52,10 @@ class Pad2dCompute : public KernelLiteAs(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/pad2d_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/pad2d_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; } @@ -93,7 +95,7 @@ class Pad2dCompute : public KernelLiteGetKernel(kernel_key.str()); int arg_idx = 0; @@ -159,6 +161,7 @@ class Pad2dCompute : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/pool_buffer_compute.cc b/lite/kernels/opencl/pool_buffer_compute.cc index 3f491afb86d4e4d5144522b6fb028c225c9a97e4..aeba4bcd2ea1d9b1f14ac86509ab9dbec2509ad0 100644 --- a/lite/kernels/opencl/pool_buffer_compute.cc +++ b/lite/kernels/opencl/pool_buffer_compute.cc @@ -37,8 +37,10 @@ class PoolCompute const auto& param = *param_.get_mutable(); kernel_func_name_ += param.pooling_type; auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "buffer/pool_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "buffer/pool_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -69,7 +71,7 @@ class PoolCompute auto* output_buf = param.output->mutable_data(TARGET(kOpenCL)); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); cl_int status; auto numel = out_dims.production(); @@ -117,6 +119,7 @@ class PoolCompute private: std::string kernel_func_name_{"pool_"}; std::string build_options_{"-DCL_DTYPE_float"}; + std::string time_stamp_{GetTimeStamp()}; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/pool_image_compute.cc b/lite/kernels/opencl/pool_image_compute.cc index 39da325ebb10c85f153e349173aa833bbf5e1f6e..34524122c8e475df63db02eae32b7d100abfa2d9 100644 --- a/lite/kernels/opencl/pool_image_compute.cc +++ b/lite/kernels/opencl/pool_image_compute.cc @@ -47,7 +47,7 @@ class PoolComputeImage2D : public KernelLiteAs(); context.cl_context()->AddKernel( - kernel_func_name_, "image/pool_kernel.cl", build_options_); + kernel_func_name_, "image/pool_kernel.cl", build_options_, time_stamp_); } void Run() override { @@ -112,7 +112,7 @@ class PoolComputeImage2D : public KernelLiteGetKernel(kernel_key.str()); int c_block = (out_dims[1] + 3) / 4; @@ -164,6 +164,7 @@ class PoolComputeImage2D : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/reshape_image_compute.cc b/lite/kernels/opencl/reshape_image_compute.cc index 376add226216a57a0868c9c52497b784929a207e..febb1c33d9c4df2cb58580a03bda1eff93ed4da7 100644 --- a/lite/kernels/opencl/reshape_image_compute.cc +++ b/lite/kernels/opencl/reshape_image_compute.cc @@ -36,8 +36,10 @@ class ReshapeComputeFloatImage : public KernelLiteAs(); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - context.cl_context()->AddKernel( - kernel_func_name_, "image/reshape_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/reshape_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -110,7 +112,7 @@ class ReshapeComputeFloatImage : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); #ifndef LITE_SHUTDOWN_LOG @@ -166,6 +168,7 @@ class ReshapeComputeFloatImage : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/scale_image_compute.cc b/lite/kernels/opencl/scale_image_compute.cc index 5fd9a2b46b5ce3b0ad84449785f510d5f0391250..97b56e68d47fcdf1647433f5e267c264fb36c5c2 100644 --- a/lite/kernels/opencl/scale_image_compute.cc +++ b/lite/kernels/opencl/scale_image_compute.cc @@ -37,53 +37,66 @@ class ScaleComputeImage2D : public KernelLiteAs(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/scale_kernel.cl", + build_options_, + time_stamp_); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - context.cl_context()->AddKernel( - kernel_func_name_, "image/scale_kernel.cl", build_options_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + scale_param_ = param_.get_mutable(); + auto x_dims = scale_param_->x->dims(); + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = + default_convertor.InitImageDimInfoWith(scale_param_->output->dims()); + + // compute global work size + GetGlobalWorkSize(); + } + } + + void GetGlobalWorkSize() { + global_work_size_ = + cl::NDRange{static_cast(out_img_shape_[0]), + static_cast(out_img_shape_[1])}; } void Run() override { - const auto& param = *param_.get_mutable(); - const auto& in_dims = param.x->dims(); - auto* x_img = param.x->data(); - const float scale = param.scale; - const float bias = param.bias; - - // LOG(INFO) << "x_image" << x_img; - auto out_image_shape = InitImageDimInfoWith(in_dims); -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "out_image_shape = " << out_image_shape["width"] << " " - << out_image_shape["height"]; -#endif - auto* out_img = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - // LOG(INFO) << "out_image" << out_img; + auto* x_img = scale_param_->x->data(); + auto* out_img = scale_param_->output->mutable_data( + out_img_shape_[0], out_img_shape_[1]); + const float scale = scale_param_->scale; + const float bias = scale_param_->bias; auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - auto global_work_size = - cl::NDRange{static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])}; + auto kernel = kernel_; cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, *x_img); + status = kernel.setArg(0, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_img); + status = kernel.setArg(1, *out_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, scale); + status = kernel.setArg(2, scale); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, bias); + status = kernel.setArg(3, bias); CL_CHECK_FATAL(status); status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, - global_work_size, + global_work_size_, cl::NullRange, nullptr, event_.get()); @@ -94,7 +107,17 @@ class ScaleComputeImage2D : public KernelLite event_{new cl::Event}; + + param_t* scale_param_{nullptr}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; }; } // namespace opencl diff --git a/lite/kernels/opencl/slice_image_compute.cc b/lite/kernels/opencl/slice_image_compute.cc index 149ef35afe3d49ca8793769ee7ad366292462296..dd231ec8647ba88ab0f953661af47bc36c948e8b 100644 --- a/lite/kernels/opencl/slice_image_compute.cc +++ b/lite/kernels/opencl/slice_image_compute.cc @@ -38,8 +38,10 @@ class SliceComputeImage2D : public KernelLiteAs(); VLOG(1) << "kernel_func_name_:" << kernel_func_name_; - context.cl_context()->AddKernel( - kernel_func_name_, "image/slice_kernel.cl", build_options_); + context.cl_context()->AddKernel(kernel_func_name_, + "image/slice_kernel.cl", + build_options_, + time_stamp_); } void Run() override { @@ -68,7 +70,7 @@ class SliceComputeImage2D : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); cl_int status; @@ -108,6 +110,7 @@ class SliceComputeImage2D : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt index 29cb83b2b853d4953bfbe7faca8633f2789e1d50..93f3cdb445af7b75adc76294b287d9963f4e3cca 100644 --- a/lite/kernels/xpu/bridges/CMakeLists.txt +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -25,6 +25,7 @@ lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${x lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_yolo_box_op_xpu SRCS yolo_box_op.cc DEPS ${xpu_subgraph_bridge_deps}) set(xpu_subgraph_bridges subgraph_bridge_registry @@ -48,6 +49,7 @@ set(xpu_subgraph_bridges subgraph_bridge_dropout_op_xpu subgraph_bridge_matmul_op_xpu subgraph_bridge_cast_op_xpu + subgraph_bridge_yolo_box_op_xpu CACHE INTERNAL "xpu_subgraph_bridges") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") diff --git a/lite/kernels/xpu/bridges/paddle_use_bridges.h b/lite/kernels/xpu/bridges/paddle_use_bridges.h index 0c7886c5b2b431db7ba97d8557fb6a49750bd468..cf896426f7a40ae17cd73547071f86dcfa738839 100644 --- a/lite/kernels/xpu/bridges/paddle_use_bridges.h +++ b/lite/kernels/xpu/bridges/paddle_use_bridges.h @@ -37,3 +37,4 @@ USE_SUBGRAPH_BRIDGE(gelu, kXPU); USE_SUBGRAPH_BRIDGE(dropout, kXPU); USE_SUBGRAPH_BRIDGE(matmul, kXPU); USE_SUBGRAPH_BRIDGE(cast, kXPU); +USE_SUBGRAPH_BRIDGE(yolo_box, kXPU); diff --git a/lite/kernels/xpu/bridges/yolo_box_op.cc b/lite/kernels/xpu/bridges/yolo_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1b7c014702aa8530d5b502bb6d32825e7bb13b2 --- /dev/null +++ b/lite/kernels/xpu/bridges/yolo_box_op.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/graph.h" +#include "lite/kernels/xpu/bridges/utility.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace xpu { + +int YoloBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[XPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindTensor(x_name); + + auto img_size_name = op_info->Input("ImgSize").front(); + auto img_size = scope->FindTensor(img_size_name); + + auto boxes_name = op_info->Output("Boxes").front(); + auto scores_name = op_info->Output("Scores").front(); + + auto anchors = op_info->GetAttr>("anchors"); + auto class_num = op_info->GetAttr("class_num"); + auto conf_thresh = op_info->GetAttr("conf_thresh"); + auto downsample_ratio = op_info->GetAttr("downsample_ratio"); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x); + } + + // ImgSize node + std::shared_ptr img_size_node = nullptr; + if (graph->Has(img_size_name)) { + img_size_node = graph->Get(img_size_name); + } else { + img_size_node = graph->Add(img_size_name, *img_size); + } + + // Softmax node + auto yolo_box_data = + graph->builder_.CreateYoloBox(*x_node->data(), + *img_size_node->data(), + CvtShape(anchors), + class_num, + conf_thresh, + downsample_ratio); + graph->Add(boxes_name, graph->builder_.GetField(yolo_box_data, 0)); + graph->Add(scores_name, graph->builder_.GetField(yolo_box_data, 1)); + + return SUCCESS; +} + +} // namespace xpu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(yolo_box, + kXPU, + paddle::lite::subgraph::xpu::YoloBoxConverter); diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 1b6d374f7396cf1e4e91bfe786603005fb0ff8dc..9c2191331c85a7f99ffb5a2e9662ed5831cb1dda 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -34,7 +34,7 @@ int SubgraphEngine::BuildDeviceProgram() { subgraph::xpu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); for (auto& inst : origin_program_) { - auto op = inst.op(); + auto op = const_cast(inst.op()); CHECK(op); op->CheckShape(); op->InferShape(); @@ -43,10 +43,8 @@ int SubgraphEngine::BuildDeviceProgram() { return subgraph::FAILED; } auto kernel = inst.kernel(); - status |= - bridges.Select(op_type, TARGET(kXPU))(reinterpret_cast(&graph), - const_cast(op), - const_cast(kernel)); + status |= bridges.Select(op_type, TARGET(kXPU))( + reinterpret_cast(&graph), op, const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { return subgraph::FAILED; } diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index 08e6a303094dc42278bfcb24c54f16bd3819d5c1..7f938577c3b2e53257d4fb79686a0bf8c6a67ad5 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -382,7 +382,7 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { pb_dims->Resize(static_cast(dims.size()), 0); auto dims_vec = dims.Vectorize(); std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin()); - int32_t size = desc.ByteSize(); + int32_t size = desc.ByteSizeLong(); os.write(reinterpret_cast(&size), sizeof(size)); auto out = desc.SerializeAsString(); os.write(out.data(), size); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 48e27560317c089446e8dbc5040786f34ca962c4..ae9ec3ad47fbc00c91ba06c1597bd65e510b629b 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -141,13 +141,12 @@ add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) # 4. training op add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) -if (LITE_WITH_TRAIN) - add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) - add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) - add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) - add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) - add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) -endif() + +add_operator(mean_grad_op train SRCS mean_grad_op.cc DEPS ${op_DEPS}) +add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEPS}) +add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) +add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS}) +add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc diff --git a/lite/operators/activation_grad_ops.cc b/lite/operators/activation_grad_ops.cc index 9a37a5f0a178192ead00801632914a8f446f058f..b31163e5dce6d9b77d923ba44ed58952263610a5 100644 --- a/lite/operators/activation_grad_ops.cc +++ b/lite/operators/activation_grad_ops.cc @@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const { return true; } -bool ActivationGradOp::InferShape() const { +bool ActivationGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.Out_grad->dims()); return true; } diff --git a/lite/operators/activation_grad_ops.h b/lite/operators/activation_grad_ops.h index 5421b3247ff844e20931a6a15b85eb7da85e7f69..cf928cfe1bf9945a1dd0474408472759a499b5d7 100644 --- a/lite/operators/activation_grad_ops.h +++ b/lite/operators/activation_grad_ops.h @@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index f7a326358bb30d747c949d7bacdebb47846562b5..13abe0c53e95363e7f54c56819eaac26ef720072 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const { return true; } -bool ActivationOp::InferShape() const { +bool ActivationOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); auto out_lod = param_.Out->mutable_lod(); *out_lod = param_.X->lod(); @@ -71,6 +71,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { } else if (opdesc.Type() == "exp") { // exp param_.active_type = lite_api::ActivationType::kExp; + } else if (opdesc.Type() == "abs") { + // abs + param_.active_type = lite_api::ActivationType::kAbs; } VLOG(4) << "opdesc.Type():" << opdesc.Type(); @@ -92,6 +95,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp); diff --git a/lite/operators/activation_ops.h b/lite/operators/activation_ops.h index 34099ab0fdb422f523e383dc0dd286acf24b2731..8f81b12af03052e558e7faa2e813039d4dee8988 100644 --- a/lite/operators/activation_ops.h +++ b/lite/operators/activation_ops.h @@ -26,7 +26,7 @@ class ActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/affine_channel_op.cc b/lite/operators/affine_channel_op.cc index c4945ababd2fdf3b0f1b25d26eb0f66c8f613b21..447079deb33bdb893b99901d8559d6961489789d 100644 --- a/lite/operators/affine_channel_op.cc +++ b/lite/operators/affine_channel_op.cc @@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const { return true; } -bool AffineChannelOpLite::InferShape() const { +bool AffineChannelOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); param_.Out->Resize(x_dims); return true; diff --git a/lite/operators/affine_channel_op.h b/lite/operators/affine_channel_op.h index 85a043bdc8e1c6f41c27b2e57555d3454322f789..5a3d9d66259d477d42ac00e0e1b1a7ba1bf2e862 100644 --- a/lite/operators/affine_channel_op.h +++ b/lite/operators/affine_channel_op.h @@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/anchor_generator_op.cc b/lite/operators/anchor_generator_op.cc index 8daa54905fcf7cf52259840c26198721d6b8f0fa..e57a4b2df8c75afd28506b5e0e2f7b7aa142b838 100644 --- a/lite/operators/anchor_generator_op.cc +++ b/lite/operators/anchor_generator_op.cc @@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const { return true; } -bool AnchorGeneratorOpLite::InferShape() const { +bool AnchorGeneratorOpLite::InferShapeImpl() const { auto input_dims = param_.Input->dims(); size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size(); std::vector output_shape( diff --git a/lite/operators/anchor_generator_op.h b/lite/operators/anchor_generator_op.h index 46e5e0fac243c10b62122327ef06ea166878e54f..2ff3422824c15b54ed1fa3ca9952745d5b1706ac 100644 --- a/lite/operators/anchor_generator_op.h +++ b/lite/operators/anchor_generator_op.h @@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/argmax_op.cc b/lite/operators/argmax_op.cc index 772cc446077e5e896b757051fae9f9b8f59df1d8..b733998ae57785483f539b56dcb47b7b50f04cf0 100644 --- a/lite/operators/argmax_op.cc +++ b/lite/operators/argmax_op.cc @@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const { return true; } -bool ArgmaxOpLite::InferShape() const { +bool ArgmaxOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); int x_rank = x_dims.size(); int axis = param_.Axis; diff --git a/lite/operators/argmax_op.h b/lite/operators/argmax_op.h index a5accc97e3b9f3bb2fbd00f45fd3a45063e5c747..e6944507cf9f6ded86ccbae7c3cec79106e8ba98 100644 --- a/lite/operators/argmax_op.h +++ b/lite/operators/argmax_op.h @@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_op.cc b/lite/operators/assign_op.cc index 8510b7e8b7b8a5732e0e09d3db494ab3eb9f15a8..25e8539d2e55a07a19d707713489d86f84aa64db 100644 --- a/lite/operators/assign_op.cc +++ b/lite/operators/assign_op.cc @@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const { return true; } -bool AssignOpLite::InferShape() const { +bool AssignOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/assign_op.h b/lite/operators/assign_op.h index 555356c3659ff31c84b2630c1f5da6acab003823..9e7039bb5b0088a6bda6acbf2baf7a50444df8b2 100644 --- a/lite/operators/assign_op.h +++ b/lite/operators/assign_op.h @@ -30,7 +30,7 @@ class AssignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/assign_value_op.cc b/lite/operators/assign_value_op.cc index 046c5222283fc73bd3af1e53520b1fc5539bcd31..ff5b55735f7b58aa2eaa2274574336dadd8061e6 100644 --- a/lite/operators/assign_value_op.cc +++ b/lite/operators/assign_value_op.cc @@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const { return true; } -bool AssignValueOpLite::InferShape() const { +bool AssignValueOpLite::InferShapeImpl() const { std::vector shape = param_.shape; std::vector out_shape; for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); diff --git a/lite/operators/assign_value_op.h b/lite/operators/assign_value_op.h index 7bf220615935f02051ed606adb894bf9842378f3..030da048184c9862b76f59198574b394457768d5 100644 --- a/lite/operators/assign_value_op.h +++ b/lite/operators/assign_value_op.h @@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/attention_padding_mask_op.cc b/lite/operators/attention_padding_mask_op.cc index a88df0e7a902c6cac63eb77377bb0b49ee30c9b3..2f3a0cd265c56ac24548e23ff3daf09e27e1d800 100644 --- a/lite/operators/attention_padding_mask_op.cc +++ b/lite/operators/attention_padding_mask_op.cc @@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const { return true; } -bool AttentionPaddingMaskOp::InferShape() const { +bool AttentionPaddingMaskOp::InferShapeImpl() const { auto src_len = param_.X->lod()[0][1]; CHECK_EQ(src_len, param_.X->dims()[1]) << "Mismatch source length, expect: " << src_len diff --git a/lite/operators/attention_padding_mask_op.h b/lite/operators/attention_padding_mask_op.h index 894d68f6226720139aee07274d4ac5cf660749f1..6a2443fc6749d4f2066ee761fd194441e2fe46cd 100644 --- a/lite/operators/attention_padding_mask_op.h +++ b/lite/operators/attention_padding_mask_op.h @@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/axpy_op.cc b/lite/operators/axpy_op.cc index 60f302862afa47ca75ae703e7b848bb3a0e7604c..c1c6304c3119f89bdc46400b2478a767c914d001 100644 --- a/lite/operators/axpy_op.cc +++ b/lite/operators/axpy_op.cc @@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const { return true; } -bool AxpyOpLite::InferShape() const { +bool AxpyOpLite::InferShapeImpl() const { auto dims = param_.Bias->dims(); // Set output dims diff --git a/lite/operators/axpy_op.h b/lite/operators/axpy_op.h index 1fa8540743f65db864f33633003b4ed8f6d8cb92..e9d9f44ca5f5843628af998d9140519a3f3a1c29 100644 --- a/lite/operators/axpy_op.h +++ b/lite/operators/axpy_op.h @@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index eca7fa6001dda7835213c60be1d21eedff301ae4..67e037fba349e811f1faf991c84310b11ab7a13c 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const { return true; } -bool BatchNormOp::InferShape() const { +bool BatchNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t channel_size = 0; switch (param_.data_layout) { diff --git a/lite/operators/batch_norm_op.h b/lite/operators/batch_norm_op.h index 21dbf9a28a4257acdd80ac6c49d111cdd757b65d..9598763713564192ed4ad0c99200f0fdb1d88d37 100644 --- a/lite/operators/batch_norm_op.h +++ b/lite/operators/batch_norm_op.h @@ -30,7 +30,7 @@ class BatchNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_decode_op.cc b/lite/operators/beam_search_decode_op.cc index 52888d8a99c0f6507862f515c633f04d4fe09c39..444c9d6a11217c3134c3cb1f988c60c4b98d4566 100644 --- a/lite/operators/beam_search_decode_op.cc +++ b/lite/operators/beam_search_decode_op.cc @@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const { return true; } -bool BeamSearchDecodeOpLite::InferShape() const { return true; } +bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; } bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/beam_search_decode_op.h b/lite/operators/beam_search_decode_op.h index 9d324d2bf0974fe5b65711c4ab2dacaf0d0d65d9..38bf9929ab12ba764fcd3fe6cacc7c08f35c15ca 100644 --- a/lite/operators/beam_search_decode_op.h +++ b/lite/operators/beam_search_decode_op.h @@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/beam_search_op.cc b/lite/operators/beam_search_op.cc index c998e002ee3d6b8f3196fdfa212462dac4da0969..ea777ad53395aba1c7d6c21b07013e374b03c1f4 100644 --- a/lite/operators/beam_search_op.cc +++ b/lite/operators/beam_search_op.cc @@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const { return true; } -bool BeamSearchOp::InferShape() const { return true; } +bool BeamSearchOp::InferShapeImpl() const { return true; } bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front()); diff --git a/lite/operators/beam_search_op.h b/lite/operators/beam_search_op.h index 42a6058de112215f525b51bfff6ff16aae04391d..7e325cb55668a77cf09466e86be220218a49cbee 100644 --- a/lite/operators/beam_search_op.h +++ b/lite/operators/beam_search_op.h @@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_clip_op.cc b/lite/operators/box_clip_op.cc index 6bd93c6ea4e2efc93fdc7e64f1738c2ac3d40997..08ba49bd9ada076c6650249f67af15174491f634 100644 --- a/lite/operators/box_clip_op.cc +++ b/lite/operators/box_clip_op.cc @@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const { return true; } -bool BoxClipOpLite::InferShape() const { +bool BoxClipOpLite::InferShapeImpl() const { auto* input = param_.Input; auto* output = param_.Output; output->Resize(input->dims()); diff --git a/lite/operators/box_clip_op.h b/lite/operators/box_clip_op.h index c7e07b1015c52eb5711638163bda327c11152dd0..0aae2112ec8b91ba63205fadd4123bc3c5fce2fd 100644 --- a/lite/operators/box_clip_op.h +++ b/lite/operators/box_clip_op.h @@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/box_coder_op.cc b/lite/operators/box_coder_op.cc index c86f494fc4f96f688c30027f1d6aa1ee452da8f0..3133176b35ecae49ed9171ef6e8b519c6774ce5d 100644 --- a/lite/operators/box_coder_op.cc +++ b/lite/operators/box_coder_op.cc @@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const { return true; } -bool BoxCoderOpLite::InferShape() const { +bool BoxCoderOpLite::InferShapeImpl() const { auto prior_box_dims = param_.prior_box->dims(); auto target_box_dims = param_.target_box->dims(); std::string code_type = param_.code_type; diff --git a/lite/operators/box_coder_op.h b/lite/operators/box_coder_op.h index 61d54fd484ff377763e00f1d71bff1c0c6f89398..51e86423e39786426d53fe8ced861866bfeb1053 100644 --- a/lite/operators/box_coder_op.h +++ b/lite/operators/box_coder_op.h @@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/calib_op.cc b/lite/operators/calib_op.cc index da00f01c3206c81fb89749432383ea8d99c14dc1..8da8747f8c9df038ee424395fd75a20a718f1970 100644 --- a/lite/operators/calib_op.cc +++ b/lite/operators/calib_op.cc @@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const { CHECK_OR_FALSE(param_.output); return true; } -bool CalibOpLite::InferShape() const { +bool CalibOpLite::InferShapeImpl() const { param_.output->Resize(param_.input->dims()); return true; } diff --git a/lite/operators/calib_op.h b/lite/operators/calib_op.h index d575766c10d1e6cd66bf7f8117315ffe21fe10fe..94240880f55e782f025fe5777eba19e0c96cfbee 100644 --- a/lite/operators/calib_op.h +++ b/lite/operators/calib_op.h @@ -42,7 +42,7 @@ class CalibOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope); diff --git a/lite/operators/cast_op.cc b/lite/operators/cast_op.cc index 9ece0a45a3e997e4d1663755f42f6b42efb86c5d..da12e2afded2c23565080b06409ce35b0535c4ff 100644 --- a/lite/operators/cast_op.cc +++ b/lite/operators/cast_op.cc @@ -25,7 +25,7 @@ bool CastOp::CheckShape() const { return true; } -bool CastOp::InferShape() const { +bool CastOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/cast_op.h b/lite/operators/cast_op.h index 2f5f57f12740d085bda36141299cfbe7c798c378..e045ef89f73d0ac29b0f03e148ad651c1513668f 100644 --- a/lite/operators/cast_op.h +++ b/lite/operators/cast_op.h @@ -30,7 +30,7 @@ class CastOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/collect_fpn_proposals_op.cc b/lite/operators/collect_fpn_proposals_op.cc index 4731d4bf81c241c6733b1403699874c1053d2b7f..27dd9a50b6fb0a9943b7a9d86be390cbc6d406b0 100644 --- a/lite/operators/collect_fpn_proposals_op.cc +++ b/lite/operators/collect_fpn_proposals_op.cc @@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { return true; } -bool CollectFpnProposalsOpLite::InferShape() const { +bool CollectFpnProposalsOpLite::InferShapeImpl() const { param_.fpn_rois->Resize({param_.post_nms_topN, 4}); return true; diff --git a/lite/operators/collect_fpn_proposals_op.h b/lite/operators/collect_fpn_proposals_op.h index 1ae7bb269ff53bb8add92d9afc8d462c45cb5f0b..b3104e81d5ff8d82083a7b37ffd88dd169b840c9 100644 --- a/lite/operators/collect_fpn_proposals_op.h +++ b/lite/operators/collect_fpn_proposals_op.h @@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/compare_op.cc b/lite/operators/compare_op.cc index aa500ba35c37cf8af17091d8d37d8fd8d1a08e0e..f458eae71edea6086e8947ae8881f6f218e49808 100644 --- a/lite/operators/compare_op.cc +++ b/lite/operators/compare_op.cc @@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const { return true; } -bool CompareOp::InferShape() const { +bool CompareOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/compare_op.h b/lite/operators/compare_op.h index 7ca21caaa1347f248213b2b43293ca18d514ba9a..c94cf88516af7676f8e524c091713cbaa4dd70ff 100644 --- a/lite/operators/compare_op.h +++ b/lite/operators/compare_op.h @@ -30,7 +30,7 @@ class CompareOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index b2f7438b64aa34787896839f020f0b056e6453fb..c15bf292897006b3c6d5e67bcfaea5d0e590a82d 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const { return true; } -bool ConcatOpLite::InferShape() const { +bool ConcatOpLite::InferShapeImpl() const { const std::vector &inputs = param_.x; const size_t n = inputs.size(); CHECK_GT_OR_FALSE(n, 0); diff --git a/lite/operators/concat_op.h b/lite/operators/concat_op.h index acc41de9b36cf6a808788a4f585e8a9c7f049717..2ac1572c833db217546aaa176640cb5c1022d3bf 100644 --- a/lite/operators/concat_op.h +++ b/lite/operators/concat_op.h @@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conditional_block_op.cc b/lite/operators/conditional_block_op.cc index c79c4e20a29834e858bc670104e2a09e55888c85..e3678e92c9d33be5428c82331ce963f4c6067369 100644 --- a/lite/operators/conditional_block_op.cc +++ b/lite/operators/conditional_block_op.cc @@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const { return true; } -bool ConditionalBlockOpLite::InferShape() const { return true; } +bool ConditionalBlockOpLite::InferShapeImpl() const { return true; } bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { diff --git a/lite/operators/conditional_block_op.h b/lite/operators/conditional_block_op.h index 5518c255c5799aa5b44557a4493275794fd598f5..1815731c8df3ac07bee80aa8e0cc658e752b5c4f 100644 --- a/lite/operators/conditional_block_op.h +++ b/lite/operators/conditional_block_op.h @@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 70ad3a32a83003e449524205a71dcc7536b9a11e..38c59a0290b03031e9cbe013a4a10c14c7ad1743 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector* paddings, } } -bool ConvOpLite::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.x->dims() && - last_input_lods[0] == param_.x->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool ConvOpLite::InferShape() const { +bool ConvOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 3379fb409529e261f4af38ef2ee3483f17cc8a3b..eab17fe6db0a59a9eb0eea0ab7344758a8232d15 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -34,9 +34,7 @@ class ConvOpLite : public OpLite { explicit ConvOpLite(const std::string& type) : OpLite(type) {} bool CheckShape() const override; - - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index a84b975492040ec0bdc1326f33f8b7edafdea2bb..511a5157ad58e5e2d7bda5c4d0de136c9b3f9590 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size, return output_size; } -bool ConvTransposeOpLite::InferShape() const { +bool ConvTransposeOpLite::InferShapeImpl() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_transpose_op.h b/lite/operators/conv_transpose_op.h index fb25c022f974ad195bf72b19cb9b459b2d11d5f2..891ece4f052128c8c236db5650414d6015ea9565 100644 --- a/lite/operators/conv_transpose_op.h +++ b/lite/operators/conv_transpose_op.h @@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crf_decoding_op.cc b/lite/operators/crf_decoding_op.cc index 1b0a27ab4afdfc165dedc2ccfad492658ec40399..b1af573518bc483b6eaf5e013609583b548fb300 100644 --- a/lite/operators/crf_decoding_op.cc +++ b/lite/operators/crf_decoding_op.cc @@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const { return true; } -bool CrfDecodingOpLite::InferShape() const { +bool CrfDecodingOpLite::InferShapeImpl() const { auto emission_dims = param_.emission->dims(); if (param_.length == nullptr) { param_.viterbi_path->Resize({emission_dims[0], 1}); diff --git a/lite/operators/crf_decoding_op.h b/lite/operators/crf_decoding_op.h index 6aaf338ec240d2caa659785f909d5eee7d249008..4bc50410ab0504b3e25585caba7f8fff823553b0 100644 --- a/lite/operators/crf_decoding_op.h +++ b/lite/operators/crf_decoding_op.h @@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/crop_op.cc b/lite/operators/crop_op.cc index 1a27cfb34d958176c8ad0a6e17d7e17e5287d2d5..4905d92e587ea10783fe7a3cb88b6ee67761c73e 100644 --- a/lite/operators/crop_op.cc +++ b/lite/operators/crop_op.cc @@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const { return true; } -bool CropOpLite::InferShape() const { +bool CropOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); lite::DDim output_shape(x_dims); diff --git a/lite/operators/crop_op.h b/lite/operators/crop_op.h index f21278e891d265093c26be1f96e416974af13b2e..bd3d0e71d8780fab16134ba347f3208249403bd7 100644 --- a/lite/operators/crop_op.h +++ b/lite/operators/crop_op.h @@ -30,7 +30,7 @@ class CropOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/decode_bboxes_op.cc b/lite/operators/decode_bboxes_op.cc index e22adf1774427e10e3fa146e388a6ce365f86021..1903267c3aa46e048787f007a5c9cede8c574c5a 100644 --- a/lite/operators/decode_bboxes_op.cc +++ b/lite/operators/decode_bboxes_op.cc @@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const { return true; } -bool DecodeBboxesOpLite::InferShape() const { +bool DecodeBboxesOpLite::InferShapeImpl() const { param_.bbox_data->Resize(param_.loc_data->dims()); return true; } diff --git a/lite/operators/decode_bboxes_op.h b/lite/operators/decode_bboxes_op.h index c463992c8da6b042d5df027b03e64a594ede8a02..8848a1c26cd9363595a3200fc6e2535751f72df0 100644 --- a/lite/operators/decode_bboxes_op.h +++ b/lite/operators/decode_bboxes_op.h @@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index 86830df2f19b5615e8b9cfb4b3b57eb22000f588..5ac3eef63bb59c80bffaf3bed558b3ac5baf4d61 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const { return true; } -bool DensityPriorBoxOpLite::InferShape() const { return true; } +bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; } bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { diff --git a/lite/operators/density_prior_box_op.h b/lite/operators/density_prior_box_op.h index bad55ad3b7046da45663a2cdd41243ecd5d41cb0..d84b20557fab101ba60f0af58234ffca4e672a57 100644 --- a/lite/operators/density_prior_box_op.h +++ b/lite/operators/density_prior_box_op.h @@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/distribute_fpn_proposals_op.cc b/lite/operators/distribute_fpn_proposals_op.cc index 5d6a0fca923dd38fd456e024ec14ba7c2685163d..a23c5e1ffb50b1d22a42d5e68bd424d078e83110 100644 --- a/lite/operators/distribute_fpn_proposals_op.cc +++ b/lite/operators/distribute_fpn_proposals_op.cc @@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const { return true; } -bool DistributeFpnProposalsOpLite::InferShape() const { +bool DistributeFpnProposalsOpLite::InferShapeImpl() const { int num_out_rois = param_.max_level - param_.min_level + 1; for (int i = 0; i < num_out_rois; i++) { param_.multi_fpn_rois[i]->Resize({-1, 4}); diff --git a/lite/operators/distribute_fpn_proposals_op.h b/lite/operators/distribute_fpn_proposals_op.h index 2390e329329f7406f05ba69b3768556f94a02bec..22ab2006e072ea36037cb05faaca324a7d2922c9 100644 --- a/lite/operators/distribute_fpn_proposals_op.h +++ b/lite/operators/distribute_fpn_proposals_op.h @@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/dropout_op.cc b/lite/operators/dropout_op.cc index 03047de3b318ee2221809ee602d94f204568d723..858cc6d9197433985aabfb428993d2fa1333527e 100644 --- a/lite/operators/dropout_op.cc +++ b/lite/operators/dropout_op.cc @@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const { return true; } -bool DropoutOp::InferShape() const { +bool DropoutOp::InferShapeImpl() const { const auto x_dims = param_.x->dims(); param_.output->Resize(x_dims); if (param_.is_test == false) { diff --git a/lite/operators/dropout_op.h b/lite/operators/dropout_op.h index 97e17e350c6a87a82e3cf05635d9575269489d7a..bdf0e1d9046178b48f2b4917840eee6ac8572c5a 100644 --- a/lite/operators/dropout_op.h +++ b/lite/operators/dropout_op.h @@ -28,7 +28,7 @@ class DropoutOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/elementwise_grad_ops.cc b/lite/operators/elementwise_grad_ops.cc index 9d964bf9e36889f2bc72b2656d23bf4022cc121c..730785ba6e6553e6a306f87bdbc63ea5b1017f0a 100644 --- a/lite/operators/elementwise_grad_ops.cc +++ b/lite/operators/elementwise_grad_ops.cc @@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const { return true; } -bool ElementwiseGradOp::InferShape() const { +bool ElementwiseGradOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (param_.XGrad) { diff --git a/lite/operators/elementwise_grad_ops.h b/lite/operators/elementwise_grad_ops.h index c45d581936207f0b37ee70a0505b912d0b509e35..ca8a3241349b4cdc04e4800a0a88b215f586ba72 100644 --- a/lite/operators/elementwise_grad_ops.h +++ b/lite/operators/elementwise_grad_ops.h @@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index 044126b3c22fa853d4908c06c307f32278fa5b9b..f4debc39a0d480f38e6d37e8e60d516def7f0b55 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool ElementwiseOp::SmartInferShape() { - if (!last_input_shapes.empty()) { - if (last_input_shapes[0] == param_.X->dims() && - last_input_shapes[1] == param_.Y->dims() && - last_input_lods[0] == param_.X->lod() && - last_input_lods[1] == param_.Y->lod()) { - param_.Out->Resize(last_output_shapes[0]); - param_.Out->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.X->dims()); - last_input_lods.push_back(param_.X->lod()); - last_input_shapes.push_back(param_.Y->dims()); - last_input_lods.push_back(param_.Y->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.Out->dims()); - last_output_lods.push_back(param_.Out->lod()); - return true; -} -bool ElementwiseOp::InferShape() const { +bool ElementwiseOp::InferShapeImpl() const { auto x_dim = param_.X->dims(); auto y_dim = param_.Y->dims(); if (x_dim == y_dim) { @@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { // return true; //} -// bool ElementwiseGradExplicitOp::InferShape() const { +// bool ElementwiseGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/elementwise_ops.h b/lite/operators/elementwise_ops.h index 9d6e5781b9754eb22be11da0d7f77b764eb25912..0f1b682fa5f267dd802c5ee0e35aca8f6d68f39c 100644 --- a/lite/operators/elementwise_ops.h +++ b/lite/operators/elementwise_ops.h @@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/expand_op.cc b/lite/operators/expand_op.cc index 656e8babc022e3bb022b3c3b4bb066ea5e5d173c..8e40a3b236609b1e83b5224efb462a1f803764df 100644 --- a/lite/operators/expand_op.cc +++ b/lite/operators/expand_op.cc @@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const { return true; } -bool ExpandOpLite::InferShape() const { +bool ExpandOpLite::InferShapeImpl() const { DDim out_dims(param_.X->dims()); for (size_t i = 0; i < param_.expand_times.size(); ++i) { out_dims[i] *= param_.expand_times[i]; diff --git a/lite/operators/expand_op.h b/lite/operators/expand_op.h index ce5dcda9e80377699b168e6a4970a9bba0cf5039..1312df8e83747107e4c87e856c3b07fc2748d75b 100644 --- a/lite/operators/expand_op.h +++ b/lite/operators/expand_op.h @@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fake_channel_wise_dequantize_max_abs.h b/lite/operators/fake_channel_wise_dequantize_max_abs.h index 43afb7791fe617af0c7ac496cc62a12e6cc548d2..e26d5dda52f8b72d9202067a8782cf1dc10b983e 100644 --- a/lite/operators/fake_channel_wise_dequantize_max_abs.h +++ b/lite/operators/fake_channel_wise_dequantize_max_abs.h @@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_dequantize_max_abs.h b/lite/operators/fake_dequantize_max_abs.h index bc266327ebcb14da01201dcc1825367ff7ecd72e..c4bb19c04872078eb997afca6cd7a3cce6923fde 100644 --- a/lite/operators/fake_dequantize_max_abs.h +++ b/lite/operators/fake_dequantize_max_abs.h @@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h index 8efa46c41501be79ccc69f4cc9f9646c11673d2d..be7ec60e0eab730c2910c3822c976d579b48d6b7 100644 --- a/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_moving_avg_max_abs.h b/lite/operators/fake_quantize_moving_avg_max_abs.h index adc62a480d2d2efec54b3822f55a9f66c278e21e..5726231f31eab2012d2cd594c5c26977c71141ff 100644 --- a/lite/operators/fake_quantize_moving_avg_max_abs.h +++ b/lite/operators/fake_quantize_moving_avg_max_abs.h @@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h index f68d1e20f6e60bb5aa99a2402ea8c9f88aa18470..14f823ece2ee168ae09bc1db67f3d6a7e8c18d5d 100644 --- a/lite/operators/fake_quantize_range_abs_max.h +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { bool CheckShape() const override { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto x = op_desc.Input("X").front(); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 345fc0d605ccd68e3a6ef72429e20400a772568c..d58a9e5b881048dd47340082fe9c94a618a7a5fb 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const { return true; } -bool FcOpLite::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (last_input_shapes[0] == param_.input->dims() && - last_input_lods[0] == param_.input->lod()) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.input->dims()); - last_input_lods.push_back(param_.input->lod()); - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - - return true; -} -bool FcOpLite::InferShape() const { +bool FcOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& w_dims = param_.w->dims(); int in_num_col_dims = param_.in_num_col_dims; diff --git a/lite/operators/fc_op.h b/lite/operators/fc_op.h index f5dc302e27a220ee1f1e0679cbb3c2ed257747dd..2e6a3ad59a1ca6d2e31f42ceb4b2d1b381c697ee 100644 --- a/lite/operators/fc_op.h +++ b/lite/operators/fc_op.h @@ -35,8 +35,7 @@ class FcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/feed_op.cc b/lite/operators/feed_op.cc index 8a0c75f62b6bed5767a8cc4b8348b4ca5b59eea5..c429d1f5744e50ff84a0a3d76e2f3e1ba68a0821 100644 --- a/lite/operators/feed_op.cc +++ b/lite/operators/feed_op.cc @@ -29,7 +29,7 @@ class FeedOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/fetch_op.cc b/lite/operators/fetch_op.cc index d50c0db34084bf8a70c9451ba0f0d8960e9d18c9..9db5fb418dab4418a0d6a622f87620c5c2673ecf 100644 --- a/lite/operators/fetch_op.cc +++ b/lite/operators/fetch_op.cc @@ -29,7 +29,7 @@ class FetchOp : public OpLite { return true; } - bool InferShape() const override { return true; } + bool InferShapeImpl() const override { return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: diff --git a/lite/operators/fill_constant_batch_size_like_op.cc b/lite/operators/fill_constant_batch_size_like_op.cc index 7df3a6aa9e75ecc3fe88031a544c8e5ed3d1dd02..5b0ebb38e717afea4dabe011c0161248e2113a02 100644 --- a/lite/operators/fill_constant_batch_size_like_op.cc +++ b/lite/operators/fill_constant_batch_size_like_op.cc @@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { return true; } -bool FillConstantBatchSizeLikeOp::InferShape() const { +bool FillConstantBatchSizeLikeOp::InferShapeImpl() const { std::vector output_dim{param_.shape.begin(), param_.shape.end()}; if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h index 33cc45779f6132fbc34b33eb2abbe9ca71418046..3c576ab28222c45aa17ba96f5e3e585624a29c02 100644 --- a/lite/operators/fill_constant_batch_size_like_op.h +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index 698b787f469375831d937fdf16bb58af06288e71..565c4bbd16e01af340e728e28866268c1a845760 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const { return true; } -bool FillConstantOp::InferShape() const { +bool FillConstantOp::InferShapeImpl() const { std::vector out_shape; auto shape_tensor = param_.shape_tensor; auto shape_tensor_list = param_.shape_tensor_list; diff --git a/lite/operators/fill_constant_op.h b/lite/operators/fill_constant_op.h index aa2fea5a665ee9a3c50efa3ec354fe52d9643050..3c0500898bef45efc7a72bc68c82fca9036c63f4 100644 --- a/lite/operators/fill_constant_op.h +++ b/lite/operators/fill_constant_op.h @@ -31,7 +31,7 @@ class FillConstantOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc index 6deab45023876b1a5707ef5cea6ec69af3875328..b270dbf52f9a19f574e6f8967ff93e3a013e5737 100644 --- a/lite/operators/flatten_op.cc +++ b/lite/operators/flatten_op.cc @@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const { return true; } -bool FlattenOp::InferShape() const { +bool FlattenOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto out_lod = param_.output->mutable_lod(); @@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const { return true; } -bool Flatten2Op::InferShape() const { - FlattenOp::InferShape(); +bool Flatten2Op::InferShapeImpl() const { + FlattenOp::InferShapeImpl(); auto x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h index 61680fd3903b77f8826cda6f6a242739720155d7..78b803d765c8513ead9bf482bf23914ac4bf3430 100644 --- a/lite/operators/flatten_op.h +++ b/lite/operators/flatten_op.h @@ -30,7 +30,7 @@ class FlattenOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index 244394b95aafede6956bc548430f5c14f28ae910..dfe3bda6c65a75f8b0f8a080d9dc367fb493e6f2 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { return true; } -bool FusionElementwiseActivationOp::InferShape() const { +bool FusionElementwiseActivationOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); param_.Out->Resize(param_.X->dims()); return true; @@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, // return true; // } -// bool FusionElementwiseActivationGradExplicitOp::InferShape() const { +// bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const { // param_.X_grad->Resize(param_.Out_grad->dims()); // param_.Y_grad->Resize(param_.Y->dims()); // return true; diff --git a/lite/operators/fusion_elementwise_activation_ops.h b/lite/operators/fusion_elementwise_activation_ops.h index db521284f0fc96c542fd5e7104b045f83f837f97..738c2168225d86f4614ba8eaaa6c6354f038116c 100644 --- a/lite/operators/fusion_elementwise_activation_ops.h +++ b/lite/operators/fusion_elementwise_activation_ops.h @@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; @@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite { // bool CheckShape() const override; -// bool InferShape() const override; +// bool InferShapeImpl() const override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; diff --git a/lite/operators/gather_op.cc b/lite/operators/gather_op.cc index 858dad8e4c4b623e8d2499019bba36c7e0373b60..670cd61c8ea5af2f29a908b5d49bedccaff93c0a 100644 --- a/lite/operators/gather_op.cc +++ b/lite/operators/gather_op.cc @@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const { return true; } -bool GatherOp::InferShape() const { +bool GatherOp::InferShapeImpl() const { auto index_dims = param_.Index->dims(); CHECK(index_dims.size() == 1 || (index_dims.size() == 2 && index_dims[1] == 1)) diff --git a/lite/operators/gather_op.h b/lite/operators/gather_op.h index 58d5a30ffbb5f563503c8934d8c9e40bb539d5df..d2072c3a6d6e6e0b100ab3bb9413da8cd4f51f6b 100644 --- a/lite/operators/gather_op.h +++ b/lite/operators/gather_op.h @@ -30,7 +30,7 @@ class GatherOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/generate_proposals_op.cc b/lite/operators/generate_proposals_op.cc index a29ef65e97ccfdaaaf20d6cbbb411fc69cee6f54..48e709c348974dcf1868a7a17425b4168f04b4f6 100644 --- a/lite/operators/generate_proposals_op.cc +++ b/lite/operators/generate_proposals_op.cc @@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const { return true; } -bool GenerateProposalsOpLite::InferShape() const { +bool GenerateProposalsOpLite::InferShapeImpl() const { param_.RpnRois->Resize(std::vector({-1, 4})); param_.RpnRoiProbs->Resize(std::vector({-1, 1})); return true; diff --git a/lite/operators/generate_proposals_op.h b/lite/operators/generate_proposals_op.h index 502bcca1a3276fbbcc2f05bf8b38fcf2d1bbb024..35dee1966bda7cd9e865f42113c7a92061a3782a 100644 --- a/lite/operators/generate_proposals_op.h +++ b/lite/operators/generate_proposals_op.h @@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/grid_sampler_op.cc b/lite/operators/grid_sampler_op.cc index 2b13d17da7c439f582f682a74b1590cda632cf78..97e2b36a6bcd0eb784a39ab4f2a2e0703d7a7c93 100644 --- a/lite/operators/grid_sampler_op.cc +++ b/lite/operators/grid_sampler_op.cc @@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const { return true; } -bool GridSamplerOp::InferShape() const { +bool GridSamplerOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out->Resize(x_dims); return true; diff --git a/lite/operators/grid_sampler_op.h b/lite/operators/grid_sampler_op.h index 035e1b834510affefacafad763d75d6fbf53aed9..2fba4fe69311c274765e9db4c9b27e137c78a3ee 100644 --- a/lite/operators/grid_sampler_op.h +++ b/lite/operators/grid_sampler_op.h @@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_op.cc b/lite/operators/gru_op.cc index eb97d65a1a213e31b23087d1ca5c8e963ecf9bbb..862a1ff98f699393c9aa91afab978f947cc25187 100644 --- a/lite/operators/gru_op.cc +++ b/lite/operators/gru_op.cc @@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const { return true; } -bool GRUOpLite::InferShape() const { +bool GRUOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); const auto& weight_dims = param_.weight->dims(); int frame_size = weight_dims[0]; diff --git a/lite/operators/gru_op.h b/lite/operators/gru_op.h index c43f32f0cd41b8fa9bc8a541c48523a4f120009d..34f87fa79371fc3d798a57b4aae0945a27a692c3 100644 --- a/lite/operators/gru_op.h +++ b/lite/operators/gru_op.h @@ -30,7 +30,7 @@ class GRUOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/gru_unit_op.cc b/lite/operators/gru_unit_op.cc index ed33507fc3fa61fce1e718581309ae37992c0531..ad025fbbc19cf27f053d5cc2bda566f186a72529 100644 --- a/lite/operators/gru_unit_op.cc +++ b/lite/operators/gru_unit_op.cc @@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const { return true; } -bool GRUUnitOpLite::InferShape() const { +bool GRUUnitOpLite::InferShapeImpl() const { auto input_dims = param_.input->dims(); auto hidden_prev_dims = param_.hidden_prev->dims(); auto weight_dims = param_.weight->dims(); diff --git a/lite/operators/gru_unit_op.h b/lite/operators/gru_unit_op.h index 301a7e7323afaea16dce2adcb356a41a8b0b8cac..2785e60e95b0f36cc5bf92714af857ef658d80dc 100644 --- a/lite/operators/gru_unit_op.h +++ b/lite/operators/gru_unit_op.h @@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/im2sequence_op.cc b/lite/operators/im2sequence_op.cc index 40ab2106af85b3386f93385785b65b9293b1c7f9..ae7b1029468ddb9f723de522ce715859d9a08a09 100644 --- a/lite/operators/im2sequence_op.cc +++ b/lite/operators/im2sequence_op.cc @@ -26,7 +26,7 @@ inline int Im2SeqOutputSize( } bool Im2SequenceOp::CheckShape() const { return true; } -bool Im2SequenceOp::InferShape() const { +bool Im2SequenceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/im2sequence_op.h b/lite/operators/im2sequence_op.h index 83a347c913fd80c3a890053e1e1945b6cf2a7cd4..62525baaf071bb92b79773c248adb4fd1c798d90 100644 --- a/lite/operators/im2sequence_op.h +++ b/lite/operators/im2sequence_op.h @@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/increment_op.cc b/lite/operators/increment_op.cc index c1928ccbd4ca28ad1d1d83d2e232234ca1677aaa..9b34e4f73b8cc0e27cab06547d3fab84c7033b88 100644 --- a/lite/operators/increment_op.cc +++ b/lite/operators/increment_op.cc @@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const { return true; } -bool IncrementOp::InferShape() const { +bool IncrementOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/increment_op.h b/lite/operators/increment_op.h index f180d527c31494dcfb8cb53f005861ae639c9844..d4e6fd6b1ff1aea47df130d510bc84ab0a0b6019 100644 --- a/lite/operators/increment_op.h +++ b/lite/operators/increment_op.h @@ -30,7 +30,7 @@ class IncrementOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/instance_norm_op.cc b/lite/operators/instance_norm_op.cc index 510402ba1fb363f383b3cba8eb322a4ff7975c18..5f685ccfc59a7170a2d29d2b8e561ed933c8517c 100644 --- a/lite/operators/instance_norm_op.cc +++ b/lite/operators/instance_norm_op.cc @@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const { return true; } -bool InstanceNormOp::InferShape() const { +bool InstanceNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t batch_size = x_dims[0]; int64_t channel_size = x_dims[1]; diff --git a/lite/operators/instance_norm_op.h b/lite/operators/instance_norm_op.h index d128345805cf77ac2a4123a8549c92051593fff0..94a1f69fa4433072a986f1d82d5f1b8401a03386 100644 --- a/lite/operators/instance_norm_op.h +++ b/lite/operators/instance_norm_op.h @@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index 1bfb20df4e4b9762e93b6a39f0d34eb2521acfe0..0ef22e42903842ac41e9aca010f78796b5a32fcc 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const { return true; } -bool InterpolateOp::InferShape() const { +bool InterpolateOp::InferShapeImpl() const { auto X = param_.X; int n = X->dims()[0]; diff --git a/lite/operators/interpolate_op.h b/lite/operators/interpolate_op.h index 5fcf4ef594d52a4ac14e5545b195cc51cbf379cf..2bc938964811c57189e45d3b9d892542f9f02e8f 100644 --- a/lite/operators/interpolate_op.h +++ b/lite/operators/interpolate_op.h @@ -31,7 +31,7 @@ class InterpolateOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/io_copy_op.cc b/lite/operators/io_copy_op.cc index 7df636d7b2d877a5539a980080077be785d47505..05b2d3d800d2d2989ae23f9a1ccac57021e82ac1 100644 --- a/lite/operators/io_copy_op.cc +++ b/lite/operators/io_copy_op.cc @@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool IoCopyOp::InferShape() const { +bool IoCopyOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/io_copy_op.h b/lite/operators/io_copy_op.h index 8d6d69d63ed8b7ec289d7935ea28df2482e0cf31..d6922b667d78e3b79a005aae895b9e63dc76fa21 100644 --- a/lite/operators/io_copy_op.h +++ b/lite/operators/io_copy_op.h @@ -24,7 +24,7 @@ class IoCopyOp : public OpLite { public: explicit IoCopyOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/is_empty_op.cc b/lite/operators/is_empty_op.cc index ed4c69e64eaae8fdcb8289c5389dcff1df2ea8b5..a62470e4bb7f88d4c441dc8814bba7c4913ab3e4 100644 --- a/lite/operators/is_empty_op.cc +++ b/lite/operators/is_empty_op.cc @@ -21,7 +21,7 @@ namespace operators { bool IsEmptyOp::CheckShape() const { return true; } -bool IsEmptyOp::InferShape() const { return true; } +bool IsEmptyOp::InferShapeImpl() const { return true; } bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = diff --git a/lite/operators/is_empty_op.h b/lite/operators/is_empty_op.h index 5bfa0905c7c57110473fde48d78d17947abbb547..14c0830c233a9ff011b00d130bc36054a7ede57a 100644 --- a/lite/operators/is_empty_op.h +++ b/lite/operators/is_empty_op.h @@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layer_norm_op.cc b/lite/operators/layer_norm_op.cc index 18ea6cbf281846600273d6e7d462ed43f2e45637..2f50d232e3781e44b8203084382c20872094a263 100644 --- a/lite/operators/layer_norm_op.cc +++ b/lite/operators/layer_norm_op.cc @@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const { return true; } -bool LayerNormOp::InferShape() const { +bool LayerNormOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); param_.Y->Resize(out_dims); auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[0]; diff --git a/lite/operators/layer_norm_op.h b/lite/operators/layer_norm_op.h index 297f6bdd402b919b4baa1915135ed909c57cfa0b..6e15d2f599beb14df024f2591b098b128c3af8dd 100644 --- a/lite/operators/layer_norm_op.h +++ b/lite/operators/layer_norm_op.h @@ -30,7 +30,7 @@ class LayerNormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/layout_op.cc b/lite/operators/layout_op.cc index 01272568045233a90e2aaaffa758e4ce1515d700..d71dab68702ddd53af1540c2a6dce14d43b27e09 100644 --- a/lite/operators/layout_op.cc +++ b/lite/operators/layout_op.cc @@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool LayoutOp::InferShape() const { +bool LayoutOp::InferShapeImpl() const { param_.y->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/layout_op.h b/lite/operators/layout_op.h index 216d571d7c37204ec6ef6c513caba726841bcdf2..f51768863bf2e942262f364c271b902922b39cb1 100644 --- a/lite/operators/layout_op.h +++ b/lite/operators/layout_op.h @@ -24,7 +24,7 @@ class LayoutOp : public OpLite { public: explicit LayoutOp(const std::string &type) : OpLite(type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool Run() override; std::string DebugString() const override; diff --git a/lite/operators/lod_reset_op.cc b/lite/operators/lod_reset_op.cc index 1754e709ff2439462e8f40d047f5594ed740e07a..c30c78bbc6c1300660c01e6219c9e5113c39a718 100644 --- a/lite/operators/lod_reset_op.cc +++ b/lite/operators/lod_reset_op.cc @@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const { return true; } -bool LodResetOp::InferShape() const { +bool LodResetOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. param_.Out->Resize(param_.X->dims()); diff --git a/lite/operators/lod_reset_op.h b/lite/operators/lod_reset_op.h index 4e048a9a696c3e1e4a366c732bb269134c9d5d06..8ca2bc578099aabfe6c9649d58e9caeabea7870f 100644 --- a/lite/operators/lod_reset_op.h +++ b/lite/operators/lod_reset_op.h @@ -30,7 +30,7 @@ class LodResetOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/logical_op.cc b/lite/operators/logical_op.cc index 8af982ad535192f4897ea70cdb180b230d29dfd6..2dd5b798280ef80a54d557e449beee15959971b8 100644 --- a/lite/operators/logical_op.cc +++ b/lite/operators/logical_op.cc @@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const { return true; } -bool BinaryLogicalOp::InferShape() const { +bool BinaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); @@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const { return true; } -bool UnaryLogicalOp::InferShape() const { +bool UnaryLogicalOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/logical_op.h b/lite/operators/logical_op.h index a0fc1d68a60a0650179f66ca9fd443e96a483c34..e784d4d99b7de29593e411db9b6a888e5bd52e21 100644 --- a/lite/operators/logical_op.h +++ b/lite/operators/logical_op.h @@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_dequant_op.cc b/lite/operators/lookup_table_dequant_op.cc index b81043bfbfeed356e3d67065686057adfadcb25f..844544dfad3c535342169d08159a80484a29643d 100644 --- a/lite/operators/lookup_table_dequant_op.cc +++ b/lite/operators/lookup_table_dequant_op.cc @@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const { return true; } -bool LookupTableDequantOpLite::InferShape() const { +bool LookupTableDequantOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_dequant_op.h b/lite/operators/lookup_table_dequant_op.h index 3a9683d5ca0d87365cb240b91dccab07cf26ca71..a094cac9a49891294ec71194d39a023867f58052 100644 --- a/lite/operators/lookup_table_dequant_op.h +++ b/lite/operators/lookup_table_dequant_op.h @@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index df066435a8758e5a75ad1bed78111396d50b44cf..9bc22080bfb6c0ebda28e620dd9b781ec515ecbb 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const { return true; } -bool LookupTableOpLite::InferShape() const { +bool LookupTableOpLite::InferShapeImpl() const { const auto& table_dims = param_.W->dims(); const auto& ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_op.h b/lite/operators/lookup_table_op.h index 2701af984088cfda450f98fa5bc432dad7c2bc59..91ef77cfa1852a93d3aa28aceb616eec3306af3a 100644 --- a/lite/operators/lookup_table_op.h +++ b/lite/operators/lookup_table_op.h @@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lookup_table_v2_op.cc b/lite/operators/lookup_table_v2_op.cc index df642e6191cffb748191b38eb5a6578aac163da4..8c76090df385ca5adf454ac1918c11c8838695f1 100644 --- a/lite/operators/lookup_table_v2_op.cc +++ b/lite/operators/lookup_table_v2_op.cc @@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const { return true; } -bool LookupTableV2OpLite::InferShape() const { +bool LookupTableV2OpLite::InferShapeImpl() const { auto table_dims = param_.W->dims(); auto ids_dims = param_.Ids->dims(); diff --git a/lite/operators/lookup_table_v2_op.h b/lite/operators/lookup_table_v2_op.h index dabff3f0cac75cb70cde6eb6e95df34dc36901fe..b0b8829fe6aeaf02a445109ea804266758919822 100644 --- a/lite/operators/lookup_table_v2_op.h +++ b/lite/operators/lookup_table_v2_op.h @@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lrn_op.cc b/lite/operators/lrn_op.cc index aff3e5af5566771411acf20736fdbec703f5def9..dcaffe1aa7cbc64c26dd2d56fcaa650e1599eb10 100644 --- a/lite/operators/lrn_op.cc +++ b/lite/operators/lrn_op.cc @@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const { return true; } -bool LrnOpLite::InferShape() const { +bool LrnOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/lrn_op.h b/lite/operators/lrn_op.h index a569a77fb40d7ea60e9e41171e73668e499684a5..13dfdefdc6f28dc289f490340faa14c166485db0 100644 --- a/lite/operators/lrn_op.h +++ b/lite/operators/lrn_op.h @@ -28,7 +28,7 @@ class LrnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/lstm_op.cc b/lite/operators/lstm_op.cc index 36a0d2f53c1f30976ad6df811ad352721e3d7ff7..d9b6ebfc321190286d27272ea7b09a2a751cd9f1 100644 --- a/lite/operators/lstm_op.cc +++ b/lite/operators/lstm_op.cc @@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const { return true; } -bool LstmOp::InferShape() const { +bool LstmOp::InferShapeImpl() const { auto in_dims = param_.Input->dims(); if (param_.H0) { CHECK(param_.C0) << "lstm must has H0 and C0 in the same time"; diff --git a/lite/operators/lstm_op.h b/lite/operators/lstm_op.h index 221bd5c37945f4ff65b21a83449937563d9e5944..38bef385da67defa4e3459cfbcb6cbf24e0f2ed9 100644 --- a/lite/operators/lstm_op.h +++ b/lite/operators/lstm_op.h @@ -30,7 +30,7 @@ class LstmOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/match_matrix_tensor_op.cc b/lite/operators/match_matrix_tensor_op.cc index a8095a94bf75cd5d6d9087509449c159056ebc28..1cc751109f76a96097d363b493322dde182a715d 100644 --- a/lite/operators/match_matrix_tensor_op.cc +++ b/lite/operators/match_matrix_tensor_op.cc @@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const { return true; } -bool MatchMatrixTensorOpLite::InferShape() const { +bool MatchMatrixTensorOpLite::InferShapeImpl() const { const Tensor* x = param_.x; const Tensor* y = param_.y; DDim x_dims = param_.x->dims(); diff --git a/lite/operators/match_matrix_tensor_op.h b/lite/operators/match_matrix_tensor_op.h index 404183ea5bda3c35ba8b833853bc0005d60b9f7d..f1070a81b471ded59610af1a5bb40e35ccba7aff 100644 --- a/lite/operators/match_matrix_tensor_op.h +++ b/lite/operators/match_matrix_tensor_op.h @@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 286ade7b2130ce662eea2b7ba4e142bf489306ca..1cdcdfa16760385db059a4894e35d04bda51a85d 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const { return true; } -bool MatMulOpLite::InferShape() const { +bool MatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); bool x_transpose = param_.transpose_X; diff --git a/lite/operators/matmul_op.h b/lite/operators/matmul_op.h index 0aa47c89dd2227f70e7264c39b13c019d9b00587..acb9d512f7ac50818e9521ca67e04318397dabb0 100644 --- a/lite/operators/matmul_op.h +++ b/lite/operators/matmul_op.h @@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mean_grad_op.cc b/lite/operators/mean_grad_op.cc index fd17cac14fca153499a52e93f6f09ea44ea9a559..55e374735ea8d861c65f1296968a40a8b5b1f096 100644 --- a/lite/operators/mean_grad_op.cc +++ b/lite/operators/mean_grad_op.cc @@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const { return true; } -bool MeanGradOp::InferShape() const { +bool MeanGradOp::InferShapeImpl() const { param_.X_grad->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/mean_grad_op.h b/lite/operators/mean_grad_op.h index 1bd604518bfc088fc45566e393fd997ae4eed06e..488581a71bb423c09540d17cbb05c170f6f06374 100644 --- a/lite/operators/mean_grad_op.h +++ b/lite/operators/mean_grad_op.h @@ -27,7 +27,7 @@ class MeanGradOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/mean_op.cc b/lite/operators/mean_op.cc index 618e9001db056b935de6aef8feff9125155d0e1a..9a66d4fbda3116ef7bd751f34f66eefd1f2e6e99 100644 --- a/lite/operators/mean_op.cc +++ b/lite/operators/mean_op.cc @@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const { return true; } -bool MeanOp::InferShape() const { +bool MeanOp::InferShapeImpl() const { param_.Out->Resize(std::vector{1}); return true; } diff --git a/lite/operators/mean_op.h b/lite/operators/mean_op.h index 8526842f93cb1d01debad9c6cb28ec28b98e43e9..c4dff93ce78aa4598bd12fb3181aa5f2bd4820b6 100644 --- a/lite/operators/mean_op.h +++ b/lite/operators/mean_op.h @@ -27,7 +27,7 @@ class MeanOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/merge_lod_tensor_op.cc b/lite/operators/merge_lod_tensor_op.cc index 4258715b1d1aa6bf7fac160dcd6fc8ca6dd3754d..704b5cad6fc80bee8bcb5dfd2921c5cf87182ff8 100644 --- a/lite/operators/merge_lod_tensor_op.cc +++ b/lite/operators/merge_lod_tensor_op.cc @@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const { return true; } -bool MergeLodTensorOpLite::InferShape() const { +bool MergeLodTensorOpLite::InferShapeImpl() const { auto dims = param_.in_true->dims(); param_.out->Resize(dims); return true; diff --git a/lite/operators/merge_lod_tensor_op.h b/lite/operators/merge_lod_tensor_op.h index 788a3451685cd0f42b72ee01e93e17da49507957..ec986fac1988efb5efa262c9fc340c6b450f8ddf 100644 --- a/lite/operators/merge_lod_tensor_op.h +++ b/lite/operators/merge_lod_tensor_op.h @@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/mul_grad_op.cc b/lite/operators/mul_grad_op.cc index 8215521637cbc29a4bdcc4b735b9658fc4cc4840..51e1fb310cb12d83dda9436bb73042c7b22fae11 100644 --- a/lite/operators/mul_grad_op.cc +++ b/lite/operators/mul_grad_op.cc @@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const { return true; } -bool MulGradOpLite::InferShape() const { +bool MulGradOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); if (param_.x_grad) { diff --git a/lite/operators/mul_grad_op.h b/lite/operators/mul_grad_op.h index ef61f54f9b88cd691ab98c4d8904b848dcea66b5..869aa60c6232000008cb57d110aa454396b2ff34 100644 --- a/lite/operators/mul_grad_op.h +++ b/lite/operators/mul_grad_op.h @@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/mul_op.cc b/lite/operators/mul_op.cc index c870abdc8989b48d8aa2f14f989ad475c027995e..8641a041e38b7a85ee7f0af8b3536f0b9224b36f 100644 --- a/lite/operators/mul_op.cc +++ b/lite/operators/mul_op.cc @@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const { return true; } -bool MulOpLite::InferShape() const { +bool MulOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); diff --git a/lite/operators/mul_op.h b/lite/operators/mul_op.h index caf7bf6ae902ac4e4f22d4a9aadfa108fa7622da..10a2e2efaa4db0e106e3c56c2f9b1cec9fb55ac4 100644 --- a/lite/operators/mul_op.h +++ b/lite/operators/mul_op.h @@ -33,7 +33,7 @@ class MulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. diff --git a/lite/operators/multiclass_nms_op.cc b/lite/operators/multiclass_nms_op.cc index 9ec79f8b57d63f20325bf686c1280522aa4fa80a..8e47a828c13f3ca73f033cd0422f8c05be857cbe 100644 --- a/lite/operators/multiclass_nms_op.cc +++ b/lite/operators/multiclass_nms_op.cc @@ -41,15 +41,9 @@ bool MulticlassNmsOpLite::CheckShape() const { return true; } -bool MulticlassNmsOpLite::InferShape() const { - auto box_dims = param_.bboxes->dims(); - auto score_dims = param_.scores->dims(); - auto score_size = score_dims.size(); - if (score_size == 3) { - param_.out->Resize({box_dims[1], box_dims[2], 3}); - } else { - param_.out->Resize({-1, box_dims[2] + 2}); - } +bool MulticlassNmsOpLite::InferShapeImpl() const { + // InferShape is useless for multiclass_nms + // out's dim is not sure before the end of calculation return true; } diff --git a/lite/operators/multiclass_nms_op.h b/lite/operators/multiclass_nms_op.h index 7be0d17d7478bdcfb4c4c6b1f22e505fb9da0846..f74479f3c9a42e6f5ec06126fedf91a2e17b6c2f 100644 --- a/lite/operators/multiclass_nms_op.h +++ b/lite/operators/multiclass_nms_op.h @@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/negative_op.cc b/lite/operators/negative_op.cc index 4db1dd4feede42fc4267eb3fc3553c538807f1a8..2b98f0a90af812ac9c524368e41177377f4d69e2 100644 --- a/lite/operators/negative_op.cc +++ b/lite/operators/negative_op.cc @@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const { return true; } -bool NegativeOpLite::InferShape() const { +bool NegativeOpLite::InferShapeImpl() const { lite::DDim input_dims; input_dims = param_.X->dims(); param_.Out->Resize(lite::DDim(input_dims)); diff --git a/lite/operators/negative_op.h b/lite/operators/negative_op.h index 83f1008c9630284956347b87151e58f49588b867..04ec92532559c050cc5a9e8ac6bdf9a817e0dc70 100644 --- a/lite/operators/negative_op.h +++ b/lite/operators/negative_op.h @@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/norm_op.cc b/lite/operators/norm_op.cc index dff26966d48889389e2837194c2bc5a96fc960e5..0513e5c942d73397f269f1fe7bb89572a97ae548 100644 --- a/lite/operators/norm_op.cc +++ b/lite/operators/norm_op.cc @@ -25,7 +25,7 @@ bool NormOp::CheckShape() const { return true; } -bool NormOp::InferShape() const { +bool NormOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto out_dims = param_.X->dims(); diff --git a/lite/operators/norm_op.h b/lite/operators/norm_op.h index ae4594ed023d47179a7125bd9183e39f505ae16b..5c69d959be81eaccddc396dadacf920493ef99f5 100644 --- a/lite/operators/norm_op.h +++ b/lite/operators/norm_op.h @@ -30,7 +30,7 @@ class NormOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 36d3b42c6b315a3858f475bd5756579137528051..3fdca389bca1ba09ebfe008365b6992b717270d8 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -24,6 +24,7 @@ #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/desc_apis.h" #include "lite/utils/all.h" +#include "lite/utils/variant.h" /* * This file contains all the argument parameter data structure for operators. */ @@ -32,6 +33,16 @@ namespace paddle { namespace lite { namespace operators { +struct ParamBase { + public: + const std::vector* input_tensor_ptrs() const { return nullptr; } + std::vector* output_tensor_ptrs() { return nullptr; } + + protected: + std::shared_ptr> input_tensor_ptrs_cache_{nullptr}; + std::shared_ptr> output_tensor_ptrs_cache_{nullptr}; +}; + using param_t = Any; #define WITH_INT8_CONFIG \ bool enable_int8{false}; \ @@ -41,38 +52,38 @@ using param_t = Any; int bit_length{8}; /// ----------------------- Functional operators ------------------------------ -struct FeedParam { +struct FeedParam : ParamBase { std::vector* feed_list{}; lite::Tensor* out{}; int col; }; -struct FetchParam { +struct FetchParam : ParamBase { const lite::Tensor* input{}; std::vector* fetch_list{}; int col; }; // Helper op for lite framework -struct IoCopyParam { +struct IoCopyParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct LayoutParam { +struct LayoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* y{}; int process_type{0}; }; -struct CalibParam { +struct CalibParam : ParamBase { const lite::Tensor* input{}; lite::Tensor* output{}; float scale; }; -struct SubgraphParam { +struct SubgraphParam : ParamBase { std::vector input_names{}; std::vector output_names{}; std::vector input_data_names{}; @@ -84,7 +95,7 @@ struct SubgraphParam { /// -------------------------- NN operators ------------------------------------ -struct FcParam { +struct FcParam : ParamBase { lite::Tensor* input{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* bias{nullptr}; @@ -95,9 +106,24 @@ struct FcParam { bool padding_weights{false}; // for int8 WITH_INT8_CONFIG -}; - -struct SearchSeqFcParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({input})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct SearchSeqFcParam : ParamBase { lite::Tensor* x{nullptr}; lite::Tensor* w{nullptr}; lite::Tensor* b{nullptr}; @@ -106,7 +132,7 @@ struct SearchSeqFcParam { }; // For Interpolate Op -struct InterpolateParam { +struct InterpolateParam : ParamBase { lite::Tensor* X{}; lite::Tensor* OutSize{}; lite::Tensor* Out{}; @@ -123,7 +149,7 @@ struct InterpolateParam { }; // For Mul Op -struct MulParam { +struct MulParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; lite::Tensor* output{}; @@ -132,9 +158,24 @@ struct MulParam { int y_num_col_dims{1}; // for int8 WITH_INT8_CONFIG -}; - -struct MulGradParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x, y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct MulGradParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* output_grad{}; @@ -146,7 +187,7 @@ struct MulGradParam { }; // For ReduceMean Op -struct ReduceMeanParam { +struct ReduceMeanParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; @@ -155,7 +196,7 @@ struct ReduceMeanParam { }; // For Stack Op -struct StackParam { +struct StackParam : ParamBase { std::vector X; lite::Tensor* Out{}; @@ -163,7 +204,7 @@ struct StackParam { }; // For Power Op -struct PowerParam { +struct PowerParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -172,7 +213,7 @@ struct PowerParam { float power{}; }; -struct ShuffleChannelParam { +struct ShuffleChannelParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; @@ -180,7 +221,7 @@ struct ShuffleChannelParam { }; // For Yolobox -struct YoloBoxParam { +struct YoloBoxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ImgSize{}; lite::Tensor* Boxes{}; @@ -193,24 +234,54 @@ struct YoloBoxParam { }; // For Scale Op -struct ScaleParam { +struct ScaleParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; float scale{1.}; float bias{}; bool bias_after_scale{true}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Softmax op -struct SoftmaxParam { +struct SoftmaxParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int axis{-1}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Reshape and Reshape2 Op -struct ReshapeParam { +struct ReshapeParam : ParamBase { const lite::Tensor* x{}; std::vector shape_tensor_vct{}; const lite::Tensor* shape_tensor{}; @@ -219,18 +290,51 @@ struct ReshapeParam { lite::Tensor* xshape{}; bool inplace{false}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Concat op -struct ConcatParam { +struct ConcatParam : ParamBase { std::vector x{}; lite::Tensor* output{}; int axis{0}; lite::Tensor* axis_tensor{}; + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + std::vector vec; + for (auto in : x) { + vec.push_back(in); + } + input_tensor_ptrs_cache_.reset(new std::vector(vec)); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- activation operators ---------------------- -struct ActivationParam { +struct ActivationParam : ParamBase { const lite::Tensor* X{}; float Leaky_relu_alpha{0}; // leaky_relu param float Relu_clipped_coef{6}; // relu_clipped param @@ -245,7 +349,7 @@ struct ActivationParam { lite_api::ActivationType active_type; }; -struct ActivationGradParam { +struct ActivationGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out{}; // for backward @@ -254,7 +358,7 @@ struct ActivationGradParam { }; // For Convolution op -struct ConvParam { +struct ConvParam : ParamBase { lite::Tensor* x{}; lite::Tensor* filter{}; lite::Tensor* bias{nullptr}; @@ -294,10 +398,26 @@ struct ConvParam { std::vector output_size; // for int8 WITH_INT8_CONFIG + + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For BatchNorm op -struct BatchNormParam { +struct BatchNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* bias{}; lite::Tensor* scale{}; @@ -313,10 +433,25 @@ struct BatchNormParam { float epsilon; float momentum; DataLayoutType data_layout{DATALAYOUT(kNCHW)}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({y})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Pooling op -struct PoolParam { +struct PoolParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::string pooling_type{""}; @@ -337,10 +472,25 @@ struct PoolParam { std::string data_format{"AnyLayout"}; // for int8 WITH_INT8_CONFIG + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Dropout op -struct DropoutParam { +struct DropoutParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* mask{}; @@ -352,7 +502,7 @@ struct DropoutParam { }; // For Split op -struct SplitParam { +struct SplitParam : ParamBase { lite::Tensor* x{}; std::vector output{}; lite::Tensor* axis_tensor; @@ -361,10 +511,25 @@ struct SplitParam { int axis{-1}; int num{0}; std::vector sections; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Transpose op -struct TransposeParam { +struct TransposeParam : ParamBase { const lite::Tensor* x{}; lite::Tensor* output{}; lite::Tensor* xshape{}; @@ -372,10 +537,25 @@ struct TransposeParam { std::vector axis; bool use_mkldnn{false}; std::string data_format{"AnyLayout"}; + /////////////////////////////////////////////////////////////////////////////////// + // // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- element wise operators ---------------------- -struct ElementwiseParam { +struct ElementwiseParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -384,9 +564,24 @@ struct ElementwiseParam { WITH_INT8_CONFIG float x_input_scale{1.0}; float y_input_scale{1.0}; -}; - -struct ElementwiseGradParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct ElementwiseGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; const lite::Tensor* OutGrad{}; @@ -404,12 +599,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { }; /// ----------------------- mean operators ---------------------- -struct MeanParam { +struct MeanParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct MeanGradParam { +struct MeanGradParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Out_grad{}; // for backward @@ -417,7 +612,7 @@ struct MeanGradParam { }; /// ----------------------- fill_constant operators ---------------------- -struct FillConstantParam { +struct FillConstantParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; lite::Tensor* shape_tensor{nullptr}; @@ -429,7 +624,7 @@ struct FillConstantParam { lite::Tensor* out{}; }; -struct FillConstantBatchSizeLikeParam { +struct FillConstantBatchSizeLikeParam : ParamBase { const lite::Tensor* input{nullptr}; lite::Tensor* out{nullptr}; @@ -443,7 +638,7 @@ struct FillConstantBatchSizeLikeParam { }; // -struct FakeQuantizeMovingAvgMaxAbsParam { +struct FakeQuantizeMovingAvgMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; const lite::Tensor* in_accum{}; @@ -457,14 +652,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam { float moving_rate{0.9}; }; -struct FakeDequantizeMaxAbsParam { +struct FakeDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* in_scale{}; lite::Tensor* out{}; float max_range; }; -struct FakeChannelWiseDequantizeMaxAbsParam { +struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase { const lite::Tensor* x{}; std::vector scale_tensors{}; lite::Tensor* out{}; @@ -472,7 +667,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam { }; /// ----------------------- sgd operators ---------------------- -struct SGDParam { +struct SGDParam : ParamBase { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; const lite::Tensor* Param{}; @@ -482,7 +677,7 @@ struct SGDParam { }; /// ----------------------- uniform_random operators ---------------------- -struct UniformRandomParam { +struct UniformRandomParam : ParamBase { std::vector shape{}; float min{-1.0f}; float max{1.0f}; @@ -491,12 +686,12 @@ struct UniformRandomParam { lite::Tensor* Out{}; }; /// ----------------------- negative operators -------------- -struct NegativeParam { +struct NegativeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- pad2d operators ---------------------- -struct Pad2dParam { +struct Pad2dParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector paddings{0, 0, 0, 0}; @@ -506,7 +701,7 @@ struct Pad2dParam { }; /// ----------------------- Crop operators ---------------------- -struct CropParam { +struct CropParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector offsets; @@ -514,21 +709,21 @@ struct CropParam { }; ///----------------------- argmax operators ---------------------- -struct ArgmaxParam { +struct ArgmaxParam : ParamBase { lite::Tensor* X{}; lite::Tensor* Out{}; int Axis{0}; }; ///----------------------- axpy operators ---------------------- -struct AxpyParam { +struct AxpyParam : ParamBase { lite::Tensor* Scale{}; lite::Tensor* X{}; lite::Tensor* Bias{}; lite::Tensor* Out{}; }; /// ----------------------- GRU unit operators ----------------------f -struct GRUUnitParam { +struct GRUUnitParam : ParamBase { enum ActType { identity, sigmoid, tanh, relu }; const lite::Tensor* input{nullptr}; const lite::Tensor* hidden_prev{nullptr}; @@ -544,7 +739,7 @@ struct GRUUnitParam { }; /// ------------------------------ lrn operators ------------------------------ -struct LrnParam { +struct LrnParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int n{5}; @@ -555,7 +750,7 @@ struct LrnParam { }; /// ----------------------- decode_bboxes operators ---------------------- -struct DecodeBboxesParam { +struct DecodeBboxesParam : ParamBase { const lite::Tensor* loc_data{}; const lite::Tensor* prior_data{}; lite::Tensor* bbox_data{}; @@ -571,7 +766,7 @@ struct DecodeBboxesParam { }; /// ----------------------- box_coder operators ---------------------- -struct BoxCoderParam { +struct BoxCoderParam : ParamBase { const lite::Tensor* prior_box{}; const lite::Tensor* prior_box_var{}; const lite::Tensor* target_box{}; @@ -584,7 +779,7 @@ struct BoxCoderParam { }; /// ----------------------- multiclass_nms operators ---------------------- -struct MulticlassNmsParam { +struct MulticlassNmsParam : ParamBase { const lite::Tensor* bboxes{}; const lite::Tensor* scores{}; lite::Tensor* out{}; @@ -599,7 +794,7 @@ struct MulticlassNmsParam { }; /// ----------------------- priorbox operators ---------------------- -struct PriorBoxParam { +struct PriorBoxParam : ParamBase { lite::Tensor* input{}; lite::Tensor* image{}; lite::Tensor* boxes{}; @@ -628,7 +823,7 @@ struct DensityPriorBoxParam : public PriorBoxParam { std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f -struct GRUParam { +struct GRUParam : ParamBase { const lite::Tensor* input{nullptr}; const lite::Tensor* h0{nullptr}; const lite::Tensor* weight{nullptr}; @@ -645,7 +840,7 @@ struct GRUParam { }; /// ----------------------- BeamSearchDecode operators ----------------------f -struct BeamSearchDecodeParam { +struct BeamSearchDecodeParam : ParamBase { std::vector* ids{nullptr}; std::vector* scores{nullptr}; lite::Tensor* sentence_ids{nullptr}; @@ -655,21 +850,21 @@ struct BeamSearchDecodeParam { }; /// ----------------------- LookupTable operators ----------------------f -struct LookupTableParam { +struct LookupTableParam : ParamBase { const lite::Tensor* W{nullptr}; const lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct LookupTableDequantParam { +struct LookupTableDequantParam : ParamBase { lite::Tensor* W{nullptr}; lite::Tensor* Ids{nullptr}; lite::Tensor* Out{nullptr}; int64_t padding_idx{-1}; }; -struct Im2SequenceParam { +struct Im2SequenceParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -679,19 +874,34 @@ struct Im2SequenceParam { std::vector out_strides{1, 1}; }; -struct SequenceSoftmaxParam { +struct SequenceSoftmaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; -}; - -struct NormParam { + /////////////////////////////////////////////////////////////////////////////////// + // // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct NormParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Norm{}; int axis{1}; float epsilon{1e-10}; }; -struct LayerNormParam { +struct LayerNormParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -702,13 +912,13 @@ struct LayerNormParam { float epsilon{1e-5}; }; -struct LogicalParam { +struct LogicalParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; }; -struct CompareParam { +struct CompareParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; bool force_cpu{0}; @@ -716,7 +926,7 @@ struct CompareParam { lite::Tensor* Out{}; }; -struct WhileParam { +struct WhileParam : ParamBase { Scope* scope{}; Tensor* cond{}; cpp::BlockDesc* sub_block{}; @@ -724,32 +934,32 @@ struct WhileParam { std::vector outs{}; }; -struct TopkParam { +struct TopkParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* Indices{}; int K{1}; }; -struct IncrementParam { +struct IncrementParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; float step{1}; }; -struct WriteToArrayParam { +struct WriteToArrayParam : ParamBase { const lite::Tensor* X{nullptr}; const lite::Tensor* I{nullptr}; std::vector* Out{nullptr}; }; -struct ReadFromArrayParam { +struct ReadFromArrayParam : ParamBase { const std::vector* X{nullptr}; const lite::Tensor* I{nullptr}; lite::Tensor* Out{nullptr}; }; -struct BeamSearchParam { +struct BeamSearchParam : ParamBase { const lite::Tensor* pre_ids{}; const lite::Tensor* pre_scores{}; const lite::Tensor* ids{}; @@ -763,7 +973,7 @@ struct BeamSearchParam { bool is_accumulated; }; -struct SequencePoolParam { +struct SequencePoolParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::string pool_type{"AVERAGE"}; @@ -773,7 +983,7 @@ struct SequencePoolParam { #endif }; -struct SequenceConvParam { +struct SequenceConvParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Filter{}; lite::Tensor* Out{}; @@ -782,13 +992,13 @@ struct SequenceConvParam { int contextLength; }; -struct SequencePoolConcatParam { +struct SequencePoolConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; std::vector pool_type{}; }; -struct SearchGroupPaddingParam { +struct SearchGroupPaddingParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out_emb_padding{}; lite::Tensor* out_new{}; @@ -796,36 +1006,36 @@ struct SearchGroupPaddingParam { int pad_id; }; -struct SequenceReshapeParam { +struct SequenceReshapeParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int new_dim; }; -struct SequenceExpandParam { +struct SequenceExpandParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; int ref_level{-1}; }; -struct SequenceExpandAsParam { +struct SequenceExpandAsParam : ParamBase { const lite::Tensor* x{nullptr}; const lite::Tensor* y{nullptr}; lite::Tensor* out{nullptr}; }; -struct SequenceReverseParam { +struct SequenceReverseParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct SequenceConcatParam { +struct SequenceConcatParam : ParamBase { std::vector X{}; lite::Tensor* Out{}; }; -struct AttentionPaddingMaskParam { +struct AttentionPaddingMaskParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int pad_id; @@ -834,21 +1044,21 @@ struct AttentionPaddingMaskParam { lite::Tensor* pad_begin{}; }; -struct SequenceArithmeticParam { +struct SequenceArithmeticParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; int op_type{1}; lite::Tensor* Out{}; }; -struct ReduceMaxParam { +struct ReduceMaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector dim{}; bool keep_dim{false}; }; -struct LodResetParam { +struct LodResetParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; @@ -856,12 +1066,12 @@ struct LodResetParam { bool append; }; -struct IsEmptyParam { +struct IsEmptyParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct ReduceParam { +struct ReduceParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; std::vector dim{0}; @@ -869,7 +1079,7 @@ struct ReduceParam { bool reduce_all{false}; }; -struct VarConv2DParam { +struct VarConv2DParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -888,19 +1098,19 @@ struct VarConv2DParam { }; /// ----------------------- shape operators ---------------------- -struct ShapeParam { +struct ShapeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; -struct CastParam { +struct CastParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; int out_dtype{2}; int in_dtype{2}; }; -struct SliceParam { +struct SliceParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector axes{}; @@ -912,9 +1122,24 @@ struct SliceParam { std::vector EndsTensorList{}; lite::Tensor* StartsTensor{nullptr}; lite::Tensor* EndsTensor{nullptr}; -}; - -struct AffineChannelParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct AffineChannelParam : ParamBase { const lite::Tensor* X{}; // X is 4D tensor const lite::Tensor* Scale{}; const lite::Tensor* Bias{}; @@ -922,7 +1147,7 @@ struct AffineChannelParam { lite::Tensor* Out{}; }; -struct AnchorGeneratorParam { +struct AnchorGeneratorParam : ParamBase { const lite::Tensor* Input{}; std::vector anchor_sizes{}; std::vector aspect_ratios{}; @@ -934,7 +1159,7 @@ struct AnchorGeneratorParam { lite::Tensor* Variances{}; }; -struct GenerateProposalsParam { +struct GenerateProposalsParam : ParamBase { // inputs const lite::Tensor* Scores{}; const lite::Tensor* BboxDeltas{}; @@ -954,53 +1179,98 @@ struct GenerateProposalsParam { lite::Tensor* RpnRoiProbs{}; }; /// ----------------------- squeeze operators ---------------------- -struct SqueezeParam { +struct SqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; -}; - -struct UnsqueezeParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct UnsqueezeParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; const lite::Tensor* axes_tensor{}; std::vector axes_tensor_vct{}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- expand operators ---------------------- -struct ExpandParam { +struct ExpandParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; std::vector expand_times{}; }; /// ----------------------- matmul operators ---------------------- -struct MatMulParam { +struct MatMulParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Y{}; lite::Tensor* Out{}; bool transpose_X{false}; bool transpose_Y{false}; float alpha{1.0f}; -}; - -struct GatherParam { + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } +}; + +struct GatherParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* Index{}; lite::Tensor* Out{}; }; /// ----------------------- assign operators ----------------------- -struct AssignParam { +struct AssignParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; }; /// ----------------------- roi_align operators ----------------------- -struct RoiAlignParam { +struct RoiAlignParam : ParamBase { lite::Tensor* X{}; lite::Tensor* ROIs{}; lite::Tensor* Out{}; @@ -1011,13 +1281,13 @@ struct RoiAlignParam { }; /// ----------------------- box_clip operators ----------------------- -struct BoxClipParam { +struct BoxClipParam : ParamBase { const lite::Tensor* Input{}; const lite::Tensor* ImInfo{}; lite::Tensor* Output{}; }; -struct RangeParam { +struct RangeParam : ParamBase { const lite::Tensor* Start; const lite::Tensor* End; const lite::Tensor* Step; @@ -1025,7 +1295,7 @@ struct RangeParam { }; /// ----------------------- assign_value operators ----------------------- -struct AssignValueParam { +struct AssignValueParam : ParamBase { std::vector shape{}; int dtype{}; std::vector fp32_values{}; @@ -1034,7 +1304,7 @@ struct AssignValueParam { }; /// --------------- sequence_topk_avg_pooling operators ------------------ -struct SequenceTopkAvgPoolingParam { +struct SequenceTopkAvgPoolingParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* ROW{}; const lite::Tensor* COLUMN{}; @@ -1045,7 +1315,7 @@ struct SequenceTopkAvgPoolingParam { }; /// --------------- search_fc operators ------------------ -struct SearchFcParam { +struct SearchFcParam : ParamBase { const lite::Tensor* X{}; const lite::Tensor* W{}; const lite::Tensor* b{}; @@ -1053,7 +1323,7 @@ struct SearchFcParam { int out_size{}; }; /// --------------------- match_matrix_tensor operators -------------------- -struct MatchMatrixTensorParam { +struct MatchMatrixTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* y{}; const lite::Tensor* w{}; @@ -1064,14 +1334,14 @@ struct MatchMatrixTensorParam { }; /// --------------------- search_seq_depadding operators -------------------- -struct SearchSeqDepaddingParam { +struct SearchSeqDepaddingParam : ParamBase { const lite::Tensor* pad{}; const lite::Tensor* src{}; lite::Tensor* out{}; }; /// --------------------- search_grnn operators -------------------- -struct SearchGrnnParam { +struct SearchGrnnParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* wi{}; const lite::Tensor* wh{}; @@ -1084,7 +1354,7 @@ struct SearchGrnnParam { lite::Tensor* layout_input{}; }; -struct SplitLodTensorParam { +struct SplitLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; lite::Tensor* out_true{}; @@ -1092,7 +1362,7 @@ struct SplitLodTensorParam { int level{}; }; -struct MergeLodTensorParam { +struct MergeLodTensorParam : ParamBase { const lite::Tensor* x{}; const lite::Tensor* mask{}; const lite::Tensor* in_true{}; @@ -1101,7 +1371,7 @@ struct MergeLodTensorParam { int level{}; }; -struct ConditionalBlockParam { +struct ConditionalBlockParam : ParamBase { const lite::Tensor* cond{}; std::vector x{}; std::vector outs{}; @@ -1110,14 +1380,14 @@ struct ConditionalBlockParam { bool is_scalar_condition{}; }; -struct CollectFpnProposalsParam { +struct CollectFpnProposalsParam : ParamBase { std::vector multi_level_rois{}; std::vector multi_level_scores{}; lite::Tensor* fpn_rois{}; int post_nms_topN{}; }; -struct DistributeFpnProposalsParam { +struct DistributeFpnProposalsParam : ParamBase { const lite::Tensor* fpn_rois{}; std::vector multi_fpn_rois{}; lite::Tensor* restore_index{}; @@ -1128,7 +1398,7 @@ struct DistributeFpnProposalsParam { }; /// --------------------- instance_norm operators -------------------- -struct InstanceNormParam { +struct InstanceNormParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* bias{}; @@ -1138,12 +1408,12 @@ struct InstanceNormParam { float epsilon; }; /// --------------------- grid sampler operators -------------------- -struct GridSamplerParam { +struct GridSamplerParam : ParamBase { lite::Tensor* x{}; lite::Tensor* out{}; lite::Tensor* grid{}; }; -struct LstmParam { +struct LstmParam : ParamBase { lite::Tensor* Input{}; lite::Tensor* Weight{}; lite::Tensor* Bias{}; @@ -1160,7 +1430,7 @@ struct LstmParam { std::string candidate_activation; }; -struct CrfDecodingParam { +struct CrfDecodingParam : ParamBase { lite::Tensor* emission{}; lite::Tensor* transition{}; lite::Tensor* label{}; diff --git a/lite/operators/pad2d_op.cc b/lite/operators/pad2d_op.cc index ff522b94b95091b6df6d4d2f71e18907c5118619..7af657c888f9b1b28a1b273a193be59e2ace895c 100644 --- a/lite/operators/pad2d_op.cc +++ b/lite/operators/pad2d_op.cc @@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const { return true; } -bool Pad2dOpLite::InferShape() const { +bool Pad2dOpLite::InferShapeImpl() const { // nchw auto x_dims = param_.X->dims(); int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1]; diff --git a/lite/operators/pad2d_op.h b/lite/operators/pad2d_op.h index c51a76a7aef5624b1480fd1b1cdf56bf23c63674..c6d2e565483655c6279af8318434f129ec92a5e5 100644 --- a/lite/operators/pad2d_op.h +++ b/lite/operators/pad2d_op.h @@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/pool_op.cc b/lite/operators/pool_op.cc index c6f6eed28f8cdb5f080b6f4367a1b88b1dbc0701..5fb990928ec1ae723bc12b695af1be5e50da5079 100644 --- a/lite/operators/pool_op.cc +++ b/lite/operators/pool_op.cc @@ -60,7 +60,7 @@ int PoolOutputSize(int input_size, return output_size; } -bool PoolOpLite::InferShape() const { +bool PoolOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); std::vector& ksize = param_.ksize; // dynamic update 4-pad diff --git a/lite/operators/pool_op.h b/lite/operators/pool_op.h index c44875ff95b554ca92cf5288597a5bdaf2cb1bf8..3fcf37e6348628d489e9a2097e2c8dac7eba3e3c 100644 --- a/lite/operators/pool_op.h +++ b/lite/operators/pool_op.h @@ -37,7 +37,7 @@ class PoolOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { diff --git a/lite/operators/power_op.cc b/lite/operators/power_op.cc index 578d95ad53ffe0481288934a7a04d0f9e4442440..83c9edfaca1505746640280633bf6d47cddc6146 100644 --- a/lite/operators/power_op.cc +++ b/lite/operators/power_op.cc @@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const { return true; } -bool PowerOp::InferShape() const { +bool PowerOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/power_op.h b/lite/operators/power_op.h index a6d43f4394a8d3a2141f32e1fb633aef8c8227f8..e89dfa7b8f682e029bfba1059fda9c17340c420b 100644 --- a/lite/operators/power_op.h +++ b/lite/operators/power_op.h @@ -31,7 +31,7 @@ class PowerOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index c4717c8185b24cfd9f6a551dcb932dc325a502d2..f1b715a46e1378f805d91312cc7804cb4097ec02 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const { return true; } -bool PriorBoxOpLite::InferShape() const { return true; } +bool PriorBoxOpLite::InferShapeImpl() const { return true; } bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto input = opdesc.Input("Input").front(); diff --git a/lite/operators/prior_box_op.h b/lite/operators/prior_box_op.h index a393e80315eab07cc8558da8c26d6acad8cc76c1..1348b7cc73f6b731453584ef455813fe0d1cf8be 100644 --- a/lite/operators/prior_box_op.h +++ b/lite/operators/prior_box_op.h @@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/range_op.cc b/lite/operators/range_op.cc index a179d8ffe7abc1665b13f7d0dfeaa8b3c18cf1d5..19f474ba43b15153a7e2cca38f5ff9b097b41342 100644 --- a/lite/operators/range_op.cc +++ b/lite/operators/range_op.cc @@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) { : std::ceil(std::abs((end - start) / step)); } -bool RangeOpLite::InferShape() const { +bool RangeOpLite::InferShapeImpl() const { int start = param_.Start->data()[0]; int end = param_.End->data()[0]; int step = param_.Step->data()[0]; diff --git a/lite/operators/range_op.h b/lite/operators/range_op.h index a1c7d4d4cc43d72001ac3519cb1c4f85ab8196ff..982ef5abf25aac816c00da918147bac8933424a9 100644 --- a/lite/operators/range_op.h +++ b/lite/operators/range_op.h @@ -29,7 +29,7 @@ class RangeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/read_from_array_op.cc b/lite/operators/read_from_array_op.cc index 930eff1ff5ff100c085a4fdb6bdf3a032d44c14b..495fd752c90da528e474b7aa726c65fd6e66c123 100644 --- a/lite/operators/read_from_array_op.cc +++ b/lite/operators/read_from_array_op.cc @@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { return true; } -bool ReadFromArrayOp::InferShape() const { +bool ReadFromArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; auto out_dims = (*param_.X)[id].dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/read_from_array_op.h b/lite/operators/read_from_array_op.h index 5c7ba1468f59e27a273b368014c707676c48e36a..299a3abaedcf3618f5e28a9636d427961a97b931 100644 --- a/lite/operators/read_from_array_op.h +++ b/lite/operators/read_from_array_op.h @@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_max_op.cc b/lite/operators/reduce_max_op.cc index d7d90ee1f454556baee1a87cfd0023f8cf8c119d..ba48acd11f3517f33b020ede92e07cfadc5d497b 100644 --- a/lite/operators/reduce_max_op.cc +++ b/lite/operators/reduce_max_op.cc @@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const { return true; } -bool ReduceMaxOp::InferShape() const { +bool ReduceMaxOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_max_op.h b/lite/operators/reduce_max_op.h index 60e263f1b9b72a31c223cc60f89a7ddf81949e8c..54b136a7576fb2bb078c5bcae727b15d319bdf8e 100644 --- a/lite/operators/reduce_max_op.h +++ b/lite/operators/reduce_max_op.h @@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite { ReduceMaxOp() {} explicit ReduceMaxOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_mean_op.cc b/lite/operators/reduce_mean_op.cc index bce31c315c22e93d7758a05ecf2ace0668dd0cc1..c5baca5e87068d267ada21854b7769bf2bc19461 100644 --- a/lite/operators/reduce_mean_op.cc +++ b/lite/operators/reduce_mean_op.cc @@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const { return true; } -bool ReduceMeanOp::InferShape() const { +bool ReduceMeanOp::InferShapeImpl() const { auto dims = param_.dim; auto x_dims = param_.X->dims(); bool reduce_all = false; diff --git a/lite/operators/reduce_mean_op.h b/lite/operators/reduce_mean_op.h index e701a1132aa1260b5f169f89dec546a0d80fc916..43fe955690b3e4569f75c88a4d7b9ba9e961fcca 100644 --- a/lite/operators/reduce_mean_op.h +++ b/lite/operators/reduce_mean_op.h @@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite { ReduceMeanOp() {} explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index e2cc56b416dd166e6b22a0c642907844ab964cc5..1af6daf8c73e8e41f69be8f8af8f485ac767d702 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const { return true; } -bool ReduceOp::InferShape() const { +bool ReduceOp::InferShapeImpl() const { const auto &x_dims = param_.x->dims(); auto x_rank = x_dims.size(); auto dims = param_.dim; diff --git a/lite/operators/reduce_ops.h b/lite/operators/reduce_ops.h index 0063aba1fa606c6228e7dcb1197bfb36f57aa33c..d4fdbd113586a57b0d5a1e6e5fbde6707efb7cc1 100644 --- a/lite/operators/reduce_ops.h +++ b/lite/operators/reduce_ops.h @@ -30,7 +30,7 @@ class ReduceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reduce_prod_op.cc b/lite/operators/reduce_prod_op.cc index 90da13c8643fa030c376ca25cb3a67b70f3485a4..5a6194b36b9c0b4a95fb47049999da093f979e3b 100644 --- a/lite/operators/reduce_prod_op.cc +++ b/lite/operators/reduce_prod_op.cc @@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const { return true; } -bool ReduceProdOpLite::InferShape() const { +bool ReduceProdOpLite::InferShapeImpl() const { auto x = param_.x; auto out = param_.output; std::vector dim = param_.dim; diff --git a/lite/operators/reduce_prod_op.h b/lite/operators/reduce_prod_op.h index 5f7a6dcdf98eb99d9145b7e3108972f4debeaeb5..d8bb1400b9aecf449499d4c6920c2ef88eb119b2 100644 --- a/lite/operators/reduce_prod_op.h +++ b/lite/operators/reduce_prod_op.h @@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/relu_op.cc b/lite/operators/relu_op.cc index 9fa3ac8f30784b8349788dfd4eaf39252db1a156..e5f51676c69bcde6b68a9e9d17f936874a5ea86f 100644 --- a/lite/operators/relu_op.cc +++ b/lite/operators/relu_op.cc @@ -20,7 +20,7 @@ namespace lite { namespace operators { bool ReluOp::CheckShape() const { return true; } -bool ReluOp::InferShape() const { +bool ReluOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. diff --git a/lite/operators/relu_op.h b/lite/operators/relu_op.h index 23ca7ff16b48de747069f006cddbb9504e6942e3..7577f2ffbab62298138b22970c00caf9ab01367f 100644 --- a/lite/operators/relu_op.h +++ b/lite/operators/relu_op.h @@ -30,7 +30,7 @@ class ReluOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 655ac58bdcbfc0f8d9cdbb0ef0078db5eb0333fa..5c55eb4aa516ae3aecf49250f42d38491c1270f1 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const { return true; } -bool ReshapeOp::InferShape() const { +bool ReshapeOp::InferShapeImpl() const { const auto &shape_tensor_vct = param_.shape_tensor_vct; auto *shape_tensor = param_.shape_tensor; const auto &shape_vct = param_.shape_vct; @@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const { return true; } -bool Reshape2Op::InferShape() const { - ReshapeOp::InferShape(); +bool Reshape2Op::InferShapeImpl() const { + ReshapeOp::InferShapeImpl(); const auto &x_dims = param_.x->dims(); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; diff --git a/lite/operators/reshape_op.h b/lite/operators/reshape_op.h index 1df49fb5f44c88978b78f17885a5ba4412aa9ab7..9dc302ec9706512b16cd9e7db38b944d2d1324f5 100644 --- a/lite/operators/reshape_op.h +++ b/lite/operators/reshape_op.h @@ -30,7 +30,7 @@ class ReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/roi_align_op.cc b/lite/operators/roi_align_op.cc index 2f65c0197ecf1324678c63b6bd16018f83389702..001934dcf8f77527666c1b5cc0a01afcade2af81 100644 --- a/lite/operators/roi_align_op.cc +++ b/lite/operators/roi_align_op.cc @@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const { return true; } -bool RoiAlignOpLite::InferShape() const { +bool RoiAlignOpLite::InferShapeImpl() const { auto x_dims = param_.X->dims(); auto rois_dims = param_.ROIs->dims(); diff --git a/lite/operators/roi_align_op.h b/lite/operators/roi_align_op.h index f3dd1a47f5e2d0dbb39439c9789573b9b7a33728..65cc72534a2e2b63a1e024a55c766f2c1983f5ab 100644 --- a/lite/operators/roi_align_op.h +++ b/lite/operators/roi_align_op.h @@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/scale_op.cc b/lite/operators/scale_op.cc index 1398ea481194cae545fc8f1fa803eff5f5b78a31..3236277187462dd1185e698e5cb8fe919fe20b97 100644 --- a/lite/operators/scale_op.cc +++ b/lite/operators/scale_op.cc @@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const { return true; } -bool ScaleOp::InferShape() const { +bool ScaleOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); return true; } diff --git a/lite/operators/scale_op.h b/lite/operators/scale_op.h index 684da4ed47370090c5cb690ea728fa4f9147c4bf..38970bfcfd82eebce51612e6afb531cbf3b10966 100644 --- a/lite/operators/scale_op.h +++ b/lite/operators/scale_op.h @@ -30,7 +30,7 @@ class ScaleOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_aligned_mat_mul_op.cc b/lite/operators/search_aligned_mat_mul_op.cc index 43a276e3c7a2f7481ade2ee18c1446593f7c5f43..65ccbc2b793cb3a64c16a5b3bf7d869d8e271327 100644 --- a/lite/operators/search_aligned_mat_mul_op.cc +++ b/lite/operators/search_aligned_mat_mul_op.cc @@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const { return true; } -bool SearchAlignedMatMulOpLite::InferShape() const { +bool SearchAlignedMatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); const auto& x_lod = param_.X->lod(); diff --git a/lite/operators/search_aligned_mat_mul_op.h b/lite/operators/search_aligned_mat_mul_op.h index 7321b7e9d15331e6aad36364436a99d3d4089c8c..8242e06d0170a8a4c178f0e460c64f93b0c2bc3c 100644 --- a/lite/operators/search_aligned_mat_mul_op.h +++ b/lite/operators/search_aligned_mat_mul_op.h @@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_fc_op.cc b/lite/operators/search_fc_op.cc index 2e77e361624e681aa93e36610674df0e1f9a13af..3c64f24e48f750b367b75431333401329721a9b9 100644 --- a/lite/operators/search_fc_op.cc +++ b/lite/operators/search_fc_op.cc @@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const { return true; } -bool SearchFcOpLite::InferShape() const { +bool SearchFcOpLite::InferShapeImpl() const { auto out_size = param_.out_size; lite::DDim dims(std::vector({-1, out_size})); param_.Out->Resize(dims); diff --git a/lite/operators/search_fc_op.h b/lite/operators/search_fc_op.h index a871cadd33b4f7d4b6130a0b8ac2974a738ac0c3..235c24c57ff0e925d763fa11a78f56cfe72613cd 100644 --- a/lite/operators/search_fc_op.h +++ b/lite/operators/search_fc_op.h @@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_grnn_op.cc b/lite/operators/search_grnn_op.cc index b56ae820bf9de4ffe6aa3f6db7a8e1385c8cc11f..1ced477c109d8cd93485f0193523887759939f17 100644 --- a/lite/operators/search_grnn_op.cc +++ b/lite/operators/search_grnn_op.cc @@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const { return true; } -bool SearchGrnnOpLite::InferShape() const { +bool SearchGrnnOpLite::InferShapeImpl() const { const auto& x_dims = param_.x->dims(); const auto& x_lod = param_.x->lod(); CHECK_OR_FALSE(!x_lod.empty()); diff --git a/lite/operators/search_grnn_op.h b/lite/operators/search_grnn_op.h index 670af8a6c9ff9eafa33018a0303ea1a36b0a1e01..de4b1d8a5c4d551970fcbb7b0c17de67214b5c9a 100644 --- a/lite/operators/search_grnn_op.h +++ b/lite/operators/search_grnn_op.h @@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_group_padding_op.cc b/lite/operators/search_group_padding_op.cc index 5ba4dde275f4b9662416bdf5190cacfafc56a40d..b97c710109ea9eb1ae3b1e50e3bdab3e1e97ac3e 100644 --- a/lite/operators/search_group_padding_op.cc +++ b/lite/operators/search_group_padding_op.cc @@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const { return true; } -bool SearchGroupPaddingOp::InferShape() const { +bool SearchGroupPaddingOp::InferShapeImpl() const { std::vector x_dims = param_.x->dims().Vectorize(); param_.out_emb_padding->Resize({-1, x_dims[1]}); diff --git a/lite/operators/search_group_padding_op.h b/lite/operators/search_group_padding_op.h index a8e96c9697b5f7de70349efa1f8b378a47c3823c..6a93c7410128aa86b034308562b8c3ccd4ca78df 100644 --- a/lite/operators/search_group_padding_op.h +++ b/lite/operators/search_group_padding_op.h @@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite { SearchGroupPaddingOp() {} explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "search_group_padding"; } diff --git a/lite/operators/search_seq_depadding_op.cc b/lite/operators/search_seq_depadding_op.cc index 12d5123e05b41665550fb7e6b90a636093959263..6ad4f1ab171486468bf34b8341344410ed99f59b 100644 --- a/lite/operators/search_seq_depadding_op.cc +++ b/lite/operators/search_seq_depadding_op.cc @@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const { return true; } -bool SearchSeqDepaddingOpLite::InferShape() const { +bool SearchSeqDepaddingOpLite::InferShapeImpl() const { DDim pad_dims = param_.pad->dims(); param_.out->Resize({-1, pad_dims[1]}); return true; diff --git a/lite/operators/search_seq_depadding_op.h b/lite/operators/search_seq_depadding_op.h index 445d9e0f3bcba6204243e80023d826bf53d90c60..aa1cc22d4b048ca81445e735e09226b7dfe2fd03 100644 --- a/lite/operators/search_seq_depadding_op.h +++ b/lite/operators/search_seq_depadding_op.h @@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/search_seq_fc_op.cc b/lite/operators/search_seq_fc_op.cc index c5cca5331ab80479656b1212df02c20d463a3707..2a4525ac6e6f7e0cdd62a0a653e7188b274545af 100644 --- a/lite/operators/search_seq_fc_op.cc +++ b/lite/operators/search_seq_fc_op.cc @@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const { return true; } -bool SearchSeqFcOpLite::InferShape() const { +bool SearchSeqFcOpLite::InferShapeImpl() const { const auto x_dims = param_.x->dims(); const auto w_dims = param_.w->dims(); const auto& x_lod = param_.x->lod(); diff --git a/lite/operators/search_seq_fc_op.h b/lite/operators/search_seq_fc_op.h index 3c4f7d82bfa66c2f323063f0297438c81ce18397..bacafcfe6ffa2a2c518cf3b8f226fa29c9b95e95 100644 --- a/lite/operators/search_seq_fc_op.h +++ b/lite/operators/search_seq_fc_op.h @@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/search_seq_softmax_op.cc b/lite/operators/search_seq_softmax_op.cc index 973ffa04c4562334af6d379b5446902036de8c5e..9b0550341c50df9cd48fa922139fc759c5289e97 100644 --- a/lite/operators/search_seq_softmax_op.cc +++ b/lite/operators/search_seq_softmax_op.cc @@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const { return true; } -bool SearchSeqSoftmaxOp::InferShape() const { +bool SearchSeqSoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); param_.output->set_lod(param_.x->lod()); return true; diff --git a/lite/operators/search_seq_softmax_op.h b/lite/operators/search_seq_softmax_op.h index f97e8ddd3a6c446fb5c53d5e603f43bbdf1e2525..dca3619eab9013f22d962b16c577c73862ee5e64 100644 --- a/lite/operators/search_seq_softmax_op.h +++ b/lite/operators/search_seq_softmax_op.h @@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_arithmetic_op.cc b/lite/operators/sequence_arithmetic_op.cc index 29c39ebc23f54c2c3c052e322575d97570195cfc..e17a179a860e13622979e5b42b07ae3459876fc7 100644 --- a/lite/operators/sequence_arithmetic_op.cc +++ b/lite/operators/sequence_arithmetic_op.cc @@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const { return true; } -bool SequenceArithmeticOp::InferShape() const { +bool SequenceArithmeticOp::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); param_.Out->set_lod(param_.X->lod()); return true; diff --git a/lite/operators/sequence_arithmetic_op.h b/lite/operators/sequence_arithmetic_op.h index 9f844dfbf429599d829bc786c66ba6d05e40d79d..cf9ef1583aeaed977c515441ca629b2e66efb3d2 100644 --- a/lite/operators/sequence_arithmetic_op.h +++ b/lite/operators/sequence_arithmetic_op.h @@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_concat_op.cc b/lite/operators/sequence_concat_op.cc index 88afe5e00fe2bfc173a8a1d1d0e63562cfb52518..91c70c0d2ff2d506d29dbeb01780de962f9a27f1 100644 --- a/lite/operators/sequence_concat_op.cc +++ b/lite/operators/sequence_concat_op.cc @@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const { return true; } -bool SequenceConcatOp::InferShape() const { return true; } +bool SequenceConcatOp::InferShapeImpl() const { return true; } bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/sequence_concat_op.h b/lite/operators/sequence_concat_op.h index 8cdc07ebca83b9c400b00a0f40556a788c5854e6..c7d61db7852fb8894c5c4ed7c3d4283480c90e48 100644 --- a/lite/operators/sequence_concat_op.h +++ b/lite/operators/sequence_concat_op.h @@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite { SequenceConcatOp() {} explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_concat"; } diff --git a/lite/operators/sequence_conv_op.cc b/lite/operators/sequence_conv_op.cc index 89596a22c616b45d0e72cc14501e4f6c148ad86c..681e05c9b69953c4dde6c873e66bee2e93839aaf 100644 --- a/lite/operators/sequence_conv_op.cc +++ b/lite/operators/sequence_conv_op.cc @@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const { return true; } -bool SequenceConvOp::InferShape() const { +bool SequenceConvOp::InferShapeImpl() const { const auto *input = param_.X; const auto *filter = param_.Filter; auto in_dims = input->dims(); diff --git a/lite/operators/sequence_conv_op.h b/lite/operators/sequence_conv_op.h index 34d65d3cc9324aea7b50a1d939a594b817889896..3ec7ac4d3da7822335e047ca1c681809914c192b 100644 --- a/lite/operators/sequence_conv_op.h +++ b/lite/operators/sequence_conv_op.h @@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite { SequenceConvOp() {} explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_expand_as_op.cc b/lite/operators/sequence_expand_as_op.cc index 22a4743103fd4b188357d067a062ea827de7aaa0..02c787b5a51749851de1484101a6339142bc9726 100644 --- a/lite/operators/sequence_expand_as_op.cc +++ b/lite/operators/sequence_expand_as_op.cc @@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const { return true; } -bool SequenceExpandAsOpLite::InferShape() const { +bool SequenceExpandAsOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); auto y_lod = param_.y->lod(); auto out_dims = x_dims; diff --git a/lite/operators/sequence_expand_as_op.h b/lite/operators/sequence_expand_as_op.h index 2eae8a26da31eb2937ab88f15d70bd44515e6a5f..19d6905c1a428ce4ac8b2cdb545f194bf47ee62d 100644 --- a/lite/operators/sequence_expand_as_op.h +++ b/lite/operators/sequence_expand_as_op.h @@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_expand_op.cc b/lite/operators/sequence_expand_op.cc index 0a5427a62ffca44070c9551a4f1c869ae184f0be..4bb3c66b26673a27a961729d6fe22d54ef9298fe 100644 --- a/lite/operators/sequence_expand_op.cc +++ b/lite/operators/sequence_expand_op.cc @@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const { return true; } -bool SequenceExpandOp::InferShape() const { +bool SequenceExpandOp::InferShapeImpl() const { const auto x_lod = param_.X->lod(); auto x_dims = param_.X->dims(); int ref_level = param_.ref_level; diff --git a/lite/operators/sequence_expand_op.h b/lite/operators/sequence_expand_op.h index da4b2fe71edb7f731bf53872960612e16efbef93..fffe2110d871941522e5924943be764e3ee51db5 100644 --- a/lite/operators/sequence_expand_op.h +++ b/lite/operators/sequence_expand_op.h @@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_pool_concat_op.cc b/lite/operators/sequence_pool_concat_op.cc index 9ee0d4d5967e0d36bb893b42033f2c5319c940bb..ce490e8246c621cb23b3a3eecc0e8ddc4bca28b1 100644 --- a/lite/operators/sequence_pool_concat_op.cc +++ b/lite/operators/sequence_pool_concat_op.cc @@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const { return true; } -bool SequencePoolConcatOp::InferShape() const { +bool SequencePoolConcatOp::InferShapeImpl() const { int out_dim = 0; for (int i = 0; i < param_.X.size(); ++i) { out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size()); diff --git a/lite/operators/sequence_pool_concat_op.h b/lite/operators/sequence_pool_concat_op.h index 7a70ceaf298ebd7d02c319b08a86f40dc36cb648..58e6fc18ba49f6885e1f4ffb86cba47ca86f9623 100644 --- a/lite/operators/sequence_pool_concat_op.h +++ b/lite/operators/sequence_pool_concat_op.h @@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite { SequencePoolConcatOp() {} explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_pool_op.cc b/lite/operators/sequence_pool_op.cc index be3726ffe7a73c50f92bec2f2a96fb1625e31a9e..6b4f7d8b789f11c815b86f7dcc990e6db7855bbd 100644 --- a/lite/operators/sequence_pool_op.cc +++ b/lite/operators/sequence_pool_op.cc @@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const { return true; } -bool SequencePoolOp::InferShape() const { +bool SequencePoolOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); out_dims[0] = input->lod()[0].size() - 1; diff --git a/lite/operators/sequence_pool_op.h b/lite/operators/sequence_pool_op.h index 215dd113a3e5d9cdb1707a9b1b70c5712a43ec5d..7b9e36bb5e6e5f47cf49b1bd0df62795b7d57b7e 100644 --- a/lite/operators/sequence_pool_op.h +++ b/lite/operators/sequence_pool_op.h @@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite { SequencePoolOp() {} explicit SequencePoolOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/sequence_reshape_op.cc b/lite/operators/sequence_reshape_op.cc index c7e86af65033205bcb389cecff8db14721507142..37ebd8a2bae3919062bc0e71e3a10193850e7877 100644 --- a/lite/operators/sequence_reshape_op.cc +++ b/lite/operators/sequence_reshape_op.cc @@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const { return true; } -bool SequenceReshapeOp::InferShape() const { +bool SequenceReshapeOp::InferShapeImpl() const { int new_dim = param_.new_dim; auto x_numel = param_.x->dims().production(); std::vector out_shape{x_numel / new_dim, diff --git a/lite/operators/sequence_reshape_op.h b/lite/operators/sequence_reshape_op.h index c8378aebc44acf22017eee17f5b58d6ff4dd65bf..4ef395bdaa762d178e925f088c5c2becd357f669 100644 --- a/lite/operators/sequence_reshape_op.h +++ b/lite/operators/sequence_reshape_op.h @@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_reverse_op.cc b/lite/operators/sequence_reverse_op.cc index dd8fa2e8fd5816cc92355c9c73caf1aa76baf36c..19a47cac9da666269fc5ef2a172ff0295b71e95d 100644 --- a/lite/operators/sequence_reverse_op.cc +++ b/lite/operators/sequence_reverse_op.cc @@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const { return true; } -bool SequenceReverseOp::InferShape() const { +bool SequenceReverseOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); param_.Out->Resize(out_dims); diff --git a/lite/operators/sequence_reverse_op.h b/lite/operators/sequence_reverse_op.h index 326d0f68927199e9353a5bbe8c072d342c9e3d69..68d9fdb0f16cf0b2e13b7ed7417572a7b971e785 100644 --- a/lite/operators/sequence_reverse_op.h +++ b/lite/operators/sequence_reverse_op.h @@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite { SequenceReverseOp() {} explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "sequence_reverse"; } diff --git a/lite/operators/sequence_softmax_op.cc b/lite/operators/sequence_softmax_op.cc index d106097ed5c2e3a712bbd87904164ccd612d1f9e..eb1821129d8b036a252fb36ab69094c8a58cce95 100644 --- a/lite/operators/sequence_softmax_op.cc +++ b/lite/operators/sequence_softmax_op.cc @@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const { CHECK_OR_FALSE(param_.Out); return true; } -bool SequenceSoftmaxOp::InferShape() const { +bool SequenceSoftmaxOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); diff --git a/lite/operators/sequence_softmax_op.h b/lite/operators/sequence_softmax_op.h index 37dfc0d444be5c608c87c2418041237d4ac4643c..5942cb0441d7af7237c7761fe4ccd5d613321c87 100644 --- a/lite/operators/sequence_softmax_op.h +++ b/lite/operators/sequence_softmax_op.h @@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sequence_topk_avg_pooling_op.cc b/lite/operators/sequence_topk_avg_pooling_op.cc index 6f5cbeeeee5816132d2ebcb7094949189931b931..cb6f12c4b33bfc04beae2574ca384fcd77ac5004 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.cc +++ b/lite/operators/sequence_topk_avg_pooling_op.cc @@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const { return true; } -bool SequenceTopkAvgPoolingOpLite::InferShape() const { +bool SequenceTopkAvgPoolingOpLite::InferShapeImpl() const { int channel_num = param_.channel_num; std::vector topks = param_.topks; auto row_dim = param_.ROW->dims(); diff --git a/lite/operators/sequence_topk_avg_pooling_op.h b/lite/operators/sequence_topk_avg_pooling_op.h index 1c1cfe3a9c7bc82c3e79fc372b98293183509dca..a619edc908a5e4d4a8db97a931acb2ce24e39008 100644 --- a/lite/operators/sequence_topk_avg_pooling_op.h +++ b/lite/operators/sequence_topk_avg_pooling_op.h @@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/sgd_op.cc b/lite/operators/sgd_op.cc index 621454259548d27f9dad23f01e1e392b007bcb5b..eb8cb6b72473310ca1df12e8510d74cc3d76f4aa 100644 --- a/lite/operators/sgd_op.cc +++ b/lite/operators/sgd_op.cc @@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const { return true; } -bool SGDOpLite::InferShape() const { +bool SGDOpLite::InferShapeImpl() const { param_.ParamOut->Resize(param_.Param->dims()); return true; } diff --git a/lite/operators/sgd_op.h b/lite/operators/sgd_op.h index 9159bf95a6a50b5cd7b5d0ffed15e06f8d0e11c5..6a29c8bfa61b455e2257600975e851860e8797cc 100644 --- a/lite/operators/sgd_op.h +++ b/lite/operators/sgd_op.h @@ -33,7 +33,7 @@ class SGDOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/shape_op.cc b/lite/operators/shape_op.cc index c6d5dc4d01a93dd4cc648358db0b6f462a116eb0..1661a909268eb15ea2c4b393e9a2831d438465c7 100644 --- a/lite/operators/shape_op.cc +++ b/lite/operators/shape_op.cc @@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const { return true; } -bool ShapeOpLite::InferShape() const { +bool ShapeOpLite::InferShapeImpl() const { std::vector shape_vec; shape_vec.push_back(static_cast(param_.X->dims().size())); param_.Out->Resize(shape_vec); diff --git a/lite/operators/shape_op.h b/lite/operators/shape_op.h index ada9961c75b1cbc6c91d94a4ed3473ca12d8dcd6..6512b8ac0213519b068a10a74fdcb9d715d73255 100644 --- a/lite/operators/shape_op.h +++ b/lite/operators/shape_op.h @@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/shuffle_channel_op.cc b/lite/operators/shuffle_channel_op.cc index 926aa932f3d278945b659b6113df6479c7515e20..d45643a3d82d9177f7719908ea572258e0029bef 100644 --- a/lite/operators/shuffle_channel_op.cc +++ b/lite/operators/shuffle_channel_op.cc @@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const { return true; } -bool ShuffleChannelOpLite::InferShape() const { +bool ShuffleChannelOpLite::InferShapeImpl() const { param_.Out->Resize(param_.X->dims()); return true; } diff --git a/lite/operators/shuffle_channel_op.h b/lite/operators/shuffle_channel_op.h index c48a47f61902087cecf874ee7ddee8313a3cf92a..768345898141dd869c6a59f69170559d68a9f498 100644 --- a/lite/operators/shuffle_channel_op.h +++ b/lite/operators/shuffle_channel_op.h @@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/slice_op.cc b/lite/operators/slice_op.cc index bbc3d1429e202dac7b9a53c00d83ee34de7ef3d1..cf7d94535cce5fa32d0f917c9d39e4746cee1c30 100644 --- a/lite/operators/slice_op.cc +++ b/lite/operators/slice_op.cc @@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const { return true; } -bool SliceOp::InferShape() const { +bool SliceOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto in_dims = param_.X->dims(); diff --git a/lite/operators/slice_op.h b/lite/operators/slice_op.h index 936a1405f46ffd9e3375da1cd57b0570b07fcbbf..ec69f23d8ded4a7435bec0a2bd1f838603c7a7be 100644 --- a/lite/operators/slice_op.h +++ b/lite/operators/slice_op.h @@ -30,7 +30,7 @@ class SliceOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 0989c9139763a435d67deb21a2ab233e1c2f3bd9..000953007c27e37bc05d85d810880f6ccd7728ce 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const { return true; } -bool SoftmaxOp::SmartInferShape() { - if (!last_input_shapes.empty() && !last_output_shapes.empty()) { - if (param_.x->dims() == last_input_shapes[0] && - param_.x->lod() == last_input_lods[0]) { - param_.output->Resize(last_output_shapes[0]); - param_.output->set_lod(last_output_lods[0]); - return true; - } - } - - this->InferShape(); - - if (!last_input_shapes.empty()) { - last_input_shapes.clear(); - last_input_lods.clear(); - } - last_input_shapes.push_back(param_.x->dims()); - last_input_lods.push_back(param_.x->lod()); - - if (!last_output_shapes.empty()) { - last_output_shapes.clear(); - last_output_lods.clear(); - } - last_output_shapes.push_back(param_.output->dims()); - last_output_lods.push_back(param_.output->lod()); - return true; -} - -bool SoftmaxOp::InferShape() const { +bool SoftmaxOp::InferShapeImpl() const { param_.output->Resize(param_.x->dims()); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x->lod(); diff --git a/lite/operators/softmax_op.h b/lite/operators/softmax_op.h index c65d039fda02c5396eff829bede3b4ffdeac0051..20dc2f461e4f83e0b363d44e07c4204c656f2cf3 100644 --- a/lite/operators/softmax_op.h +++ b/lite/operators/softmax_op.h @@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; - bool SmartInferShape() override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_lod_tensor_op.cc b/lite/operators/split_lod_tensor_op.cc index 9b665b6026a44caa31b89ec7806188f90f5f1595..2900c8165dba3b8f0b83ef288c89ed0e56b4820d 100644 --- a/lite/operators/split_lod_tensor_op.cc +++ b/lite/operators/split_lod_tensor_op.cc @@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const { return true; } -bool SplitLodTensorOpLite::InferShape() const { +bool SplitLodTensorOpLite::InferShapeImpl() const { auto x_dims = param_.x->dims(); param_.out_true->Resize(x_dims); param_.out_false->Resize(x_dims); diff --git a/lite/operators/split_lod_tensor_op.h b/lite/operators/split_lod_tensor_op.h index c7feef4f85df652d0c24f830076a078e20c111f9..fb7f85de5cae69d3c0844ee0eeabe98d45acde4a 100644 --- a/lite/operators/split_lod_tensor_op.h +++ b/lite/operators/split_lod_tensor_op.h @@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 834d68a3156700605e621a1ba71faec33fb7b745..71deb5631dd3523ebb0367b7db5e4049b785be7b 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const { return true; } -bool SplitOp::InferShape() const { +bool SplitOp::InferShapeImpl() const { const auto &outs = param_.output; auto in_dims = param_.x->dims(); int axis = param_.axis; diff --git a/lite/operators/split_op.h b/lite/operators/split_op.h index 66190742155a8268e510d5a8da47ab958a043418..3bb40a8d35e25145057d8c5790b25028ea571cd5 100644 --- a/lite/operators/split_op.h +++ b/lite/operators/split_op.h @@ -30,7 +30,7 @@ class SplitOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index 01f96c28ff6be38e426030aa3c580f28f73b3a38..633a6b4d4e45fd30bd72c8dcdfbbd96b8a8e8ebe 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const { return true; } -bool SqueezeOp::InferShape() const { +bool SqueezeOp::InferShapeImpl() const { std::vector squeeze_dims = param_.axes; DDim in_dims = param_.X->dims(); DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true); @@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const { return true; } -bool Squeeze2Op::InferShape() const { - SqueezeOp::InferShape(); +bool Squeeze2Op::InferShapeImpl() const { + SqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/squeeze_op.h b/lite/operators/squeeze_op.h index 1a550c5fbee59d43170b5ffa16caa81521c14d87..983e17acf6483da9e3e33c83b48e6e61455a4914 100644 --- a/lite/operators/squeeze_op.h +++ b/lite/operators/squeeze_op.h @@ -30,7 +30,7 @@ class SqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/stack_op.cc b/lite/operators/stack_op.cc index 8fdf61e8224aa06792bdbb3f41a4f1701039d8dd..0f9ba6662b16ce20acad497a4915cfc848b319cd 100644 --- a/lite/operators/stack_op.cc +++ b/lite/operators/stack_op.cc @@ -32,7 +32,7 @@ bool StackOp::CheckShape() const { return true; } -bool StackOp::InferShape() const { +bool StackOp::InferShapeImpl() const { auto input = param_.X; auto input_dims = input[0]->dims(); int axis = param_.axis; diff --git a/lite/operators/stack_op.h b/lite/operators/stack_op.h index 068d905338bde892b44630c64d3ec43771614f2a..9ce73057a313fd4b4f96914b3e962120de11ac43 100644 --- a/lite/operators/stack_op.h +++ b/lite/operators/stack_op.h @@ -31,7 +31,7 @@ class StackOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/subgraph_op.cc b/lite/operators/subgraph_op.cc index 58388669afa060d48ea4c3d674dff94c386f104a..9ac07e96334eda9f0001d33e0789f9de15c4ca67 100644 --- a/lite/operators/subgraph_op.cc +++ b/lite/operators/subgraph_op.cc @@ -22,7 +22,7 @@ namespace operators { bool SubgraphOp::CheckShape() const { return true; } -bool SubgraphOp::InferShape() const { return CheckShape(); /* enrich me */ } +bool SubgraphOp::InferShapeImpl() const { return CheckShape(); /* enrich me */ } bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { param_.input_names = op_desc.Input("Inputs"); diff --git a/lite/operators/subgraph_op.h b/lite/operators/subgraph_op.h index 7f593159c8651cc18fbea17e559f62297d5022e9..edbfb922044d60165e589d389cd8cfb3b2547796 100644 --- a/lite/operators/subgraph_op.h +++ b/lite/operators/subgraph_op.h @@ -35,7 +35,7 @@ class SubgraphOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; diff --git a/lite/operators/topk_op.cc b/lite/operators/topk_op.cc index fbfb825544870dfaf3e18d1595f2824970b7352b..4a68cbb4745473b21cc7b6c5f6c8fcef6e186e57 100644 --- a/lite/operators/topk_op.cc +++ b/lite/operators/topk_op.cc @@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const { return true; } -bool TopkOp::InferShape() const { +bool TopkOp::InferShapeImpl() const { auto out_dims = param_.X->dims(); out_dims[out_dims.size() - 1] = param_.K; auto out = param_.Out; diff --git a/lite/operators/topk_op.h b/lite/operators/topk_op.h index 037fa413ea5ce6fcb5eb04502cf232cea7e109e0..d5888e5f1800ba37f4bed61c146b6af75e3f91fc 100644 --- a/lite/operators/topk_op.h +++ b/lite/operators/topk_op.h @@ -30,7 +30,7 @@ class TopkOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index 71086b492b538e293a1f08ed7f492a46d6eb02f8..40780346d038c875a2eb96b11aff9d1c2a578a2f 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const { return true; } -bool TransposeOp::InferShape() const { +bool TransposeOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); @@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const { return true; } -bool Transpose2Op::InferShape() const { +bool Transpose2Op::InferShapeImpl() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); auto x_dims = param_.x->dims(); diff --git a/lite/operators/transpose_op.h b/lite/operators/transpose_op.h index ce352a7d82f4a9dd3899f21c252c003c1924dda6..39b75b96d858bb80a51e428b8d7f402258dd9cc1 100644 --- a/lite/operators/transpose_op.h +++ b/lite/operators/transpose_op.h @@ -31,7 +31,7 @@ class TransposeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -50,7 +50,7 @@ class Transpose2Op : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/uniform_random_op.cc b/lite/operators/uniform_random_op.cc index 93e74e2b0172e8c3948925f3334b011f37bc097e..512648bfe4acf245286c9be21223520789134897 100644 --- a/lite/operators/uniform_random_op.cc +++ b/lite/operators/uniform_random_op.cc @@ -22,7 +22,7 @@ namespace operators { bool UniformRandomOpLite::CheckShape() const { return true; } -bool UniformRandomOpLite::InferShape() const { +bool UniformRandomOpLite::InferShapeImpl() const { param_.Out->Resize(param_.shape); return true; } diff --git a/lite/operators/uniform_random_op.h b/lite/operators/uniform_random_op.h index f7dde8882f47fc533e0d47dac99acdb431509341..a7890ea3e74afb3fd67f7ba4d1f02861a7e4ae48 100644 --- a/lite/operators/uniform_random_op.h +++ b/lite/operators/uniform_random_op.h @@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 39b275b7b55f79f2c8daf16ab0a6acd2e76e8b48..b5ae90248abb4f2496a4dbca1c12317cf3a7d325 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const { return true; } -bool UnsqueezeOp::InferShape() const { +bool UnsqueezeOp::InferShapeImpl() const { std::vector final_axes; auto axes = param_.axes; auto *axes_tensor = param_.axes_tensor; @@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const { return true; } -bool Unsqueeze2Op::InferShape() const { - UnsqueezeOp::InferShape(); +bool Unsqueeze2Op::InferShapeImpl() const { + UnsqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); std::vector xshape_dims(x_dims.size() + 1, 1); for (size_t i = 0; i < x_dims.size(); i++) { diff --git a/lite/operators/unsqueeze_op.h b/lite/operators/unsqueeze_op.h index 1e88828c6c5fdef767850909c0dae8ec65e9d1e0..5139b69c63699f041973c3cf31b38d6c7e9fa847 100644 --- a/lite/operators/unsqueeze_op.h +++ b/lite/operators/unsqueeze_op.h @@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; @@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 51f43c709990d7ac1e664336e252ed684479b783..8cf11f6465d73646ec9bf846cbe6347bdc4b9f5b 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -21,7 +21,7 @@ namespace operators { bool VarConv2dOp::CheckShape() const { return true; } -bool VarConv2dOp::InferShape() const { return true; } +bool VarConv2dOp::InferShapeImpl() const { return true; } bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = const_cast( diff --git a/lite/operators/var_conv_2d_op.h b/lite/operators/var_conv_2d_op.h index ce6309419cc582c2f93250dd6e8e59c04a951f91..5fa492d28ec858426bea7d3d06598813d94dbbb8 100644 --- a/lite/operators/var_conv_2d_op.h +++ b/lite/operators/var_conv_2d_op.h @@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite { VarConv2dOp() {} explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {} bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "var_conv_2d"; } diff --git a/lite/operators/while_op.cc b/lite/operators/while_op.cc index dba266af770183698680a49cb7ba4fe5dda2f5b2..1dcf9553f331ee6646ad6d93de048728a0886116 100644 --- a/lite/operators/while_op.cc +++ b/lite/operators/while_op.cc @@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const { return true; } -bool WhileOpLite::InferShape() const { return true; } +bool WhileOpLite::InferShapeImpl() const { return true; } bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto inputs = op_desc.Input("X"); diff --git a/lite/operators/while_op.h b/lite/operators/while_op.h index fcba722dbc182d0de617c3bf397a0266dc3d9cb2..94aec15a6d3eb60036bf9c2168fdbd855b84a396 100644 --- a/lite/operators/while_op.h +++ b/lite/operators/while_op.h @@ -30,7 +30,7 @@ class WhileOpLite : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/write_to_array_op.cc b/lite/operators/write_to_array_op.cc index bf2d9bc4b755c5800497e895f597aff22147e34f..d2cf7b4f94513d1058c3b4f4de1ec70c8c244b7e 100644 --- a/lite/operators/write_to_array_op.cc +++ b/lite/operators/write_to_array_op.cc @@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const { return true; } -bool WriteToArrayOp::InferShape() const { +bool WriteToArrayOp::InferShapeImpl() const { int id = param_.I->data()[0]; if (param_.Out->size() < id + 1) { param_.Out->resize(id + 1); diff --git a/lite/operators/write_to_array_op.h b/lite/operators/write_to_array_op.h index 8c987a24509d915d2ec59b90808993abe779623e..9460b7e364047750991d03468956462497fc4cc1 100644 --- a/lite/operators/write_to_array_op.h +++ b/lite/operators/write_to_array_op.h @@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/operators/yolo_box_op.cc b/lite/operators/yolo_box_op.cc index c8186d3f3182e21856919c46b83fe96a6e2bef93..0a5481a8fb01b5401734beacbc18a0bafcc48457 100644 --- a/lite/operators/yolo_box_op.cc +++ b/lite/operators/yolo_box_op.cc @@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const { return true; } -bool YoloBoxOp::InferShape() const { +bool YoloBoxOp::InferShapeImpl() const { auto* X = param_.X; auto anchors = param_.anchors; int anchor_num = anchors.size() / 2; diff --git a/lite/operators/yolo_box_op.h b/lite/operators/yolo_box_op.h index 2e2ea6d63408ca7d1a1cd7db48b82bf1ced294de..85448000f34bb1f0b768f78bb5929d1a26462043 100644 --- a/lite/operators/yolo_box_op.h +++ b/lite/operators/yolo_box_op.h @@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite { bool CheckShape() const override; - bool InferShape() const override; + bool InferShapeImpl() const override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index ffed48cdc612bd7d5c7e701b0e198390976b7bef..e108e35af76c6b5f2c5719b650b06d849a2f3887 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -32,6 +32,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/activation_compute_test.cc b/lite/tests/kernels/activation_compute_test.cc index afbf194976c6e524c05e95f9273748ed70b96277..5a0b033b1b8c4d8f28aa05c3f2fcac40f2569bf4 100644 --- a/lite/tests/kernels/activation_compute_test.cc +++ b/lite/tests/kernels/activation_compute_test.cc @@ -425,19 +425,24 @@ TEST(Activation_swish, precision) { TEST(Activation_relu6, precision) { LOG(INFO) << "test relu6 op..."; -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); + Place place; + float abs_error = 2e-5; +#if defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_ARM) + place = TARGET(kARM); +#else + return; +#endif for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { - for (auto slope : {0.01, 0.1}) { - std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } -#endif } TEST(Activation_log, precision) { diff --git a/lite/tests/kernels/multiclass_nms_compute_test.cc b/lite/tests/kernels/multiclass_nms_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1190197bffdf505fec77c6b22b7871316a2d125 --- /dev/null +++ b/lite/tests/kernels/multiclass_nms_compute_test.cc @@ -0,0 +1,491 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +static void GetMaxScoreIndex(const std::vector& scores, + const T threshold, + int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), + sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +void SliceOneClass(const Tensor& items, + const int class_id, + Tensor* one_class_item) { + T* item_data = one_class_item->mutable_data(); + const T* items_data = items.data(); + const int64_t num_item = items.dims()[0]; + const int64_t class_num = items.dims()[1]; + if (items.dims().size() == 3) { + int64_t item_size = items.dims()[2]; + for (int i = 0; i < num_item; ++i) { + std::memcpy(item_data + i * item_size, + items_data + i * class_num * item_size + class_id * item_size, + sizeof(T) * item_size); + } + } else { + for (int i = 0; i < num_item; ++i) { + item_data[i] = items_data[i * class_num + class_id]; + } + } +} + +template +void NMSFast(const Tensor& bbox, + const Tensor& scores, + const T score_threshold, + const T nms_threshold, + const T eta, + const int64_t top_k, + std::vector* selected_indices, + const bool normalized) { + // The total boxes for each instance. + int64_t num_boxes = bbox.dims()[0]; + // 4: [xmin ymin xmax ymax] + // 8: [x1 y1 x2 y2 x3 y3 x4 y4] + // 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16 + int64_t box_size = bbox.dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores.data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); + + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + const T* bbox_data = bbox.data(); + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = T(0.); + // 4: [xmin ymin xmax ymax] + if (box_size == 4) { + overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, + normalized); + } else { + LOG(FATAL) << "not support"; + } + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } +} + +template +void MultiClassNMS(const Tensor& scores, + const Tensor& bboxes, + const int scores_size, + std::map>* indices, + int* num_nmsed_out, + int64_t background_label, + int64_t nms_top_k, + int64_t keep_top_k, + bool normalized, + T nms_threshold, + T nms_eta, + T score_threshold) { + int num_det = 0; + + int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1]; + Tensor bbox_slice, score_slice; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + if (scores_size == 3) { + score_slice = scores.Slice(c, c + 1); + bbox_slice = bboxes; + } else { + score_slice.Resize({scores.dims()[0], 1}); + bbox_slice.Resize({scores.dims()[0], 4}); + SliceOneClass(scores, c, &score_slice); + SliceOneClass(bboxes, c, &bbox_slice); + } + NMSFast(bbox_slice, + score_slice, + score_threshold, + nms_threshold, + nms_eta, + nms_top_k, + &((*indices)[c]), + normalized); + if (scores_size == 2) { + std::stable_sort((*indices)[c].begin(), (*indices)[c].end()); + } + num_det += (*indices)[c].size(); + } + + *num_nmsed_out = num_det; + const T* scores_data = scores.data(); + if (keep_top_k > -1 && num_det > keep_top_k) { + const T* sdata; + std::vector>> score_index_pairs; + for (const auto& it : *indices) { + int label = it.first; + if (scores_size == 3) { + sdata = scores_data + label * scores.dims()[1]; + } else { + score_slice.Resize({scores.dims()[0], 1}); + SliceOneClass(scores, label, &score_slice); + sdata = score_slice.data(); + } + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + score_index_pairs.push_back( + std::make_pair(sdata[idx], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), + score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + // Store the new indices. + std::map> new_indices; + for (size_t j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + if (scores_size == 2) { + for (const auto& it : new_indices) { + int label = it.first; + std::stable_sort(new_indices[label].begin(), new_indices[label].end()); + } + } + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; + } +} + +template +void MultiClassOutput(const Tensor& scores, + const Tensor& bboxes, + const std::map>& selected_indices, + const int scores_size, + Tensor* outs, + int* oindices = nullptr, + const int offset = 0) { + int64_t class_num = scores.dims()[1]; + int64_t predict_dim = scores.dims()[1]; + int64_t box_size = bboxes.dims()[1]; + if (scores_size == 2) { + box_size = bboxes.dims()[2]; + } + int64_t out_dim = box_size + 2; + auto* scores_data = scores.data(); + auto* bboxes_data = bboxes.data(); + auto* odata = outs->mutable_data(); + const T* sdata; + Tensor bbox; + bbox.Resize({scores.dims()[0], box_size}); + int count = 0; + for (const auto& it : selected_indices) { + int label = it.first; + const std::vector& indices = it.second; + if (scores_size == 2) { + SliceOneClass(bboxes, label, &bbox); + } else { + sdata = scores_data + label * predict_dim; + } + for (size_t j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + odata[count * out_dim] = label; // label + const T* bdata; + if (scores_size == 3) { + bdata = bboxes_data + idx * box_size; + odata[count * out_dim + 1] = sdata[idx]; // score + if (oindices != nullptr) { + oindices[count] = offset + idx; + } + } else { + bdata = bbox.data() + idx * box_size; + odata[count * out_dim + 1] = *(scores_data + idx * class_num + label); + if (oindices != nullptr) { + oindices[count] = offset + idx * class_num + label; + } + } + // xmin, ymin, xmax, ymax or multi-points coordinates + std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T)); + count++; + } + } +} + +class MulticlassNmsComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string type_ = "multiclass_nms"; + std::string bboxes_ = "bboxes"; + std::string scores_ = "scores"; + std::string out_ = "out"; + DDim bboxes_dims_{}; + DDim scores_dims_{}; + int keep_top_k_{2}; + float nms_threshold_{0.45f}; + float nms_eta_{1.f}; + int nms_top_k_{1}; + int background_label_{-1}; + float score_threshold_{0.01f}; + bool normalized_{false}; + + public: + MulticlassNmsComputeTester(const Place& place, + const std::string& alias, + DDim bboxes_dims, + DDim scores_dims, + int keep_top_k = 2, + float nms_threshold = 0.45f, + float nms_eta = 1.f, + int nms_top_k = 1, + int background_label = 1, + float score_threshold = 0.01f, + bool normalized = false) + : TestCase(place, alias), + bboxes_dims_(bboxes_dims), + scores_dims_(scores_dims), + keep_top_k_(keep_top_k), + nms_threshold_(nms_threshold), + nms_eta_(nms_eta), + nms_top_k_(nms_top_k), + background_label_(background_label), + score_threshold_(score_threshold), + normalized_(normalized) {} + + void RunBaseline(Scope* scope) override { + auto* boxes = scope->FindTensor(bboxes_); + auto* scores = scope->FindTensor(scores_); + auto* outs = scope->NewTensor(out_); + CHECK(outs); + outs->set_precision(PRECISION(kFloat)); + + auto score_size = scores_dims_.size(); + std::vector>> all_indices; + std::vector batch_starts = {0}; + int64_t batch_size = scores_dims_[0]; + int64_t box_dim = bboxes_dims_[2]; + int64_t out_dim = box_dim + 2; + int num_nmsed_out = 0; + Tensor boxes_slice, scores_slice; + int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1; + for (int i = 0; i < n; ++i) { + if (score_size == 3) { + scores_slice = scores->Slice(i, i + 1); + scores_slice.Resize({scores_dims_[1], scores_dims_[2]}); + boxes_slice = boxes->Slice(i, i + 1); + boxes_slice.Resize({scores_dims_[2], box_dim}); + } else { + auto boxes_lod = boxes->lod().back(); + scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); + boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); + } + std::map> indices; + MultiClassNMS(scores_slice, + boxes_slice, + score_size, + &indices, + &num_nmsed_out, + background_label_, + nms_top_k_, + keep_top_k_, + normalized_, + nms_threshold_, + nms_eta_, + score_threshold_); + all_indices.push_back(indices); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + uint64_t num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({1, 1}); + float* od = outs->mutable_data(); + od[0] = -1; + batch_starts = {0, 1}; + } else { + outs->Resize({static_cast(num_kept), out_dim}); + outs->mutable_data(); + int offset = 0; + int* oindices = nullptr; + for (int i = 0; i < n; ++i) { + if (score_size == 3) { + scores_slice = scores->Slice(i, i + 1); + boxes_slice = boxes->Slice(i, i + 1); + scores_slice.Resize({scores_dims_[1], scores_dims_[2]}); + boxes_slice.Resize({scores_dims_[2], box_dim}); + } else { + auto boxes_lod = boxes->lod().back(); + scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); + boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); + } + int64_t s = static_cast(batch_starts[i]); + int64_t e = static_cast(batch_starts[i + 1]); + if (e > s) { + Tensor out = outs->Slice(s, e); + MultiClassOutput(scores_slice, + boxes_slice, + all_indices[i], + scores_dims_.size(), + &out, + oindices, + offset); + } + } + } + + LoD lod; + lod.emplace_back(batch_starts); + outs->set_lod(lod); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(type_); + op_desc->SetInput("BBoxes", {bboxes_}); + op_desc->SetInput("Scores", {scores_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("keep_top_k", keep_top_k_); + op_desc->SetAttr("nms_threshold", nms_threshold_); + op_desc->SetAttr("nms_eta", nms_eta_); + op_desc->SetAttr("nms_top_k", nms_top_k_); + op_desc->SetAttr("background_label", background_label_); + op_desc->SetAttr("score_threshold", score_threshold_); + op_desc->SetAttr("normalized", normalized_); + } + + void PrepareData() override { + std::vector bboxes(bboxes_dims_.production()); + for (int i = 0; i < bboxes_dims_.production(); ++i) { + bboxes[i] = i * 1. / bboxes_dims_.production(); + } + SetCommonTensor(bboxes_, bboxes_dims_, bboxes.data()); + + std::vector scores(scores_dims_.production()); + for (int i = 0; i < scores_dims_.production(); ++i) { + scores[i] = i * 1. / scores_dims_.production(); + } + SetCommonTensor(scores_, scores_dims_, scores.data()); + } +}; + +void TestMulticlassNms(Place place, float abs_error) { + int N = 3; + int M = 2500; + for (int class_num : {2, 4, 10}) { + std::vector bbox_shape{N, M, 4}; + std::vector score_shape{N, class_num, M}; + std::unique_ptr tester(new MulticlassNmsComputeTester( + place, "def", DDim(bbox_shape), DDim(score_shape))); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} + +TEST(multiclass_nms, precision) { + float abs_error = 2e-5; + Place place; +#if defined(LITE_WITH_ARM) + place = TARGET(kHost); +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#else + return; +#endif + + TestMulticlassNms(place, abs_error); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/pool_compute_test.cc b/lite/tests/kernels/pool_compute_test.cc index 988c99bf7c0adb246ea7b7408054485aaf59dce8..04894188b0bf1557000479ae18b0369997909f89 100644 --- a/lite/tests/kernels/pool_compute_test.cc +++ b/lite/tests/kernels/pool_compute_test.cc @@ -276,9 +276,24 @@ void TestPoolHelper(Place place, std::string pooling_type, std::vector strides, std::vector paddings, - std::vector ksize) { - std::unique_ptr tester(new PoolComputeTest( - place, "def", DDim(dims), pooling_type, false, strides, paddings, ksize)); + std::vector ksize, + bool exclusive = true, + bool ceil_mode = false, + bool adaptive = false, + std::string padding_algorithm = "") { + std::unique_ptr tester( + new PoolComputeTest(place, + "def", + DDim(dims), + pooling_type, + false, + strides, + paddings, + ksize, + exclusive, + ceil_mode, + adaptive, + padding_algorithm)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -345,6 +360,20 @@ void TestPoolKsize(Place place, float abs_error = 2e-5) { } } +void TestPoolCeilMode(Place place, float abs_error = 2e-5) { + for (auto pooling_type : {"max", "avg"}) { + TestPoolHelper(place, + abs_error, + {2, 3, 6, 6}, + pooling_type, + {2, 2}, + {0, 0, 0, 0}, + {3, 3}, + true, + true); + } +} + TEST(Pool, precision) { LOG(INFO) << "test pool op"; float abs_error = 2e-5; @@ -363,6 +392,7 @@ TEST(Pool, precision) { TestPoolStrides(place, abs_error); TestPoolPaddings(place, abs_error); TestPoolKsize(place, abs_error); + TestPoolCeilMode(place, abs_error); } } // namespace lite diff --git a/lite/tests/kernels/yolo_box_compute_test.cc b/lite/tests/kernels/yolo_box_compute_test.cc index 2e98ce96cef479d55e77acebbe464d9a56f92934..c41c89608fd7496c5b01b1a813581f7f461ff0ee 100644 --- a/lite/tests/kernels/yolo_box_compute_test.cc +++ b/lite/tests/kernels/yolo_box_compute_test.cc @@ -228,14 +228,14 @@ class YoloBoxComputeTester : public arena::TestCase { } }; -void test_yolobox(Place place) { - for (int class_num : {1, 2, 3, 4}) { - for (float conf_thresh : {0.01, 0.2, 0.7}) { +void TestYoloBox(Place place, float abs_error) { + for (int class_num : {1, 4}) { + for (float conf_thresh : {0.01, 0.2}) { for (int downsample_ratio : {16, 32}) { - std::vector anchor({10, 13, 16, 30}); + std::vector anchor{10, 13, 16, 30, 33, 30}; std::unique_ptr tester(new YoloBoxComputeTester( place, "def", anchor, class_num, conf_thresh, downsample_ratio)); - arena::Arena arena(std::move(tester), place, 2e-5); + arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } } @@ -243,13 +243,17 @@ void test_yolobox(Place place) { } TEST(YoloBox, precision) { -// #ifdef LITE_WITH_X86 -// Place place(TARGET(kX86)); -// #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - test_yolobox(place); + float abs_error = 2e-5; + Place place; +#if defined(LITE_WITH_ARM) + place = TARGET(kARM); +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#else + return; #endif + + TestYoloBox(place, abs_error); } } // namespace lite diff --git a/lite/tools/build.sh b/lite/tools/build.sh index c67be25954e14e3c627863a28b8b80c61f5fab87..7ea15f23d961c11a9ee3969d4fbc0866ad76b1e3 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -14,6 +14,7 @@ readonly NUM_PROC=${LITE_BUILD_THREADS:-4} # global variables BUILD_EXTRA=OFF +BUILD_TRAIN=OFF BUILD_JAVA=ON BUILD_PYTHON=OFF BUILD_DIR=$(pwd) @@ -226,6 +227,7 @@ function make_full_publish_so { -DNPU_DDK_ROOT=$NPU_DDK_ROOT \ -DLITE_WITH_XPU=$BUILD_XPU \ -DXPU_SDK_ROOT=$XPU_SDK_ROOT \ + -DLITE_WITH_TRAIN=$BUILD_TRAIN \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} make publish_inference -j$NUM_PROC @@ -389,6 +391,7 @@ function print_usage { echo -e "optional argument:" echo -e "--shutdown_log: (OFF|ON); controls whether to shutdown log, default is ON" echo -e "--build_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP)" + echo -e "--build_train: (OFF|ON); controls whether to publish training operators and kernels, build_train is only for full_publish library now" echo -e "--build_python: (OFF|ON); controls whether to publish python api lib (ANDROID and IOS is not supported)" echo -e "--build_java: (OFF|ON); controls whether to publish java api lib (Only ANDROID is supported)" echo -e "--build_dir: directory for building" @@ -437,6 +440,10 @@ function main { BUILD_EXTRA="${i#*=}" shift ;; + --build_train=*) + BUILD_TRAIN="${i#*=}" + shift + ;; --build_cv=*) BUILD_CV="${i#*=}" shift diff --git a/lite/tools/build_bm.sh b/lite/tools/build_bm.sh index 2c3a8406f7e7c52ecb0268d581e043a3070ba028..5bc8544e9f81d07ed33ece4b33233b08b9227740 100755 --- a/lite/tools/build_bm.sh +++ b/lite/tools/build_bm.sh @@ -5,7 +5,7 @@ set -ex BM_SDK_ROOT="$(pwd)/third-party/bmlibs/bm_sc3_libs" # BM SDK TARGET_NAME="BM1682" # default target BUILD_EXTRA=OFF # ON(with sequence ops)/OFF -WITH_TESTING=ON # ON/OFF +WITH_TESTING=OFF # ON/OFF function print_usage { echo -e "\nUSAGE:" diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index 703da69fa59f3aa99bad9fb04c0decb591486058..a5dc2b741d2d3d5fdd2f08d13b7dc483a3065b0e 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -192,6 +192,7 @@ function build_opencl { cmake_opencl ${os} ${abi} ${lang} make opencl_clhpp -j$NUM_CORES_FOR_COMPILE + make publish_inference -j$NUM_CORES_FOR_COMPILE build $TESTS_FILE } diff --git a/lite/utils/cv/bgr_rotate.cc b/lite/utils/cv/bgr_rotate.cc index 82d977d491d06147b3fd04d490002eb6bedcf16a..333bf8575515fe4f5e063f8e55610c111c377571 100644 --- a/lite/utils/cv/bgr_rotate.cc +++ b/lite/utils/cv/bgr_rotate.cc @@ -1133,7 +1133,7 @@ bgr3 bgr2 bgr1 #ifdef __aarch64__ void rotate180_hwc(const uint8_t* src, uint8_t* dst, int w, int h_in) { int w_in = w * 3; - uint8_t zerobuff[30000]; // [w_in]; + uint8_t* zerobuff = new uint8_t[w_in]; memset(zerobuff, 0, w_in * sizeof(uint8_t)); int64_t stride_w = 24; for (int i = 0; i < h_in; i += 4) { @@ -1331,7 +1331,7 @@ void rotate180_hwc(const uint8_t* src, uint8_t* dst, int w, int h_in) { #else void rotate180_hwc(const uint8_t* src, uint8_t* dst, int w, int h_in) { int w_in = w * 3; - uint8_t zerobuff[30000]; // w_in + uint8_t* zerobuff = new uint8_t[w_in]; memset(zerobuff, 0, w_in * sizeof(uint8_t)); int stride_w = 24; // 4*8