diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake index 11a803ff031706a10f282f21024915be68444546..45be0e4d3a7e1b7daff3fd226b53b06ad96fb73d 100644 --- a/cmake/cross_compiling/android.cmake +++ b/cmake/cross_compiling/android.cmake @@ -18,6 +18,7 @@ endif() set(ANDROID TRUE) add_definitions(-DLITE_WITH_LINUX) +add_definitions(-DLITE_WITH_ANDROID) if(NOT DEFINED ANDROID_NDK) set(ANDROID_NDK $ENV{NDK_ROOT}) diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index a494be563bc4eac133f95b1a389f3155c491bc18..896f6c8d33a8665c4c94786dd08af1a097942608 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -35,6 +35,9 @@ #include #include #endif +#ifdef LITE_WITH_ANDROID +#include +#endif #if __APPLE__ #include "TargetConditionals.h" #if LITE_WITH_IPHONE @@ -218,6 +221,7 @@ void get_cpu_arch(std::vector* archs, const int cpu_num) { #ifdef LITE_WITH_LINUX std::string get_cpu_name() { + std::string cpu_name; FILE* fp = fopen("/proc/cpuinfo", "rb"); if (!fp) { return ""; @@ -229,12 +233,23 @@ std::string get_cpu_name() { break; } if (strstr(line, "Hardware") != NULL) { - fclose(fp); - return std::string(line); + cpu_name = std::string(line); } } +#ifdef LITE_WITH_ANDROID + // cpu name concat board name, platform name and chip name + char board_name[128]; + char platform_name[128]; + char chip_name[128]; + __system_property_get("ro.product.board", board_name); + __system_property_get("ro.board.platform", platform_name); + __system_property_get("ro.chipname", chip_name); + cpu_name = + cpu_name + "_" + board_name + "_" + platform_name + "_" + chip_name; +#endif + std::transform(cpu_name.begin(), cpu_name.end(), cpu_name.begin(), ::toupper); fclose(fp); - return ""; + return cpu_name; } int get_min_freq_khz(int cpuid) { @@ -780,7 +795,9 @@ bool DeviceInfo::SetCPUInfoByName() { cluster_ids_ = {0, 0, 0, 0}; SetArchInfo(1, kA53); return true; - } else if (dev_name_.find("KIRIN980") != std::string::npos) { // Kirin 980 + } else if (dev_name_.find("KIRIN980") != std::string::npos || + dev_name_.find("KIRIN990") != + std::string::npos) { // Kirin 980, Kirin 990 core_num_ = 8; core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; big_core_ids_ = {4, 5, 6, 7}; diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 98007db0d188b8a77477a5148224be71f5b00dd5..ebb96e21d5e856325b7abdb8342df2aea3d5b5c3 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -39,6 +39,13 @@ void ConvCompute::PrepareForRun() { int pad = param.paddings[0]; int stride = param.strides[0]; + int chin = param.x->dims()[1]; + int hin = param.x->dims()[2]; + int win = param.x->dims()[3]; + int chout = param.output->dims()[1]; + int hout = param.output->dims()[2]; + int wout = param.output->dims()[3]; + bool kps_equal = (param.paddings[0] == param.paddings[1]) && (param.strides[0] == param.strides[1]) && (kw == kh); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); @@ -54,7 +61,7 @@ void ConvCompute::PrepareForRun() { VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && no_dilation) { - if (ic >= 32 && oc >= 32) { + if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) { /// winograd conv impl impl_ = new WinogradConv; VLOG(3) << "invoking winograd conv"; @@ -63,8 +70,8 @@ void ConvCompute::PrepareForRun() { impl_ = new DirectConv; VLOG(3) << "invoking direct conv"; } - } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && - no_dilation) { + } else if (param.groups == 1 && kw == 3 && stride == 2 && + chin * chout < 4 * hin * win && kps_equal && no_dilation) { /// direct conv impl impl_ = new DirectConv; VLOG(3) << "invoking direct conv"; diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 348a55db117245582a8f13c5abf9161a8c880940..dd5676e6430069297cdd3527900bce69c59f3dfb 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -16,6 +16,8 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute. add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) @@ -26,5 +28,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) +nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) +nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/pool_compute.cu b/lite/kernels/cuda/pool_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..a2483a2c759e8acc5f5944fd316c83bb49530d36 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.cu @@ -0,0 +1,375 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/pool_compute.h" +#include "lite/utils/macros.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +#define MAX_VAL(a, b) (((a) > (b)) ? (a) : (b)) +#define MIN_VAL(a, b) (((a) < (b)) ? (a) : (b)) + +__global__ void max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + max_val = max_val == -FLT_MAX ? 0.f : max_val; + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void adaptive_max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + bool exclusive, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = exclusive ? h_loop * w_loop : win_w * win_h; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void adaptive_avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = h_loop * w_loop; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void global_max_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float max_val = -0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + max_val = MAX_VAL(max_val, *(in_p + 0)); + max_val = MAX_VAL(max_val, *(in_p + 1)); + max_val = MAX_VAL(max_val, *(in_p + 2)); + max_val = MAX_VAL(max_val, *(in_p + 3)); + max_val = MAX_VAL(max_val, *(in_p + 4)); + max_val = MAX_VAL(max_val, *(in_p + 5)); + max_val = MAX_VAL(max_val, *(in_p + 6)); + max_val = MAX_VAL(max_val, *(in_p + 7)); + in_p += 8; + } + for (; i < spatial_in; i++) { + max_val = MAX_VAL(max_val, *in_p); + in_p++; + } + output[gid] = max_val; + } +} + +__global__ void global_avg_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float sum_val = 0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + } + for (; i < spatial_in; i++) { + sum_val += *in_p++; + } + output[gid] = sum_val / spatial_in; + } +} + +void PoolCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + auto x_dims = param.x->dims(); + auto out_dims = param.output->dims(); + const int in_h = x_dims[2]; + const int in_w = x_dims[3]; + const int out_h = out_dims[2]; + const int out_w = out_dims[3]; + const int spatial_in = in_h * in_w; + const int spatial_out = out_h * out_w; + const int win_h = param.ksize[0]; + const int win_w = param.ksize[1]; + const int stride_h = param.strides[0]; + const int stride_w = param.strides[1]; + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int total_threads = out_dims.production(); + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (param.global_pooling) { + if (param.pooling_type == "max") { + global_max_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } else { + global_avg_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } + } else { + if (!adaptive) { + if (param.pooling_type == "max") { + max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + exclusive, + total_threads); + } + } else { + if (param.pooling_type == "max") { + adaptive_max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + adaptive_avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } + } + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + pool2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::PoolCompute, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/pool_compute.h b/lite/kernels/cuda/pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..55b346bfaf4ac139c8d22bff2ac64f0e78bc6023 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class PoolCompute + : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override; + virtual ~PoolCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/pool_compute_test.cc b/lite/kernels/cuda/pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe6ff92c0ce943cad36fbdd4f1408e344d9fd5fd --- /dev/null +++ b/lite/kernels/cuda/pool_compute_test.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/pool_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +static int PoolOutputSize( + int input_size, int filter_size, int padding, int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + return output_size; +} + +static std::vector compute_output_shape(operators::PoolParam* param_) { + const auto x_dims = param_->x->dims(); + std::vector& ksize = param_->ksize; + if (param_->global_pooling) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { + param_->paddings[i] = 0; + ksize[i] = static_cast(x_dims[i + 2]); + } + } + + std::vector output_shape({x_dims[0], x_dims[1]}); + if (param_->adaptive) { + output_shape.insert( + output_shape.end(), param_->ksize.begin(), param_->ksize.end()); + } else { + for (size_t i = 0; i < param_->ksize.size(); ++i) { + output_shape.push_back(PoolOutputSize(x_dims[i + 2], + param_->ksize[i], + param_->paddings[i], + param_->strides[i], + param_->ceil_mode)); + } + } + return output_shape; +} + +static void pool_compute_ref(const operators::PoolParam& param) { + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* src_ptr = param.x->data(); + float* dst_ptr = param.output->mutable_data(); + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = param.paddings; + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + std::string data_format = param.data_format; + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + const float* src = src_ptr + n * size_in_n + c * size_in_c; + float res = src[0]; + if (pooling_type == "max") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res = cur_val > res ? cur_val : res; + } + } else if (pooling_type == "avg") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res += cur_val; + } + res /= size_in_c; + } + dst_ptr[n * size_out_n + c] = res; + } + } + } else { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < out_h; ++h) { + int sh = h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + for (int w = 0; w < out_w; ++w) { + int sw = w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + int pooling_size = (ew - sw) * (eh - sh); + if (pooling_size == 0) { + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = 0.f; + continue; + } + float res = 0.f; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw; + if (kh == sh && kw == sw) { + res = src_ptr[src_idx]; + } else { + if (pooling_type == "max") { + res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx]; + } + if (pooling_type == "avg") { + res += src_ptr[src_idx]; + } + } + } + } + if (pooling_type == "avg") { + if (exclusive) { + res /= pooling_size; + } else { + res /= window_h * window_w; + } + } + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res; + } + } + } + } + } +} + +TEST(pool_cuda, compute) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + PoolCompute pool; + operators::PoolParam param; + pool.SetContext(std::move(ctx)); + + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + for (auto pooling_type : {"max", "avg"}) { + for (auto ceil_mode : {true, false}) { + for (auto global_pooling : {true, false}) { + for (auto exclusive : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto pad : {0, 1}) { + for (auto n : {1, 2}) { + for (auto c : {1, 3}) { + for (auto h : {2, 3, 4, 11}) { + for (auto w : {2, 3, 4, 11}) { + VLOG(3) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x_cpu.dims().production(); ++i) { + float sign = i % 3 == 0 ? -0.03 : 0.05f; + x_cpu_data[i] = sign * (i % 128); + } + x.Assign(x_cpu_data, + x_cpu.dims()); + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + if (global_pooling) { + param.ksize = {h, w}; + } else { + param.ksize = {ksize, ksize}; + } + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = exclusive; + param.ceil_mode = ceil_mode; + param.adaptive = false; + param.use_quantizer = false; + + const std::vector& output_shape = + compute_output_shape(¶m); + if (output_shape[2] * output_shape[3] == 0) continue; + output.Resize(DDim(output_shape)); + output_ref.Resize(DDim(output_shape)); + output_cpu.Resize(DDim(output_shape)); + auto* output_data = + output.mutable_data(TARGET(kCUDA)); + auto* output_ref_data = + output_ref.mutable_data(); + auto* output_cpu_data = + output_cpu.mutable_data(); + + // compute + pool.SetParam(param); + pool.Launch(); + + // compute ref + param.x = &x_cpu; + param.output = &output_ref; + pool_compute_ref(param); + + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR( + output_cpu_data[i], output_ref_data[i], 1e-4); + } + VLOG(3) << "compare pass"; + } + } + } + } + } + } + } + } + } + } + } +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8d2987524cd2e8f9c38aba4da3ff61a80bf53ce --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.cu @@ -0,0 +1,246 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/softmax_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +extern __shared__ char tile[]; +template +__global__ void sharemem_softmax_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + int inner_num, + int outer_num, + int axis_size) { + dtype* data = reinterpret_cast(tile) + threadIdx.x; + //! compute thread index and real data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int blocksize = blockDim.x; + int real_index = idx_outer * inner_num + idx_inner; + int loop_idx = real_index; +//! read all data to sharemem in softmax channel +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + data[i * blocksize] = in_data[loop_idx]; + loop_idx += inner_num; + } + //! get maximum value in softmax channel + dtype max_data = data[0]; +#pragma unroll + for (int i = 1; i < axis_size; ++i) { + dtype dt = data[i * blocksize]; + if (max_data < dt) { + max_data = dt; + } + } + //! subtract then summarize + dtype sum = 0; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + dtype* dt = data + i * blocksize; + *dt = expf(*dt - max_data); + sum += *dt; + } + //! write back result + loop_idx = real_index; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + out_data[loop_idx] = data[i * blocksize] / sum; + loop_idx += inner_num; + } + } +} + +//! general kernel for softmax +template +__global__ void softmax_max_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + dtype min_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + //! get maximum data across softmax axis + dtype max_data = min_data; + for (int i = 0; i < axis_size; ++i) { + max_data = + in_data[real_index] > max_data ? in_data[real_index] : max_data; + real_index += inner_num; + } + out_data[idx] = max_data; + } +} + +template +__global__ void softmax_sub_exp_sum_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + const dtype* max_data, + dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + + dtype max_data_cur = max_data[idx]; + dtype sum_data_cur = 0; + int real_index = idx_outer * inner_num + idx_inner; + //! compute exp and summarize across the softmax axis + for (int i = 0; i < axis_size; ++i) { + dtype sub_data = in_data[real_index] - max_data_cur; + sub_data = expf(sub_data); + sum_data_cur += sub_data; + out_data[real_index] = sub_data; + real_index += inner_num; + } + sum_data[idx] = sum_data_cur; + } +} + +template +__global__ void softmax_divid_output_kernel(int total_size, + dtype* io_data, + const dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + dtype sum_data_cur = 1.f / sum_data[idx]; + int real_index = idx_outer * inner_num + idx_inner; + //! compute final result + for (int i = 0; i < axis_size; ++i) { + io_data[real_index] = io_data[real_index] * sum_data_cur; + real_index += inner_num; + } + } +} + +void SoftmaxCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int total_threads = inner_num * outer_num; + int axis_size = x_dims[axis]; + + int device_id; + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + cudaGetDevice(&device_id); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, device_id); + size_t sharedmem_size = deviceProp.sharedMemPerBlock; + int max_dimsize = sharedmem_size / sizeof(float) / threads; + + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (axis_size <= max_dimsize) { + int use_sharemem_size = axis_size * threads * sizeof(float); + sharemem_softmax_kernel<<>>( + total_threads, + input_data, + output_data, + inner_num, + outer_num, + axis_size); + } else { + //! re_alloc device memory + Tensor tmax_data; + Tensor tsum_data; + tmax_data.Resize({1, 1, 1, outer_num * inner_num}); + tsum_data.Resize({1, 1, 1, outer_num * inner_num}); + auto max_data = tmax_data.mutable_data(TARGET(kCUDA)); + auto sum_data = tsum_data.mutable_data(TARGET(kCUDA)); + //! firstly, get maximum data + float min_data = std::numeric_limits::min(); + softmax_max_kernel<<>>(total_threads, + input_data, + max_data, + min_data, + inner_num, + outer_num, + axis_size); + //! then, compute exp and sum data + softmax_sub_exp_sum_kernel<<>>( + total_threads, + input_data, + output_data, + max_data, + sum_data, + inner_num, + outer_num, + axis_size); + //! last, compute divided output + softmax_divid_output_kernel<<>>( + total_threads, output_data, sum_data, inner_num, outer_num, axis_size); + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SoftmaxCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("axis", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..4acde4ab072390dd139c3e4e715f9ad288dc4ef8 --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SoftmaxCompute + : public KernelLite { + public: + using param_t = operators::SoftmaxParam; + + void Run() override; + virtual ~SoftmaxCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/softmax_compute_test.cc b/lite/kernels/cuda/softmax_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4d53520911a4868c73d7806fcc1bb5bf8bf33df --- /dev/null +++ b/lite/kernels/cuda/softmax_compute_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/softmax_compute.h" +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +template +static void softmax_compute_ref(const operators::SoftmaxParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + ASSERT_EQ(x_dims.data(), param.output->dims().data()); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +TEST(softmax_cuda, compute) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + SoftmaxCompute softmax; + operators::SoftmaxParam param; + softmax.SetContext(std::move(ctx)); + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + for (auto n : {1, 3}) { + for (auto c : {1, 4}) { + for (auto h : {5, 1, 112}) { + for (auto w : {1, 6, 112}) { + for (auto axis : {-2, -1, 0, 1, 2}) { + x.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_cpu.Resize({n, c, h, w}); + output_ref.Resize({n, c, h, w}); + auto* x_cpu_data = x_cpu.mutable_data(); + auto* output_data = output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = output_ref.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_cpu_data[i] = i; + } + x.Assign(x_cpu_data, + x_cpu.dims()); + param.x = &x; + param.axis = axis; + param.output = &output; + softmax.SetParam(param); + softmax.Launch(); + param.x = &x_cpu; + param.output = &output_ref; + softmax_compute_ref(param); + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 10dff5371a0f6840e092287d97eff98722e3b7f7..668419cf7ceae4a2e10cd447d57824f826cabd3a 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/operators/conv_op.h" +#include #include #include "lite/core/op_registry.h" @@ -51,10 +52,41 @@ inline int ConvOutputSize( return output_size; } +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { + // when padding_desc is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < strides.size(); ++i) { + int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; + int pad_sum = + std::max((out_size - 1) * strides[i] + ksize[i] - data_dims[i + 2], + (int64_t)0); + // pad + *(paddings->begin() + i) = pad_sum / 2; + // dilation + *(dilations->begin() + i) = 1; + } + } else if (padding_algorithm == "VALID") { + for (auto& it : *paddings) { + it = 0; + } + } +} + bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); + UpdatePaddingAndDilation(¶m_.paddings, + ¶m_.dilations, + param_.strides, + padding_algorithm_, + in_dims, + filter_dims); std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < param_.strides.size(); ++i) { output_shape.push_back(ConvOutputSize(in_dims[i + 2], diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index ac0006c8e6f495d36991cf712c3c80dfcf7a46c9..1d6e1c93490a394723d34de76fc3ff8040d31e81 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -93,6 +93,10 @@ class ConvOpLite : public OpLite { << "The fused conv only supports fuse with relu and leaky relu"; } } + + if (op_desc.HasAttr("padding_algorithm")) { + padding_algorithm_ = op_desc.GetAttr("padding_algorithm"); + } // For Int8 if (op_desc.HasAttr("enable_int8")) { param_.enable_int8 = op_desc.GetAttr("enable_int8"); @@ -114,6 +118,7 @@ class ConvOpLite : public OpLite { private: mutable ConvParam param_; + std::string padding_algorithm_{""}; }; } // namespace operators