提交 6554854a 编写于 作者: L Liu Yiqun

Merge branch 'develop' into step_rnn/opt_ddim_lite

test=develop
...@@ -120,6 +120,7 @@ ...@@ -120,6 +120,7 @@
# #
## Lite settings ## Lite settings
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto")
if (ARM_TARGET_OS STREQUAL "ios") if (ARM_TARGET_OS STREQUAL "ios")
set(PLATFORM "OS") set(PLATFORM "OS")
elseif(ARM_TARGET_OS STREQUAL "ios64") elseif(ARM_TARGET_OS STREQUAL "ios64")
......
...@@ -305,6 +305,26 @@ if(NOT IOS) ...@@ -305,6 +305,26 @@ if(NOT IOS)
FPGA_DEPS ${fpga_kernels} FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels} X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels}) CUDA_DEPS ${cuda_kernels})
lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels})
lite_cc_binary(multithread_test SRCS lite_multithread_test.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels}
ARM_DEPS ${arm_kernels}
CV_DEPS paddle_cv_arm
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels})
endif() endif()
#lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc #lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <string>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/device_info.h"
#include "lite/core/profile/timer.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
#include <thread> // NOLINT
using paddle::lite::profile::Timer;
DEFINE_string(input_shape,
"1,3,224,224",
"input shapes, separated by colon and comma");
DEFINE_string(model_dir_0, "", "model_dir_0");
DEFINE_string(input_shape_0,
"1,3,224,224",
"input shapes another, separated by colon and comma");
DEFINE_bool(use_optimize_nb,
false,
"optimized & naive buffer model for mobile devices");
DEFINE_int32(test_type, 0, "multithread test type");
namespace paddle {
namespace lite_api {
void OutputOptModel(const std::string& load_model_dir,
const std::string& save_optimized_model_dir,
const std::vector<std::vector<int64_t>>& input_shapes) {
lite_api::CxxConfig config;
config.set_model_dir(load_model_dir);
config.set_valid_places({
Place{TARGET(kARM), PRECISION(kFloat)},
});
auto predictor = lite_api::CreatePaddlePredictor(config);
// delete old optimized model
int ret = system(
paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str())
.c_str());
if (ret == 0) {
LOG(INFO) << "delete old optimized model " << save_optimized_model_dir;
}
predictor->SaveOptimizedModel(save_optimized_model_dir,
LiteModelType::kNaiveBuffer);
LOG(INFO) << "Load model from " << load_model_dir;
LOG(INFO) << "Save optimized model to " << save_optimized_model_dir;
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void Run(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const PowerMode power_mode,
const int thread_num,
const int repeat,
int tid,
const int warmup_times = 5) {
lite_api::MobileConfig config;
config.set_model_dir(model_dir);
config.set_power_mode(power_mode);
config.set_threads(thread_num);
auto predictor = lite_api::CreatePaddlePredictor(config);
for (int j = 0; j < input_shapes.size(); ++j) {
auto input_tensor = predictor->GetInput(j);
input_tensor->Resize(input_shapes[j]);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (int i = 0; i < input_shapes[j].size(); ++i) {
input_num *= input_shapes[j][i];
}
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
}
}
for (int i = 0; i < warmup_times; ++i) {
predictor->Run();
}
Timer ti;
for (int j = 0; j < repeat; ++j) {
ti.Start();
predictor->Run();
float t = ti.Stop();
auto output = predictor->GetOutput(0);
auto out = output->data<float>();
LOG(INFO) << "[thread " << tid << "] Model: " << model_dir
<< " output[0]:" << out[0] << "; output[1]:" << out[1];
}
LOG(INFO) << "[thread " << tid << "] Model: " << model_dir
<< ", power_mode: " << static_cast<int>(power_mode)
<< ", threads num " << thread_num
<< ", avg time: " << ti.LapTimes().Avg() << "ms"
<< ", min time: " << ti.LapTimes().Min() << " ms"
<< ", max time: " << ti.LapTimes().Max() << " ms.";
}
void RunTestType_00(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const PowerMode power_mode,
const int thread_num,
const int repeat,
const int warmup_times = 5) {
std::thread run_th0(Run,
input_shapes,
model_dir,
power_mode,
thread_num,
repeat,
0,
warmup_times);
Run(input_shapes, model_dir, power_mode, thread_num, repeat, 1, warmup_times);
run_th0.join();
}
void RunTestType_01(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const std::vector<std::vector<int64_t>>& input_shapes_0,
const std::string& model_dir_0,
const PowerMode power_mode,
const int thread_num,
const int repeat,
const int warmup_times = 5) {
std::thread run_th0(Run,
input_shapes,
model_dir,
power_mode,
thread_num,
repeat,
0,
warmup_times);
Run(input_shapes_0,
model_dir_0,
power_mode,
thread_num,
repeat,
1,
warmup_times);
run_th0.join();
}
void run_with_predictor(std::shared_ptr<PaddlePredictor> predictor,
const std::vector<std::vector<int64_t>>& input_shapes,
int index,
const std::string& name) {
for (int j = 0; j < input_shapes.size(); ++j) {
auto input_tensor = predictor->GetInput(j);
input_tensor->Resize(input_shapes[j]);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (int i = 0; i < input_shapes[j].size(); ++i) {
input_num *= input_shapes[j][i];
}
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
}
}
Timer ti;
ti.Start();
predictor->Run();
float t = ti.Stop();
auto output = predictor->GetOutput(0);
auto out = output->data<float>();
LOG(INFO) << "[thread " << index << "] name: " << name
<< ",run time: " << ti.LapTimes().Avg() << "ms"
<< " output[0]:" << out[0] << "; output[1]:" << out[1];
}
void RunTestType_10(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const PowerMode power_mode,
const int thread_num,
const int repeat,
int warmup = 5) {
lite_api::MobileConfig config;
config.set_model_dir(model_dir);
config.set_power_mode(power_mode);
config.set_threads(thread_num);
auto predictor = lite_api::CreatePaddlePredictor(config);
for (int i = 0; i < repeat; ++i) {
std::thread pre_th0(
run_with_predictor, predictor, input_shapes, i, model_dir);
pre_th0.join();
}
}
void RunTestType_11(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const std::vector<std::vector<int64_t>>& input_shapes_0,
const std::string& model_dir_0,
const PowerMode power_mode,
const int thread_num,
const int repeat,
int warmup = 5) {
lite_api::MobileConfig config;
config.set_model_dir(model_dir);
config.set_power_mode(power_mode);
config.set_threads(thread_num);
auto predictor = lite_api::CreatePaddlePredictor(config);
config.set_model_dir(model_dir_0);
auto predictor_0 = lite_api::CreatePaddlePredictor(config);
for (int i = 0; i < 2 * repeat; i += 2) {
std::thread pre_th0(
run_with_predictor, predictor, input_shapes, i, model_dir);
std::thread pre_th1(
run_with_predictor, predictor_0, input_shapes_0, i + 1, model_dir_0);
pre_th0.join();
pre_th1.join();
}
}
#endif
} // namespace lite_api
} // namespace paddle
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
LOG(INFO) << "usage: "
<< "--model_dir /path/to/your/model";
exit(0);
}
std::string save_optimized_model_dir = "";
std::string save_optimized_model_dir_0 = "";
if (FLAGS_use_optimize_nb) {
save_optimized_model_dir = FLAGS_model_dir;
save_optimized_model_dir_0 = FLAGS_model_dir_0;
} else {
save_optimized_model_dir = FLAGS_model_dir + "opt2";
save_optimized_model_dir_0 = FLAGS_model_dir_0 + "opt2";
}
auto split_string =
[](const std::string& str_in) -> std::vector<std::string> {
std::vector<std::string> str_out;
std::string tmp_str = str_in;
while (!tmp_str.empty()) {
size_t next_offset = tmp_str.find(":");
str_out.push_back(tmp_str.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return str_out;
};
auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> {
std::vector<int64_t> shape;
std::string tmp_str = str_shape;
while (!tmp_str.empty()) {
int dim = atoi(tmp_str.data());
shape.push_back(dim);
size_t next_offset = tmp_str.find(",");
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return shape;
};
std::vector<std::string> str_input_shapes = split_string(FLAGS_input_shape);
std::vector<std::vector<int64_t>> input_shapes;
for (int i = 0; i < str_input_shapes.size(); ++i) {
input_shapes.push_back(get_shape(str_input_shapes[i]));
}
std::vector<std::string> str_input_shapes_0 =
split_string(FLAGS_input_shape_0);
std::vector<std::vector<int64_t>> input_shapes_0;
for (int i = 0; i < str_input_shapes_0.size(); ++i) {
input_shapes_0.push_back(get_shape(str_input_shapes_0[i]));
}
if (!FLAGS_use_optimize_nb) {
// Output optimized model
paddle::lite_api::OutputOptModel(
FLAGS_model_dir, save_optimized_model_dir, input_shapes);
paddle::lite_api::OutputOptModel(
FLAGS_model_dir_0, save_optimized_model_dir_0, input_shapes_0);
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// Run inference using optimized model
if (FLAGS_test_type == 0) {
paddle::lite_api::RunTestType_00(
input_shapes,
save_optimized_model_dir,
static_cast<paddle::lite_api::PowerMode>(0),
FLAGS_threads,
FLAGS_repeats,
5);
LOG(INFO) << "=========above is case 0, below is case "
"1============================";
paddle::lite_api::RunTestType_10(
input_shapes,
save_optimized_model_dir,
static_cast<paddle::lite_api::PowerMode>(0),
FLAGS_threads,
FLAGS_repeats);
}
if (FLAGS_test_type == 1) {
paddle::lite_api::RunTestType_01(
input_shapes,
save_optimized_model_dir,
input_shapes_0,
save_optimized_model_dir_0,
static_cast<paddle::lite_api::PowerMode>(0),
FLAGS_threads,
FLAGS_repeats,
5);
LOG(INFO) << "=========above is case 0, below is case "
"1============================";
paddle::lite_api::RunTestType_11(
input_shapes,
save_optimized_model_dir,
input_shapes_0,
save_optimized_model_dir_0,
static_cast<paddle::lite_api::PowerMode>(0),
FLAGS_threads,
FLAGS_repeats);
}
#endif
return 0;
}
...@@ -32,26 +32,37 @@ ...@@ -32,26 +32,37 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <algorithm> #include <algorithm>
DEFINE_double(fraction_of_cpu_memory_to_use, #include "lite/utils/env.h"
1,
"Default use 100% of CPU memory for PaddlePaddle," // DEFINE_double(fraction_of_cpu_memory_to_use,
"reserve the rest for page tables, etc"); // 1,
DEFINE_uint64(initial_cpu_memory_in_mb, // "Default use 100% of CPU memory for PaddlePaddle,"
500ul, // "reserve the rest for page tables, etc");
"Initial CPU memory for PaddlePaddle, in MD unit."); double fraction_of_cpu_memory_to_use =
paddle::lite::GetDoubleFromEnv("fraction_of_cpu_memory_to_use", 1);
DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, // DEFINE_uint64(initial_cpu_memory_in_mb,
0.5, // 500ul,
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," // "Initial CPU memory for PaddlePaddle, in MD unit.");
"reserve the rest for page tables, etc"); uint64_t initial_cpu_memory_in_mb =
paddle::lite::GetUInt64FromEnv("initial_cpu_memory_in_mb", 500ul);
// DEFINE_double(
// fraction_of_cuda_pinned_memory_to_use,
// 0.5,
// "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
// "reserve the rest for page tables, etc");
double fraction_of_cuda_pinned_memory_to_use = paddle::lite::GetDoubleFromEnv(
"fraction_of_cuda_pinned_memory_to_use", 0.5);
// If use_pinned_memory is true, CPUAllocator calls mlock, which // If use_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange // returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount // between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we // of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory. // should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); // DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory.");
bool use_pinned_memory =
paddle::lite::GetBoolFromEnv("use_pinned_memory", true);
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -81,7 +92,7 @@ size_t CpuTotalPhysicalMemory() { ...@@ -81,7 +92,7 @@ size_t CpuTotalPhysicalMemory() {
size_t CpuMaxAllocSize() { size_t CpuMaxAllocSize() {
// For distributed systems, it requires configuring and limiting // For distributed systems, it requires configuring and limiting
// the fraction of memory to use. // the fraction of memory to use.
return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); return fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory();
} }
size_t CpuMinChunkSize() { size_t CpuMinChunkSize() {
...@@ -92,15 +103,14 @@ size_t CpuMinChunkSize() { ...@@ -92,15 +103,14 @@ size_t CpuMinChunkSize() {
size_t CpuMaxChunkSize() { size_t CpuMaxChunkSize() {
// Allow to allocate the maximum chunk size is roughly 3% of CPU memory, // Allow to allocate the maximum chunk size is roughly 3% of CPU memory,
// or the initial_cpu_memory_in_mb. // or the initial_cpu_memory_in_mb.
return std::min( return std::min(static_cast<size_t>(CpuMaxAllocSize() / 32),
static_cast<size_t>(CpuMaxAllocSize() / 32), static_cast<size_t>(initial_cpu_memory_in_mb * 1 << 20));
static_cast<size_t>(FLAGS_initial_cpu_memory_in_mb * 1 << 20));
} }
size_t CUDAPinnedMaxAllocSize() { size_t CUDAPinnedMaxAllocSize() {
// For distributed systems, it requires configuring and limiting // For distributed systems, it requires configuring and limiting
// the fraction of memory to use. // the fraction of memory to use.
return FLAGS_fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory(); return fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory();
} }
size_t CUDAPinnedMinChunkSize() { size_t CUDAPinnedMinChunkSize() {
......
...@@ -22,36 +22,46 @@ limitations under the License. */ ...@@ -22,36 +22,46 @@ limitations under the License. */
#include "lite/backends/x86/cupti_lib_path.h" #include "lite/backends/x86/cupti_lib_path.h"
#include "lite/backends/x86/port.h" #include "lite/backends/x86/port.h"
#include "lite/backends/x86/warpctc_lib_path.h" #include "lite/backends/x86/warpctc_lib_path.h"
#include "lite/utils/env.h"
#include "lite/utils/paddle_enforce.h" #include "lite/utils/paddle_enforce.h"
DEFINE_string(cudnn_dir, // DEFINE_string(cudnn_dir,
"", // "",
"Specify path for loading libcudnn.so. For instance, " // "Specify path for loading libcudnn.so. For instance, "
"/usr/local/cudnn/lib. If empty [default], dlopen " // "/usr/local/cudnn/lib. If empty [default], dlopen "
"will search cudnn from LD_LIBRARY_PATH"); // "will search cudnn from LD_LIBRARY_PATH");
std::string cudnn_dir = paddle::lite::GetStringFromEnv("cudnn_dir"); // NOLINT
DEFINE_string(cuda_dir, // DEFINE_string(cuda_dir,
"", // "",
"Specify path for loading cuda library, such as libcublas, " // "Specify path for loading cuda library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, " // "libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH"); // "dlopen will search cuda from LD_LIBRARY_PATH");
std::string cuda_dir = paddle::lite::GetStringFromEnv("cuda_dir"); // NOLINT
DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); // DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
std::string f_warpctc_dir = // NOLINT
paddle::lite::GetStringFromEnv("warpctc_dir"); // NOLINT
DEFINE_string(nccl_dir, // DEFINE_string(nccl_dir,
"", // "",
"Specify path for loading nccl library, such as libcublas, " // "Specify path for loading nccl library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, " // "libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH"); // "dlopen will search cuda from LD_LIBRARY_PATH");
std::string nccl_dir = paddle::lite::GetStringFromEnv("nccl_dir"); // NOLINT
DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so."); // DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so.");
std::string cupti_dir = paddle::lite::GetStringFromEnv("cupti_dir"); // NOLINT
DEFINE_string( // DEFINE_string(
tensorrt_dir, // tensorrt_dir,
"", // "",
"Specify path for loading tensorrt library, such as libnvinfer.so."); // "Specify path for loading tensorrt library, such as libnvinfer.so.");
std::string tensorrt_dir = // NOLINT
paddle::lite::GetStringFromEnv("tensorrt_dir"); // NOLINT
DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); // DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
std::string mklml_dir = paddle::lite::GetStringFromEnv("mklml_dir"); // NOLINT
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -180,28 +190,28 @@ auto error_msg = ...@@ -180,28 +190,28 @@ auto error_msg =
void* GetCublasDsoHandle() { void* GetCublasDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib"); return GetDsoHandleFromSearchPath(cuda_dir, "libcublas.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cublas_lib); return GetDsoHandleFromSearchPath(cuda_dir, win_cublas_lib);
#else #else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so"); return GetDsoHandleFromSearchPath(cuda_dir, "libcublas.so");
#endif #endif
} }
void* GetCUDNNDsoHandle() { void* GetCUDNNDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", false); return GetDsoHandleFromSearchPath(cudnn_dir, "libcudnn.dylib", false);
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib); return GetDsoHandleFromSearchPath(cudnn_dir, win_cudnn_lib);
#else #else
return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", false); return GetDsoHandleFromSearchPath(cudnn_dir, "libcudnn.so", false);
#endif #endif
} }
void* GetCUPTIDsoHandle() { void* GetCUPTIDsoHandle() {
std::string cupti_path = cupti_lib_path; std::string cupti_path = cupti_lib_path;
if (!FLAGS_cupti_dir.empty()) { if (!cupti_dir.empty()) {
cupti_path = FLAGS_cupti_dir; cupti_path = cupti_dir;
} }
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(cupti_path, "libcupti.dylib", false); return GetDsoHandleFromSearchPath(cupti_path, "libcupti.dylib", false);
...@@ -212,18 +222,18 @@ void* GetCUPTIDsoHandle() { ...@@ -212,18 +222,18 @@ void* GetCUPTIDsoHandle() {
void* GetCurandDsoHandle() { void* GetCurandDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib"); return GetDsoHandleFromSearchPath(cuda_dir, "libcurand.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_curand_lib); return GetDsoHandleFromSearchPath(cuda_dir, win_curand_lib);
#else #else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so"); return GetDsoHandleFromSearchPath(cuda_dir, "libcurand.so");
#endif #endif
} }
void* GetWarpCTCDsoHandle() { void* GetWarpCTCDsoHandle() {
std::string warpctc_dir = warpctc_lib_path; std::string warpctc_dir = warpctc_lib_path;
if (!FLAGS_warpctc_dir.empty()) { if (!f_warpctc_dir.empty()) {
warpctc_dir = FLAGS_warpctc_dir; warpctc_dir = f_warpctc_dir;
} }
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.dylib"); return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.dylib");
...@@ -236,27 +246,27 @@ void* GetWarpCTCDsoHandle() { ...@@ -236,27 +246,27 @@ void* GetWarpCTCDsoHandle() {
void* GetNCCLDsoHandle() { void* GetNCCLDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib"); return GetDsoHandleFromSearchPath(nccl_dir, "libnccl.dylib");
#else #else
return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so"); return GetDsoHandleFromSearchPath(nccl_dir, "libnccl.so");
#endif #endif
} }
void* GetTensorRtDsoHandle() { void* GetTensorRtDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); return GetDsoHandleFromSearchPath(tensorrt_dir, "libnvinfer.dylib");
#else #else
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so"); return GetDsoHandleFromSearchPath(tensorrt_dir, "libnvinfer.so");
#endif #endif
} }
void* GetMKLMLDsoHandle() { void* GetMKLMLDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.dylib"); return GetDsoHandleFromSearchPath(mklml_dir, "libmklml_intel.dylib");
#elif defined(_WIN32) #elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "mklml.dll"); return GetDsoHandleFromSearchPath(mklml_dir, "mklml.dll");
#else #else
return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.so"); return GetDsoHandleFromSearchPath(mklml_dir, "libmklml_intel.so");
#endif #endif
} }
......
...@@ -21,13 +21,15 @@ ...@@ -21,13 +21,15 @@
// posix_memalign // posix_memalign
#include "lite/backends/x86/cpu_info.h" #include "lite/backends/x86/cpu_info.h"
#include "lite/backends/x86/jit/macro.h" #include "lite/backends/x86/jit/macro.h"
#include "lite/utils/env.h"
#include "lite/utils/paddle_enforce.h" #include "lite/utils/paddle_enforce.h"
#ifndef _WIN32 #ifndef _WIN32
#define posix_memalign_free free #define posix_memalign_free free
#endif #endif
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); // DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
bool dump_jitcode = paddle::lite::GetBoolFromEnv("dump_jitcode");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
#include <vector> #include <vector>
#include "lite/backends/x86/jit/kernel_base.h" #include "lite/backends/x86/jit/kernel_base.h"
DECLARE_bool(dump_jitcode); // DECLARE_bool(dump_jitcode);
extern bool dump_jitcode;
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,7 +37,7 @@ class GenBase : public Kernel { ...@@ -36,7 +37,7 @@ class GenBase : public Kernel {
template <typename Func> template <typename Func>
Func getCode() const { Func getCode() const {
const unsigned char* code = this->getCodeInternal(); const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) { if (dump_jitcode) {
this->dumpCode(code); this->dumpCode(code);
} }
// Note: failed to cast with reinterpret_cast<const Func> on Mac clang, // Note: failed to cast with reinterpret_cast<const Func> on Mac clang,
......
...@@ -86,7 +86,8 @@ class BeamSearchFunctor<TARGET(kX86), T> { ...@@ -86,7 +86,8 @@ class BeamSearchFunctor<TARGET(kX86), T> {
// selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace()); // selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace());
// auto *selected_scores_data = // auto *selected_scores_data =
// selected_scores->mutable_data<float>(dims, platform::CPUPlace()); // selected_scores->mutable_data<float>(dims, platform::CPUPlace());
parent_idx->Resize({static_cast<int64_t>(num_instances)}); parent_idx->Resize(
std::vector<int64_t>({static_cast<int64_t>(num_instances)}));
auto *parent_idx_data = auto *parent_idx_data =
parent_idx ? parent_idx->mutable_data<int>(TARGET(kX86)) : nullptr; parent_idx ? parent_idx->mutable_data<int>(TARGET(kX86)) : nullptr;
// auto *parent_idx_data = // auto *parent_idx_data =
......
...@@ -83,14 +83,11 @@ class KernelBase { ...@@ -83,14 +83,11 @@ class KernelBase {
#if defined(LITE_WITH_CUDA) #if defined(LITE_WITH_CUDA)
WorkSpace::Global_CUDA().AllocReset(); WorkSpace::Global_CUDA().AllocReset();
#endif #endif
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. " profiler_->StopTiming(profile::Type::kCreate, profile_id_, ctx_.get());
"When LITE_WITH_PROFILE is defined, please set a " profiler_->StartTiming(profile::Type::kDispatch, profile_id_, ctx_.get());
"Profiler for Instruction.";
profiler_->StartTiming(profile_id_, ctx_.get());
Run(); Run();
profiler_->StopTiming(profile_id_, ctx_.get()); profiler_->StopTiming(profile::Type::kDispatch, profile_id_, ctx_.get());
#else #else
Run(); Run();
#endif #endif
......
...@@ -120,6 +120,7 @@ class Buffer { ...@@ -120,6 +120,7 @@ class Buffer {
if (space_ > 0) { if (space_ > 0) {
TargetFree(target_, data_); TargetFree(target_, data_);
} }
data_ = nullptr;
target_ = TargetType::kHost; target_ = TargetType::kHost;
space_ = 0; space_ = 0;
} }
......
...@@ -28,36 +28,55 @@ auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) { ...@@ -28,36 +28,55 @@ auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) {
}; };
} }
int Profiler::NewTimer(const OpCharacter& ch) { std::map<Type, std::string> TypeStr{
StatisUnit unit; {Type::kUnk, "Unknown"},
unit.character = ch; {Type::kCreate, "Create"},
{Type::kDispatch, "Dispatch"},
};
StatisUnit::StatisUnit(const OpCharacter& ch) : character(ch) {
create_t.reset(new DeviceTimer<TargetType::kHost>());
if (ch.target == TargetType::kCUDA) { if (ch.target == TargetType::kCUDA) {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
unit.timer.reset(new DeviceTimer<TargetType::kCUDA>()); dispatch_t.reset(new DeviceTimer<TargetType::kCUDA>());
#else #else
LOG(ERROR) << "The timer type specified as cuda is uninitialized, so the " LOG(ERROR) << "The timer type specified as cuda is uninitialized, so the "
"default x86 timer is used instead."; "default x86 timer is used instead.";
#endif #endif
} else { } else {
unit.timer.reset(new DeviceTimer<TargetType::kHost>()); dispatch_t.reset(new DeviceTimer<TargetType::kHost>());
} }
}
lite::profile::Timer* StatisUnit::Timer(Type type) {
if (type == Type::kCreate) {
return create_t.get();
} else if (type == Type::kDispatch) {
return dispatch_t.get();
}
LOG(FATAL) << "Timer cannot be returned for unknown platforms.";
return nullptr;
}
int Profiler::NewTimer(const OpCharacter& ch) {
StatisUnit unit(ch);
units_.push_back(std::move(unit)); units_.push_back(std::move(unit));
return units_.size() - 1; return units_.size() - 1;
} }
void Profiler::StartTiming(const int index, KernelContext* ctx) { void Profiler::StartTiming(Type type, const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size()) CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range."; << "The timer index in the profiler is out of range.";
units_[index].timer->Start(ctx); units_[index].Timer(type)->Start(ctx);
} }
float Profiler::StopTiming(const int index, KernelContext* ctx) { float Profiler::StopTiming(Type type, const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size()) CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range."; << "The timer index in the profiler is out of range.";
return units_[index].timer->Stop(ctx); return units_[index].Timer(type)->Stop(ctx);
} }
std::string Profiler::Summary(bool concise, size_t w) { std::string Profiler::Summary(Type type, bool concise, size_t w) {
using std::setw; using std::setw;
using std::left; using std::left;
using std::fixed; using std::fixed;
...@@ -65,12 +84,14 @@ std::string Profiler::Summary(bool concise, size_t w) { ...@@ -65,12 +84,14 @@ std::string Profiler::Summary(bool concise, size_t w) {
std::string title; std::string title;
// Title. // Title.
if (concise) { if (concise) {
ss << "Timing cycle = " << units_.front().timer->LapTimes().Size() ss << "Timing cycle = " << units_.front().Timer(type)->LapTimes().Size()
<< std::endl; << std::endl;
ss << "===== Concise Profiler Summary: " << name_ << ", Exclude " << w ss << "===== Concise " << TypeStr.find(type)->second
<< " Profiler Summary: " << name_ << ", Exclude " << w
<< " warm-ups =====" << std::endl; << " warm-ups =====" << std::endl;
} else { } else {
ss << "===== Detailed Profiler Summary: " << name_ << ", Exclude " << w ss << "===== Detailed " << TypeStr.find(type)->second
<< " Profiler Summary: " << name_ << ", Exclude " << w
<< " warm-ups =====" << std::endl; << " warm-ups =====" << std::endl;
} }
ss << setw(25) << left << "Operator Type" ss << setw(25) << left << "Operator Type"
...@@ -84,16 +105,16 @@ std::string Profiler::Summary(bool concise, size_t w) { ...@@ -84,16 +105,16 @@ std::string Profiler::Summary(bool concise, size_t w) {
if (concise) { if (concise) {
std::map<OpCharacter, TimeInfo, decltype(op_comp)> summary(op_comp); std::map<OpCharacter, TimeInfo, decltype(op_comp)> summary(op_comp);
for (auto& unit : units_) { for (auto& unit : units_) {
auto ch = summary.find(unit.character); auto ch = summary.find(unit.Character());
if (ch != summary.end()) { if (ch != summary.end()) {
ch->second.avg += unit.timer->LapTimes().Avg(w); ch->second.avg += unit.Timer(type)->LapTimes().Avg(w);
ch->second.min += unit.timer->LapTimes().Min(w); ch->second.min += unit.Timer(type)->LapTimes().Min(w);
ch->second.max += unit.timer->LapTimes().Max(w); ch->second.max += unit.Timer(type)->LapTimes().Max(w);
} else { } else {
TimeInfo info({unit.timer->LapTimes().Avg(w), TimeInfo info({unit.Timer(type)->LapTimes().Avg(w),
unit.timer->LapTimes().Min(w), unit.Timer(type)->LapTimes().Min(w),
unit.timer->LapTimes().Max(w)}); unit.Timer(type)->LapTimes().Max(w)});
summary.insert({unit.character, info}); summary.insert({unit.Character(), info});
} }
} }
for (const auto& item : summary) { for (const auto& item : summary) {
...@@ -109,14 +130,15 @@ std::string Profiler::Summary(bool concise, size_t w) { ...@@ -109,14 +130,15 @@ std::string Profiler::Summary(bool concise, size_t w) {
} }
} else { } else {
for (auto& unit : units_) { for (auto& unit : units_) {
const auto& times = unit.Timer(type)->LapTimes();
// clang-format off // clang-format off
ss << setw(25) << left << fixed << unit.character.op_type \ ss << setw(25) << left << fixed << unit.Character().op_type \
<< " " << setw(40) << left << fixed << unit.character.kernel_name \ << " " << setw(40) << left << fixed << unit.Character().kernel_name \
<< " " << setw(12) << left << fixed << unit.character.remark \ << " " << setw(12) << left << fixed << unit.Character().remark \
<< " " << setw(12) << left << fixed << unit.timer->LapTimes().Avg(w) \ << " " << setw(12) << left << fixed << times.Avg(w) \
<< " " << setw(12) << left << fixed << unit.timer->LapTimes().Min(w) \ << " " << setw(12) << left << fixed << times.Min(w) \
<< " " << setw(12) << left << fixed << unit.timer->LapTimes().Max(w) \ << " " << setw(12) << left << fixed << times.Max(w) \
<< " " << setw(12) << left << fixed << unit.timer->LapTimes().Last(w) \ << " " << setw(12) << left << fixed << times.Last(w) \
<< std::endl; << std::endl;
// clang-format on // clang-format on
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -22,6 +23,14 @@ namespace paddle { ...@@ -22,6 +23,14 @@ namespace paddle {
namespace lite { namespace lite {
namespace profile { namespace profile {
enum class Type {
kUnk = 0,
kCreate,
kDispatch,
};
extern std::map<Type, std::string> TypeStr;
struct TimeInfo { struct TimeInfo {
float avg; float avg;
float min; float min;
...@@ -35,8 +44,15 @@ struct OpCharacter { ...@@ -35,8 +44,15 @@ struct OpCharacter {
std::string remark{std::string("N/A")}; std::string remark{std::string("N/A")};
}; };
struct StatisUnit { class StatisUnit final {
std::unique_ptr<Timer> timer; public:
explicit StatisUnit(const OpCharacter& ch);
lite::profile::Timer* Timer(Type type);
const OpCharacter& Character() const { return character; }
protected:
std::unique_ptr<lite::profile::Timer> create_t;
std::unique_ptr<lite::profile::Timer> dispatch_t;
OpCharacter character; OpCharacter character;
}; };
...@@ -45,9 +61,9 @@ class Profiler final { ...@@ -45,9 +61,9 @@ class Profiler final {
Profiler() = default; Profiler() = default;
explicit Profiler(const std::string& name) : name_(name) {} explicit Profiler(const std::string& name) : name_(name) {}
int NewTimer(const OpCharacter& ch); int NewTimer(const OpCharacter& ch);
void StartTiming(const int index, KernelContext* ctx); void StartTiming(Type type, const int index, KernelContext* ctx);
float StopTiming(const int index, KernelContext* ctx); float StopTiming(Type type, const int index, KernelContext* ctx);
std::string Summary(bool concise = true, size_t warm_up = 10); std::string Summary(Type type, bool concise = true, size_t warm_up = 10);
private: private:
std::string name_{std::string("N/A")}; std::string name_{std::string("N/A")};
......
...@@ -69,10 +69,10 @@ TEST(profiler, real_latency) { ...@@ -69,10 +69,10 @@ TEST(profiler, real_latency) {
ch.op_type = "operator/1"; ch.op_type = "operator/1";
ch.kernel_name = "kernel/1"; ch.kernel_name = "kernel/1";
int idx = profiler.NewTimer(ch); int idx = profiler.NewTimer(ch);
profiler.StartTiming(idx, &ctx); profiler.StartTiming(Type::kDispatch, idx, &ctx);
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
profiler.StopTiming(idx, &ctx); profiler.StopTiming(Type::kDispatch, idx, &ctx);
std::cout << profiler.Summary(); std::cout << profiler.Summary(Type::kDispatch);
} }
#endif #endif
......
...@@ -147,7 +147,7 @@ void RuntimeProgram::Run() { ...@@ -147,7 +147,7 @@ void RuntimeProgram::Run() {
#endif // LITE_WITH_PROFILE #endif // LITE_WITH_PROFILE
} }
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(false, 0); LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 0);
#endif // LITE_WITH_PROFILE #endif // LITE_WITH_PROFILE
} }
...@@ -252,8 +252,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { ...@@ -252,8 +252,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
} }
void Instruction::Run() { void Instruction::Run() {
#ifdef LITE_WITH_PROFILE
CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. "
"When LITE_WITH_PROFILE is defined, please set a "
"Profiler for Instruction.";
profiler_->StartTiming(
profile::Type::kCreate, profile_id_, kernel_->mutable_context());
#endif
CHECK(op_) << "op null"; CHECK(op_) << "op null";
CHECK(kernel_) << "kernel null"; CHECK(kernel_) << "kernel null";
if (first_epoch_) { if (first_epoch_) {
first_epoch_ = false; first_epoch_ = false;
CHECK(op_->CheckShape()); CHECK(op_->CheckShape());
...@@ -263,10 +271,7 @@ void Instruction::Run() { ...@@ -263,10 +271,7 @@ void Instruction::Run() {
return; return;
} }
// VLOG(4) << "kernel launch";
op_->InferShape(); op_->InferShape();
// VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
// << TargetToStr(kernel_->target());
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -143,7 +143,8 @@ class LITE_API RuntimeProgram { ...@@ -143,7 +143,8 @@ class LITE_API RuntimeProgram {
} }
~RuntimeProgram() { ~RuntimeProgram() {
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary(); LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate);
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch);
#endif // LITE_WITH_PROFILE #endif // LITE_WITH_PROFILE
} }
......
...@@ -233,6 +233,10 @@ class TensorLite { ...@@ -233,6 +233,10 @@ class TensorLite {
(static_cast<char *>(buffer_->data()) + offset_)); (static_cast<char *>(buffer_->data()) + offset_));
} }
void clear() {
buffer_->Free();
offset_ = 0;
}
size_t data_size() const { return this->dims().production(); } size_t data_size() const { return this->dims().production(); }
size_t memory_size() const { return memory_size_; } size_t memory_size() const { return memory_size_; }
......
...@@ -34,6 +34,9 @@ void ConditionalBlockCompute::PrepareForRun() { ...@@ -34,6 +34,9 @@ void ConditionalBlockCompute::PrepareForRun() {
} }
void ConditionalBlockCompute::Run() { void ConditionalBlockCompute::Run() {
auto& param = Param<operators::ConditionalBlockParam>(); auto& param = Param<operators::ConditionalBlockParam>();
for (auto& out : param.outs) {
out->clear();
}
bool need_run = true; bool need_run = true;
if (param.is_scalar_condition) { if (param.is_scalar_condition) {
auto* cond = param.cond; auto* cond = param.cond;
......
...@@ -82,6 +82,10 @@ void SplitLodTensorCompute::Run() { ...@@ -82,6 +82,10 @@ void SplitLodTensorCompute::Run() {
ranges.begin(), ranges.end(), 0UL, [](size_t a, const CopyRange &b) { ranges.begin(), ranges.end(), 0UL, [](size_t a, const CopyRange &b) {
return a + b.end - b.begin; return a + b.end - b.begin;
}); });
if (height == 0) {
out->clear();
continue;
}
auto x_dim = x->dims(); auto x_dim = x->dims();
x_dim[0] = static_cast<int64_t>(height); x_dim[0] = static_cast<int64_t>(height);
out->Resize(x_dim); out->Resize(x_dim);
......
...@@ -54,12 +54,12 @@ REGISTER_LITE_KERNEL(unsqueeze, ...@@ -54,12 +54,12 @@ REGISTER_LITE_KERNEL(unsqueeze,
kNCHW, kNCHW,
paddle::lite::kernels::host::UnsqueezeCompute, paddle::lite::kernels::host::UnsqueezeCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("AxesTensor", .BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList", .BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(unsqueeze2, REGISTER_LITE_KERNEL(unsqueeze2,
...@@ -68,11 +68,11 @@ REGISTER_LITE_KERNEL(unsqueeze2, ...@@ -68,11 +68,11 @@ REGISTER_LITE_KERNEL(unsqueeze2,
kNCHW, kNCHW,
paddle::lite::kernels::host::Unsqueeze2Compute, paddle::lite::kernels::host::Unsqueeze2Compute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("AxesTensor", .BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList", .BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -54,7 +54,8 @@ REGISTER_LITE_KERNEL(yolo_box, ...@@ -54,7 +54,8 @@ REGISTER_LITE_KERNEL(yolo_box,
paddle::lite::kernels::arm::YoloBoxCompute, paddle::lite::kernels::arm::YoloBoxCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("ImgSize",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() { ...@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() {
cudaGetDevice(&device_id); cudaGetDevice(&device_id);
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id); cudaGetDeviceProperties(&deviceProp, device_id);
sharedmem_size = deviceProp.sharedMemPerBlock; sharedmem_size_ = deviceProp.sharedMemPerBlock;
max_dimsize = sharedmem_size / sizeof(float) / CUDA_NUM_THREADS; max_dimsize_ = sharedmem_size_ / sizeof(float) / CUDA_NUM_THREADS;
} }
void SoftmaxCompute::Run() { void SoftmaxCompute::Run() {
...@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() { ...@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() {
int outer_num = x_dims.Slice(0, axis).production(); int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int total_threads = inner_num * outer_num; int total_threads = inner_num * outer_num;
int axis_size = x_dims[axis]; axis_size_ = x_dims[axis];
const int threads = CUDA_NUM_THREADS; const int threads = CUDA_NUM_THREADS;
const int blocks = (total_threads + threads - 1) / threads; const int blocks = (total_threads + threads - 1) / threads;
auto input_data = param.x->data<float>(); auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA)); auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (axis_size <= max_dimsize) { if (axis_size_ <= max_dimsize_) {
int use_sharemem_size = axis_size * threads * sizeof(float); int use_sharemem_size = axis_size_ * threads * sizeof(float);
sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>( sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
total_threads, total_threads,
input_data, input_data,
output_data, output_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
} else { } else {
//! re_alloc device memory //! re_alloc device memory
Tensor tmax_data; tmax_data_.Resize({1, 1, 1, outer_num * inner_num});
Tensor tsum_data; tsum_data_.Resize({1, 1, 1, outer_num * inner_num});
tmax_data.Resize({1, 1, 1, outer_num * inner_num}); auto max_data = tmax_data_.mutable_data<float>(TARGET(kCUDA));
tsum_data.Resize({1, 1, 1, outer_num * inner_num}); auto sum_data = tsum_data_.mutable_data<float>(TARGET(kCUDA));
auto max_data = tmax_data.mutable_data<float>(TARGET(kCUDA));
auto sum_data = tsum_data.mutable_data<float>(TARGET(kCUDA));
//! firstly, get maximum data //! firstly, get maximum data
float min_data = std::numeric_limits<float>::lowest(); float min_data = std::numeric_limits<float>::lowest();
softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads, softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads,
...@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() { ...@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() {
min_data, min_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
//! then, compute exp and sum data //! then, compute exp and sum data
softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>( softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, total_threads,
...@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() { ...@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() {
sum_data, sum_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
//! last, compute divided output //! last, compute divided output
softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>( softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, output_data, sum_data, inner_num, outer_num, axis_size); total_threads, output_data, sum_data, inner_num, outer_num, axis_size_);
} }
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
......
...@@ -30,9 +30,11 @@ class SoftmaxCompute ...@@ -30,9 +30,11 @@ class SoftmaxCompute
virtual ~SoftmaxCompute() = default; virtual ~SoftmaxCompute() = default;
private: private:
size_t sharedmem_size; lite::Tensor tmax_data_;
int num_threads; lite::Tensor tsum_data_;
int max_dimsize; size_t sharedmem_size_;
int max_dimsize_;
int axis_size_;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -28,12 +28,14 @@ namespace subgraph { ...@@ -28,12 +28,14 @@ namespace subgraph {
class Engine { class Engine {
public: public:
Engine(int block_idx, Engine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
lite::Scope *scope) lite::Scope *scope)
: block_idx_(block_idx), : ctx_(ctx),
block_idx_(block_idx),
block_desc_(block_desc), block_desc_(block_desc),
input_names_(input_names), input_names_(input_names),
output_names_(output_names), output_names_(output_names),
...@@ -55,6 +57,7 @@ class Engine { ...@@ -55,6 +57,7 @@ class Engine {
virtual bool InputShapeChanged(); virtual bool InputShapeChanged();
KernelContext *ctx_{nullptr};
int block_idx_; int block_idx_;
cpp::BlockDesc *block_desc_; cpp::BlockDesc *block_desc_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
......
...@@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() { ...@@ -207,7 +207,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx, engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc, param.sub_block_desc,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names,
......
...@@ -29,13 +29,14 @@ namespace npu { ...@@ -29,13 +29,14 @@ namespace npu {
class SubgraphEngine : public subgraph::Engine { class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(int block_idx, SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
Scope *scope) Scope *scope)
: subgraph::Engine( : subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {} ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected: protected:
int BuildDeviceProgram() override; int BuildDeviceProgram() override;
......
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/gru_compute.h" #include "lite/kernels/x86/gru_compute.h"
#include "lite/utils/env.h"
DEFINE_int32(paddle_num_threads, // DEFINE_int32(paddle_num_threads,
1, // 1,
"Number of threads for each paddle instance."); // "Number of threads for each paddle instance.");
int32_t paddle_num_threads =
paddle::lite::GetIntFromEnv("paddle_num_threads", 1);
REGISTER_LITE_KERNEL(gru, REGISTER_LITE_KERNEL(gru,
kX86, kX86,
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#include "lite/core/types.h" #include "lite/core/types.h"
#include "lite/fluid/eigen.h" #include "lite/fluid/eigen.h"
DECLARE_int32(paddle_num_threads); // DECLARE_int32(paddle_num_threads);
extern int32_t paddle_num_threads;
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -109,7 +110,7 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -109,7 +110,7 @@ class GRUCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM // use MKL packed to speedup GEMM
if (FLAGS_paddle_num_threads >= 4) { if (paddle_num_threads >= 4) {
auto blas = lite::x86::math::GetBlas<TARGET(kX86), T>(context); auto blas = lite::x86::math::GetBlas<TARGET(kX86), T>(context);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix,
1 /*height of C*/, 1 /*height of C*/,
......
...@@ -49,9 +49,10 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -49,9 +49,10 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto out_type = kernel->GetOutputDeclType("Out"); auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat)); CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW)); CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
auto transpose_x = op_info->GetAttr<bool>("transpose_X"); auto transpose_x = op_info->GetAttr<bool>("transpose_X");
CHECK(!transpose_x) << "XPU only support transpose_x == true now";
auto transpose_y = op_info->GetAttr<bool>("transpose_Y"); auto transpose_y = op_info->GetAttr<bool>("transpose_Y");
auto alpha = op_info->GetAttr<float>("alpha"); auto alpha = op_info->GetAttr<float>("alpha");
...@@ -71,11 +72,68 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -71,11 +72,68 @@ int MatmulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
y_node = graph->AddNode(y_name, y_dims); y_node = graph->AddNode(y_name, y_dims);
} }
auto matmul_node = // Matmul node
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y); if (x_dims.size() > 2 && y_dims.size() >= 2) {
graph->AddNode(out_name, graph->builder_.CreateScale(matmul_node, alpha)); // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
return SUCCESS; // Reshape and transposed X node
if (x_dims.size() != 3) {
auto m = static_cast<int>(x_dims[x_dims.size() - 2]);
auto k = static_cast<int>(x_dims[x_dims.size() - 1]);
x_node =
graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape(*x_node, {-1, m, k}));
if (transpose_x) {
x_node =
graph->AddNode(x_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*x_node, {0, 2, 1}));
}
}
// Reshape and transposed Y node
if (y_dims.size() != 3) {
auto k = static_cast<int>(y_dims[y_dims.size() - 2]);
auto n = static_cast<int>(y_dims[y_dims.size() - 1]);
y_node =
graph->AddNode(y_name + "/reshape",
graph->builder_.CreateReshape(*y_node, {-1, k, n}));
if (!transpose_y) {
y_node =
graph->AddNode(y_name + "/reshape/transpose",
graph->builder_.CreateTranspose(*y_node, {0, 2, 1}));
}
}
// Matmul node
auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateBatchMatmul(*x_node, *y_node));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
}
if (out_dims.size() != 3) {
graph->AddNode(out_name,
graph->builder_.CreateReshape(
*matmul_node, CvtShape<xtcl::Integer>(out_dims)));
}
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
if (transpose_x) {
x_node = graph->AddNode(x_name + "/transpose",
graph->builder_.CreateTranspose(*x_node, {1, 0}));
}
auto matmul_node = graph->AddNode(
out_name,
graph->builder_.CreateMatmul2D(*x_node, *y_node, transpose_y));
if (fabs(alpha - 1) > 1e-6f) {
matmul_node = graph->AddNode(
out_name, graph->builder_.CreateScale(*matmul_node, alpha));
}
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1]
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
LOG(FATAL) << "[XPU] Not supported.";
return FAILED;
}
return REBUILD_WHEN_SHAPE_CHANGED;
} }
} // namespace xpu } // namespace xpu
......
...@@ -67,15 +67,27 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -67,15 +67,27 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = x_node =
graph->AddNode(x_name + "/reshape", graph->AddNode(x_name + "/reshape",
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
*x_node, {-1, static_cast<int>(y_matrix_dims[0])})); *x_node, {-1, static_cast<int>(x_matrix_dims[1])}));
} }
// Y node // Y node
auto y_const_node = graph->AddNode(y_name, *y, y_matrix_dims); std::shared_ptr<xtcl::xExpr> y_node = nullptr;
if (graph->HasNode(y_name)) {
y_node = graph->GetNode(y_name);
} else {
y_node = graph->AddNode(y_name, y_dims);
}
// Flatten Y node
if (y_dims.size() != 2) {
y_node =
graph->AddNode(y_name + "/reshape",
graph->builder_.CreateReshape(
*y_node, {static_cast<int>(y_matrix_dims[0]), -1}));
}
// Reshape the matmul node with the inferred shape as the output node // Reshape the matmul node with the inferred shape as the output node
auto matmul_node = graph->AddNode( auto matmul_node = graph->AddNode(
out_name, graph->builder_.CreateMatmul2D(*x_node, *y_const_node, false)); out_name, graph->builder_.CreateMatmul2D(*x_node, *y_node, false));
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
graph->AddNode(out_name, graph->AddNode(out_name,
graph->builder_.CreateReshape( graph->builder_.CreateReshape(
......
...@@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() { ...@@ -197,7 +197,8 @@ int SubgraphEngine::LaunchDeviceProgram() {
void SubgraphCompute::PrepareForRun() { void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(param.sub_block_idx, engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc, param.sub_block_desc,
param.input_data_names, param.input_data_names,
param.output_data_names, param.output_data_names,
......
...@@ -29,13 +29,14 @@ namespace xpu { ...@@ -29,13 +29,14 @@ namespace xpu {
class SubgraphEngine : public subgraph::Engine { class SubgraphEngine : public subgraph::Engine {
public: public:
SubgraphEngine(int block_idx, SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc, cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
Scope *scope) Scope *scope)
: subgraph::Engine( : subgraph::Engine(
block_idx, block_desc, input_names, output_names, scope) {} ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected: protected:
int BuildDeviceProgram() override; int BuildDeviceProgram() override;
......
...@@ -50,6 +50,7 @@ add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) ...@@ -50,6 +50,7 @@ add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS})
add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS}) add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS})
add_operator(subgraph_op basic SRCS subgraph_op.cc DEPS ${op_DEPS}) add_operator(subgraph_op basic SRCS subgraph_op.cc DEPS ${op_DEPS})
add_operator(grid_sampler_op basic SRCS grid_sampler_op.cc DEPS ${op_DEPS}) add_operator(grid_sampler_op basic SRCS grid_sampler_op.cc DEPS ${op_DEPS})
add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS})
# 2.basic ops not used in basic models # 2.basic ops not used in basic models
add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS}) add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS})
...@@ -78,11 +79,9 @@ add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEP ...@@ -78,11 +79,9 @@ add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEP
add_operator(generate_proposals_op extra SRCS generate_proposals_op.cc DEPS ${op_DEPS}) add_operator(generate_proposals_op extra SRCS generate_proposals_op.cc DEPS ${op_DEPS})
add_operator(roi_align_op extra SRCS roi_align_op.cc DEPS ${op_DEPS}) add_operator(roi_align_op extra SRCS roi_align_op.cc DEPS ${op_DEPS})
add_operator(box_clip_op extra SRCS box_clip_op.cc DEPS ${op_DEPS}) add_operator(box_clip_op extra SRCS box_clip_op.cc DEPS ${op_DEPS})
add_operator(flatten_op extra SRCS flatten_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_range_abs_max_op extra SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) add_operator(fake_quantize_range_abs_max_op extra SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS})
add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc DEPS ${op_DEPS}) add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc DEPS ${op_DEPS})
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/attention_padding_mask_op.h" #include "lite/operators/attention_padding_mask_op.h"
#include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/scope.h" #include "lite/core/scope.h"
...@@ -39,7 +40,8 @@ bool AttentionPaddingMaskOp::InferShape() const { ...@@ -39,7 +40,8 @@ bool AttentionPaddingMaskOp::InferShape() const {
<< "Mismatch batch size, bottom0: " << att_batch << "Mismatch batch size, bottom0: " << att_batch
<< ", bottom1: " << src_batch; << ", bottom1: " << src_batch;
param_.pad_begin->Resize({static_cast<int64_t>(src_batch)}); param_.pad_begin->Resize(
std::vector<int64_t>({static_cast<int64_t>(src_batch)}));
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
param_.Out->set_lod(param_.X->lod()); param_.Out->set_lod(param_.X->lod());
......
...@@ -46,8 +46,9 @@ bool InstanceNormOp::InferShape() const { ...@@ -46,8 +46,9 @@ bool InstanceNormOp::InferShape() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0]; int64_t batch_size = x_dims[0];
int64_t channel_size = x_dims[1]; int64_t channel_size = x_dims[1];
param_.saved_mean->Resize({batch_size * channel_size}); param_.saved_mean->Resize(std::vector<int64_t>({batch_size * channel_size}));
param_.saved_variance->Resize({batch_size * channel_size}); param_.saved_variance->Resize(
std::vector<int64_t>({batch_size * channel_size}));
param_.out->Resize(x_dims); param_.out->Resize(x_dims);
return true; return true;
} }
......
...@@ -50,7 +50,7 @@ bool ReduceProdOpLite::InferShape() const { ...@@ -50,7 +50,7 @@ bool ReduceProdOpLite::InferShape() const {
if (keep_dim) { if (keep_dim) {
out->Resize({static_cast<int64_t>(x_rank), 1}); out->Resize({static_cast<int64_t>(x_rank), 1});
} else { } else {
out->Resize({1}); out->Resize(std::vector<int64_t>({1L}));
} }
} else { } else {
auto dims_vector = x_dims.Vectorize(); auto dims_vector = x_dims.Vectorize();
......
...@@ -30,6 +30,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH ...@@ -30,6 +30,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class MulComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string type_ = "mul";
std::string x_ = "x";
std::string y_ = "y";
std::string out_ = "out";
DDim x_dims_{{1, 2}};
DDim y_dims_{{2, 1}};
int x_num_col_dims_{1};
int y_num_col_dims_{1};
public:
MulComputeTester(const Place& place,
const std::string& alias,
DDim x_dims,
DDim y_dims,
int x_num_col_dims,
int y_num_col_dims)
: TestCase(place, alias),
x_dims_(x_dims),
y_dims_(y_dims),
x_num_col_dims_(x_num_col_dims),
y_num_col_dims_(y_num_col_dims) {}
void RunBaseline(Scope* scope) override {
auto* x = scope->FindTensor(x_);
auto* y = scope->FindTensor(y_);
auto x_mat_dims = x_dims_.Flatten2D(x_num_col_dims_);
auto y_mat_dims = y_dims_.Flatten2D(y_num_col_dims_);
CHECK_EQ(x_mat_dims[1], y_mat_dims[0]);
auto* out = scope->NewTensor(out_);
CHECK(out);
std::vector<int64_t> out_shape;
for (int i = 0; i < x_num_col_dims_; i++) {
out_shape.push_back(x_dims_[i]);
}
for (int i = y_num_col_dims_; i < y_dims_.size(); i++) {
out_shape.push_back(y_dims_[i]);
}
out->Resize(DDim(out_shape));
auto x_data = x->data<float>();
auto y_data = y->data<float>();
auto* out_data = out->mutable_data<float>();
const int M = x_mat_dims[0];
const int K = x_mat_dims[1];
const int N = y_mat_dims[1];
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
out_data[m * N + n] = 0;
for (int k = 0; k < K; ++k) {
out_data[m * N + n] += x_data[m * K + k] * y_data[k * N + n];
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(type_);
op_desc->SetInput("X", {x_});
op_desc->SetInput("Y", {y_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("x_num_col_dims", x_num_col_dims_);
op_desc->SetAttr("y_num_col_dims", y_num_col_dims_);
}
void PrepareData() override {
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
std::vector<float> y(y_dims_.production());
fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production());
SetCommonTensor(y_, y_dims_, y.data());
}
};
void TestMul(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
int x_num_col_dims,
int y_num_col_dims,
const Place& place,
float abs_error) {
std::unique_ptr<arena::TestCase> tester(new MulComputeTester(place,
"def",
DDim(x_dims),
DDim(y_dims),
x_num_col_dims,
y_num_col_dims));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
TEST(Mul, precision) {
LOG(INFO) << "test mul op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
TestMul({4, 5}, {5, 4}, 1, 1, place, abs_error);
TestMul({4, 5}, {5, 4, 3, 2}, 1, 1, place, abs_error);
TestMul({4, 20}, {5, 4, 3, 2}, 1, 2, place, abs_error);
TestMul({4, 60}, {5, 4, 3, 2}, 1, 3, place, abs_error);
TestMul({2, 3, 4, 5}, {60, 4}, 1, 1, place, abs_error);
TestMul({2, 3, 4, 5}, {20, 4}, 2, 1, place, abs_error);
TestMul({2, 3, 4, 5}, {5, 4}, 3, 1, place, abs_error);
TestMul({2, 3, 4, 5}, {60, 3, 4, 5}, 1, 1, place, abs_error);
TestMul({2, 3, 4, 5}, {4, 5, 6, 2}, 2, 2, place, abs_error);
TestMul({2, 3, 4, 5}, {5, 1, 4, 2}, 3, 2, place, abs_error);
}
} // namespace lite
} // namespace paddle
...@@ -107,6 +107,7 @@ class UnsqueezeComputeTester : public arena::TestCase { ...@@ -107,6 +107,7 @@ class UnsqueezeComputeTester : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
SetPrecisionType(out_, PRECISION(kFloat));
std::vector<float> in_data(dims_.production()); std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) { for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i; in_data[i] = i;
...@@ -213,6 +214,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase { ...@@ -213,6 +214,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
SetPrecisionType(out_, PRECISION(kFloat));
std::vector<float> in_data(dims_.production()); std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) { for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i; in_data[i] = i;
......
...@@ -1042,23 +1042,6 @@ function main { ...@@ -1042,23 +1042,6 @@ function main {
build_test_arm_subtask_armlinux build_test_arm_subtask_armlinux
shift shift
;; ;;
build_test_arm_model_mobilenetv1)
build_test_arm_subtask_model test_mobilenetv1 mobilenet_v1
build_test_arm_subtask_model test_mobilenetv1_int8 MobileNetV1_quant
shift
;;
build_test_arm_model_mobilenetv2)
build_test_arm_subtask_model test_mobilenetv2 mobilenet_v2_relu
shift
;;
build_test_arm_model_resnet50)
build_test_arm_subtask_model test_resnet50 resnet50
shift
;;
build_test_arm_model_inceptionv4)
build_test_arm_subtask_model test_inceptionv4 inception_v4_simple
shift
;;
check_style) check_style)
check_style check_style
shift shift
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string>
namespace paddle {
namespace lite {
static std::string GetStringFromEnv(const std::string& str,
const std::string& def = "") {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
return std::string(variable);
}
static bool GetBoolFromEnv(const std::string& str, bool def = false) {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
if (strcmp(variable, "false") == 0 || strcmp(variable, "0") == 0) {
return false;
} else {
return true;
}
}
static int GetIntFromEnv(const std::string& str, int def = 0) {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
return atoi(variable);
}
static double GetDoubleFromEnv(const std::string& str, double def = 0.0) {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
return atof(variable);
}
static uint64_t GetUInt64FromEnv(const std::string& str, uint64_t def = 0ul) {
char* variable = std::getenv(str.c_str());
if (!variable) {
return def;
}
return static_cast<uint64_t>(atol(variable));
}
} // namespace lite
} // namespace paddle
...@@ -18,6 +18,37 @@ limitations under the License. */ ...@@ -18,6 +18,37 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void CLImage::PrintTensor(const CLImage &cl_image) const {
size_t width = cl_image.ImageDims()[0];
size_t height = cl_image.ImageDims()[1];
half_t *image_data = new half_t[height * width * 4];
cl_int err;
cl_mem image = cl_image.GetCLImage();
size_t origin[3] = {0, 0, 0};
size_t region[3] = {width, height, 1};
err = clEnqueueReadImage(cl_image.CommandQueue(), image, CL_TRUE, origin,
region, 0, 0, image_data, 0, NULL, NULL);
CL_CHECK_ERRORS(err);
PADDLE_MOBILE_ENFORCE(cl_image.numel() != 0,
"cl_image numel should not be 0 ");
float *tensor_data = new float[cl_image.numel()];
auto converter = cl_image.Converter();
converter->ImageToNCHW(image_data, tensor_data, cl_image.ImageDims(),
cl_image.dims());
int stride = cl_image.numel() / 20;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < cl_image.numel(); i++) {
printf("%f \n", tensor_data[i]);
}
delete[](tensor_data);
delete[](image_data);
}
void CLImageToTensor(CLImage *cl_image, Tensor *tensor, cl_context context, void CLImageToTensor(CLImage *cl_image, Tensor *tensor, cl_context context,
cl_command_queue commandQueue, cl_kernel kernel) { cl_command_queue commandQueue, cl_kernel kernel) {
tensor->mutable_data<float>(); tensor->mutable_data<float>();
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -285,6 +286,7 @@ class CLImage { ...@@ -285,6 +286,7 @@ class CLImage {
cl_event GetClEvent() const { return cl_event_.get(); } cl_event GetClEvent() const { return cl_event_.get(); }
CLImageConverterBase *Converter() const { return image_converter_; } CLImageConverterBase *Converter() const { return image_converter_; }
void PrintTensor(const CLImage &cl_image) const;
private: private:
void InitCLImage(cl_context context, size_t width, size_t height, void InitCLImage(cl_context context, size_t width, size_t height,
......
...@@ -21,13 +21,14 @@ namespace framework { ...@@ -21,13 +21,14 @@ namespace framework {
const char* opencl_error_to_str(cl_int error); const char* opencl_error_to_str(cl_int error);
#define CL_CHECK_ERRORS(ERR) \ #define CL_CHECK_ERRORS(ERR) \
if (ERR != CL_SUCCESS) { \ if (ERR != CL_SUCCESS) { \
printf( \ printf( \
"OpenCL error with code %s happened in file %s at line %d. " \ "\033[1;31;40mOpenCL error with code %s happened in file %s at line " \
"Exiting.\n", \ "%d. " \
paddle_mobile::framework::opencl_error_to_str(ERR), __FILE__, \ "Exiting.\033[0m\n", \
__LINE__); \ paddle_mobile::framework::opencl_error_to_str(ERR), __FILE__, \
__LINE__); \
} }
} // namespace framework } // namespace framework
......
...@@ -363,7 +363,10 @@ void Executor<Device, T>::InitNoPersistableMemory(const Tensor &input_tensor) { ...@@ -363,7 +363,10 @@ void Executor<Device, T>::InitNoPersistableMemory(const Tensor &input_tensor) {
DLOG << "InitNoPersistableMemory var " << var_desc->Name(); DLOG << "InitNoPersistableMemory var " << var_desc->Name();
auto tensor = var->template GetMutable<LoDTensor>(); auto tensor = var->template GetMutable<LoDTensor>();
if (tensor->IsInitialized() && tensor->dims().size() == 4) { if (tensor->IsInitialized() && tensor->dims().size() == 4) {
DLOG << "var's tensor is Initialized or dims size != 4"; // don't change user's input and avoid memory leaks
if (feed_indices_.find(var_desc->Name()) != feed_indices_.end()) {
break;
}
DDim tensor_dim = tensor->dims(); DDim tensor_dim = tensor->dims();
DDim new_dim = DDim new_dim =
make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2], make_ddim({tensor_dim[0], tensor_dim[1], input_tensor.dims()[2],
......
...@@ -241,7 +241,9 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, ...@@ -241,7 +241,9 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
cl_int status; cl_int status;
int index = 0; int index = 0;
if (param.Filter()->dims()[2] == 1 && param.Filter()->dims()[3] == 1) { const int filter_height = param.Filter()->dims()[2];
const int filter_width = param.Filter()->dims()[3];
if (filter_height == 1 && filter_width == 1) {
status = clSetKernelArg(kernel, index++, sizeof(int), &c_block); status = clSetKernelArg(kernel, index++, sizeof(int), &c_block);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
...@@ -404,7 +406,7 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, ...@@ -404,7 +406,7 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); status = clSetKernelArg(kernel, index++, sizeof(int), &output_height);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { if (filter_height == 3 && filter_width == 3) {
// normal conv // normal conv
if (param.Filter()->dims()[0] == param.Output()->dims()[1] && if (param.Filter()->dims()[0] == param.Output()->dims()[1] &&
param.Filter()->dims()[1] == param.Input()->dims()[1]) { param.Filter()->dims()[1] == param.Input()->dims()[1]) {
...@@ -425,6 +427,17 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, ...@@ -425,6 +427,17 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
status = clSetKernelArg(kernel, index++, sizeof(int), &group); status = clSetKernelArg(kernel, index++, sizeof(int), &group);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
} }
} else if (filter_height != 3 && filter_width != 3) {
// not 3x3
if (param.Filter()->dims()[1] == 1 &&
param.Input()->dims()[1] == param.Output()->dims()[1]) {
// deepwise basic use in not 3x3
status = clSetKernelArg(kernel, index++, sizeof(int), &filter_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, index++, sizeof(int), &filter_height);
CL_CHECK_ERRORS(status);
}
} }
status = clEnqueueNDRangeKernel( status = clEnqueueNDRangeKernel(
......
此差异已折叠。
...@@ -13,33 +13,101 @@ See the License for the specific language governing permissions and ...@@ -13,33 +13,101 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,__write_only image2d_t outputImage) { __kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,
int x = get_global_id(0); __write_only image2d_t outputImage) {
int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
half4 in = read_imageh(input, sampler, coords);
half4 biase = read_imageh(bias, sampler, coords);
half4 output = in * biase;
write_imageh(outputImage,coords,output);
}
__kernel void channel_mul(__global image2d_t input, __global image2d_t bias,__write_only
image2d_t outputImage, int w) {
int x = get_global_id(0); int x = get_global_id(0);
int y = get_global_id(1); int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
half4 in = read_imageh(input, sampler, coords);
half4 biase = read_imageh(bias, sampler, coords);
half4 output = in * biase;
write_imageh(outputImage, coords, output);
}
__kernel void channel_mul(__global image2d_t input, __global image2d_t bias,
__write_only image2d_t outputImage, int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords; int2 coords;
coords.x = x; coords.x = x;
coords.y = y; coords.y = y;
int2 coords_bias; int2 coords_bias;
coords_bias.x = x/w; coords_bias.x = x / w;
coords_bias.y = 0; coords_bias.y = 0;
half4 in = read_imageh(input, sampler, coords); half4 in = read_imageh(input, sampler, coords);
half4 biase = read_imageh(bias, sampler, coords_bias); half4 biase = read_imageh(bias, sampler, coords_bias);
half4 output = in * biase; half4 output = in * biase;
write_imageh(outputImage,coords,output); write_imageh(outputImage, coords, output);
} }
// etc : 1 1 1 72
// run time Y [value,0,0,0] * 72
__kernel void channel_mul_d2(__global image2d_t input, __global image2d_t bias,
__write_only image2d_t outputImage, int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias0;
int2 coords_bias1;
int2 coords_bias2;
int2 coords_bias3;
/* if (x == 0 && y == 0) {
half4 b = (half4){0, 0, 0, 0};
#define PPI(j, k) \
b = read_imageh(bias, sampler, (int2){j, k}); \
printf("bias(%d,%d)={ %f , %f , %f , %f }\n ", j, k, convert_float(b.x), \
convert_float(b.y), convert_float(b.z), convert_float(b.w));
for (int i = 0; i < 73; ++i) {
PPI(i, 0);
}
#undef PPI
}*/
coords_bias0.x = x / w * 4;
coords_bias0.y = 0;
coords_bias1.x = x / w * 4 + 1;
coords_bias1.y = 0;
coords_bias2.x = x / w * 4 + 2;
coords_bias2.y = 0;
coords_bias3.x = x / w * 4 + 3;
coords_bias3.y = 0;
half4 biase0 = read_imageh(bias, sampler, coords_bias0);
half4 biase1 = read_imageh(bias, sampler, coords_bias1);
half4 biase2 = read_imageh(bias, sampler, coords_bias2);
half4 biase3 = read_imageh(bias, sampler, coords_bias3);
/* if (x == 0 && y == 0) {
printf("bias0={ %f , %f , %f , %f }\n ",
convert_float(biase0.x), convert_float(biase0.y),
convert_float(biase0.z), convert_float(biase0.w));
printf("bias1={ %f , %f , %f , %f }\n ",
convert_float(biase1.x), convert_float(biase1.y),
convert_float(biase1.z), convert_float(biase1.w));
printf("bias2={ %f , %f , %f , %f }\n ",
convert_float(biase2.x), convert_float(biase2.y),
convert_float(biase2.z), convert_float(biase2.w));
printf("bias3={ %f , %f , %f , %f }\n ",
convert_float(biase3.x), convert_float(biase3.y),
convert_float(biase3.z), convert_float(biase3.w));
}*/
half4 biase = {biase0.x, biase1.x, biase2.x, biase3.x};
half4 in = read_imageh(input, sampler, coords);
half4 output = mad(in, biase, 0);
write_imageh(outputImage, coords, output);
}
\ No newline at end of file
...@@ -174,6 +174,16 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init( ...@@ -174,6 +174,16 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
build_options); build_options);
} }
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// other depthwise not with filter 3x3
DLOG << "depth_conv basic ";
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file, build_options);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -214,6 +224,7 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute( ...@@ -214,6 +224,7 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute(
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, true, param.Bias(), ConvAddBnRelu(&this->cl_helper_, param, true, param.Bias(),
param.NewScale(), param.NewBias()); param.NewScale(), param.NewBias());
break; break;
......
...@@ -71,6 +71,14 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) { ...@@ -71,6 +71,14 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
build_options); build_options);
} }
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file, build_options);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -124,6 +132,7 @@ void ConvAddKernel<GPU_CL, float>::Compute( ...@@ -124,6 +132,7 @@ void ConvAddKernel<GPU_CL, float>::Compute(
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW5x5_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW5x5_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, false, param.Bias()); ConvAddBnRelu(&this->cl_helper_, param, false, param.Bias());
break; break;
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT:
......
...@@ -72,6 +72,14 @@ bool ConvAddReluKernel<GPU_CL, float>::Init( ...@@ -72,6 +72,14 @@ bool ConvAddReluKernel<GPU_CL, float>::Init(
build_options); build_options);
} }
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
DLOG << "init depwise conv basic";
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file, build_options);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -130,6 +138,7 @@ void ConvAddReluKernel<GPU_CL, float>::Compute( ...@@ -130,6 +138,7 @@ void ConvAddReluKernel<GPU_CL, float>::Compute(
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW5x5_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW5x5_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, true, param.Bias()); ConvAddBnRelu(&this->cl_helper_, param, true, param.Bias());
break; break;
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT:
......
...@@ -129,6 +129,14 @@ bool ConvBNReluKernel<GPU_CL, float>::Init( ...@@ -129,6 +129,14 @@ bool ConvBNReluKernel<GPU_CL, float>::Init(
build_options); build_options);
} }
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file, build_options);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -168,6 +176,7 @@ void ConvBNReluKernel<GPU_CL, float>::Compute( ...@@ -168,6 +176,7 @@ void ConvBNReluKernel<GPU_CL, float>::Compute(
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, true, nullptr, param.NewScale(), ConvAddBnRelu(&this->cl_helper_, param, true, nullptr, param.NewScale(),
param.NewBias()); param.NewBias());
break; break;
......
...@@ -66,6 +66,14 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) { ...@@ -66,6 +66,14 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
} }
DLOG << "depth_conv 3x3"; DLOG << "depth_conv 3x3";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -115,6 +123,7 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) { ...@@ -115,6 +123,7 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW7x7_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param); ConvAddBnRelu(&this->cl_helper_, param);
break; break;
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT:
......
...@@ -72,6 +72,14 @@ bool ConvReluKernel<GPU_CL, float>::Init(FusionConvReluParam<GPU_CL> *param) { ...@@ -72,6 +72,14 @@ bool ConvReluKernel<GPU_CL, float>::Init(FusionConvReluParam<GPU_CL> *param) {
DLOG << "depth_conv 3x3"; DLOG << "depth_conv 3x3";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] != 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT;
this->cl_helper_.AddKernel("depth_conv", conv_kernel_file, build_options);
} else if (param->Filter()->dims()[2] == 3 && } else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) { param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] && // if (param->Strides()[0] == param->Strides()[1] &&
...@@ -120,6 +128,7 @@ void ConvReluKernel<GPU_CL, float>::Compute( ...@@ -120,6 +128,7 @@ void ConvReluKernel<GPU_CL, float>::Compute(
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISEBASIC_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, true); ConvAddBnRelu(&this->cl_helper_, param, true);
break; break;
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3S1_FLOAT:
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#ifdef ELEMENTWISEMUL_OP #ifdef ELEMENTWISEMUL_OP
#include "operators/kernel/elementwise_mul_kernel.h" #include "operators/kernel/elementwise_mul_kernel.h"
#include <framework/cl/cl_half.h>
#include <iostream>
#include "framework/cl/cl_image.h" #include "framework/cl/cl_image.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -23,19 +25,24 @@ namespace operators { ...@@ -23,19 +25,24 @@ namespace operators {
template <> template <>
bool ElementwiseMulKernel<GPU_CL, float>::Init( bool ElementwiseMulKernel<GPU_CL, float>::Init(
ElementwiseMulParam<GPU_CL> *param) { ElementwiseMulParam<GPU_CL> *param) {
DLOG << "-----init add-----";
framework::CLImage *bias = reinterpret_cast<framework::CLImage *>( framework::CLImage *bias = reinterpret_cast<framework::CLImage *>(
const_cast<framework::CLImage *>(param->InputY())); const_cast<framework::CLImage *>(param->InputY()));
if (bias->dims() == param->InputX()->dims()) { if (bias->dims() == param->InputX()->dims()) {
DLOG << "init element wise mul";
this->cl_helper_.AddKernel("elementwise_mul", "elementwise_mul_kernel.cl"); this->cl_helper_.AddKernel("elementwise_mul", "elementwise_mul_kernel.cl");
} else if (bias->dims().size() == 4) { } else if (bias->dims().size() == 1) {
DLOG << "init channel_mul";
this->cl_helper_.AddKernel("channel_mul", "elementwise_mul_kernel.cl"); this->cl_helper_.AddKernel("channel_mul", "elementwise_mul_kernel.cl");
} else if (bias->dims().size() == 2) {
// etc. input 1 72 28 28
// filter 1 72
DLOG << "init channel_mul_d2";
this->cl_helper_.AddKernel("channel_mul_d2", "elementwise_mul_kernel.cl");
} else { } else {
DLOG << "error:bias dims is error"; PADDLE_MOBILE_ENFORCE(false, "element mul not supported yet");
} }
return true; return true;
} }
template <> template <>
void ElementwiseMulKernel<GPU_CL, float>::Compute( void ElementwiseMulKernel<GPU_CL, float>::Compute(
const ElementwiseMulParam<GPU_CL> &param) { const ElementwiseMulParam<GPU_CL> &param) {
...@@ -64,8 +71,8 @@ void ElementwiseMulKernel<GPU_CL, float>::Compute( ...@@ -64,8 +71,8 @@ void ElementwiseMulKernel<GPU_CL, float>::Compute(
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL); NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
} else if (bias->dims().size() == 4) { } else if (bias->dims().size() == 1) {
DLOG << "zp7 444"; DLOG << "channel mul";
cl_mem input_image = input->GetCLImage(); cl_mem input_image = input->GetCLImage();
cl_mem bias_image = bias->GetCLImage(); cl_mem bias_image = bias->GetCLImage();
cl_mem output_image = output->GetCLImage(); cl_mem output_image = output->GetCLImage();
...@@ -84,14 +91,48 @@ void ElementwiseMulKernel<GPU_CL, float>::Compute( ...@@ -84,14 +91,48 @@ void ElementwiseMulKernel<GPU_CL, float>::Compute(
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
auto width = input->ImageWidth(); auto width = input->ImageWidth();
auto height = input->ImageHeight(); auto height = input->ImageHeight();
DLOG << "dede:" << width << "," << height;
size_t global_work_size[2] = {width, height}; size_t global_work_size[2] = {width, height};
status = status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL); NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
} else if (bias->dims().size() == 2) {
DLOG << "channel mul d2";
// etc. input 1 72 28 28
// filter 1 72 --> 1 1 1 72
DLOG << "input->ImageDims(): " << input->ImageDims();
DLOG << "bias->ImageDims(): " << bias->ImageDims();
DLOG << "out->ImageDims(): " << output->ImageDims();
DLOG << "channel mul d2";
cl_mem input_image = input->GetCLImage();
cl_mem bias_image = bias->GetCLImage();
cl_mem output_image = output->GetCLImage();
int tensor_w = input->dims()[input->dims().size() - 1];
status = clSetKernelArg(kernel, 0, sizeof(cl_mem),
reinterpret_cast<void *>(&input_image));
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem),
reinterpret_cast<void *>(&bias_image));
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_mem),
reinterpret_cast<void *>(&output_image));
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int),
reinterpret_cast<void *>(&tensor_w));
CL_CHECK_ERRORS(status);
auto width = input->ImageWidth();
auto height = input->ImageHeight();
size_t global_work_size[2] = {width, height};
status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
// bias->PrintTensor(*bias);
} else { } else {
DLOG << "error:bias dims is error"; PADDLE_MOBILE_ENFORCE(false, "element mul not support this situation yet")
} }
} }
......
...@@ -489,6 +489,7 @@ class ConvParam : public OpParam { ...@@ -489,6 +489,7 @@ class ConvParam : public OpParam {
EXEC_SLIDINGWINDOW5x5_FLOAT, EXEC_SLIDINGWINDOW5x5_FLOAT,
EXEC_SLIDINGWINDOW7x7_FLOAT, EXEC_SLIDINGWINDOW7x7_FLOAT,
EXEC_GEMM1x1s1_FLOAT, EXEC_GEMM1x1s1_FLOAT,
EXEC_DEPTHWISEBASIC_FLOAT,
}; };
ExecMode &ExecMode() const { return exec_mode_; } ExecMode &ExecMode() const { return exec_mode_; }
......
...@@ -216,4 +216,6 @@ void test(int argc, char *argv[]) { ...@@ -216,4 +216,6 @@ void test(int argc, char *argv[]) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
#else
int main() {}
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册