提交 305130fc 编写于 作者: Y yiicy 提交者: Xiaoyang LI

add cuda op(pool & softmax), support conv with padding_algorithm

* cuda add softmax and pool op

* * fix armlinux can find sys/system_properties.h
* conv add padding_algorithm
test=develop

* delete padding_algorithm in op param, test=develop

* fix bugs, test=develop
上级 91059871
...@@ -18,6 +18,7 @@ endif() ...@@ -18,6 +18,7 @@ endif()
set(ANDROID TRUE) set(ANDROID TRUE)
add_definitions(-DLITE_WITH_LINUX) add_definitions(-DLITE_WITH_LINUX)
add_definitions(-DLITE_WITH_ANDROID)
if(NOT DEFINED ANDROID_NDK) if(NOT DEFINED ANDROID_NDK)
set(ANDROID_NDK $ENV{NDK_ROOT}) set(ANDROID_NDK $ENV{NDK_ROOT})
......
...@@ -35,6 +35,9 @@ ...@@ -35,6 +35,9 @@
#include <sys/syscall.h> #include <sys/syscall.h>
#include <unistd.h> #include <unistd.h>
#endif #endif
#ifdef LITE_WITH_ANDROID
#include <sys/system_properties.h>
#endif
#if __APPLE__ #if __APPLE__
#include "TargetConditionals.h" #include "TargetConditionals.h"
#if LITE_WITH_IPHONE #if LITE_WITH_IPHONE
...@@ -218,6 +221,7 @@ void get_cpu_arch(std::vector<ARMArch>* archs, const int cpu_num) { ...@@ -218,6 +221,7 @@ void get_cpu_arch(std::vector<ARMArch>* archs, const int cpu_num) {
#ifdef LITE_WITH_LINUX #ifdef LITE_WITH_LINUX
std::string get_cpu_name() { std::string get_cpu_name() {
std::string cpu_name;
FILE* fp = fopen("/proc/cpuinfo", "rb"); FILE* fp = fopen("/proc/cpuinfo", "rb");
if (!fp) { if (!fp) {
return ""; return "";
...@@ -229,12 +233,23 @@ std::string get_cpu_name() { ...@@ -229,12 +233,23 @@ std::string get_cpu_name() {
break; break;
} }
if (strstr(line, "Hardware") != NULL) { if (strstr(line, "Hardware") != NULL) {
fclose(fp); cpu_name = std::string(line);
return std::string(line);
} }
} }
#ifdef LITE_WITH_ANDROID
// cpu name concat board name, platform name and chip name
char board_name[128];
char platform_name[128];
char chip_name[128];
__system_property_get("ro.product.board", board_name);
__system_property_get("ro.board.platform", platform_name);
__system_property_get("ro.chipname", chip_name);
cpu_name =
cpu_name + "_" + board_name + "_" + platform_name + "_" + chip_name;
#endif
std::transform(cpu_name.begin(), cpu_name.end(), cpu_name.begin(), ::toupper);
fclose(fp); fclose(fp);
return ""; return cpu_name;
} }
int get_min_freq_khz(int cpuid) { int get_min_freq_khz(int cpuid) {
...@@ -780,7 +795,9 @@ bool DeviceInfo::SetCPUInfoByName() { ...@@ -780,7 +795,9 @@ bool DeviceInfo::SetCPUInfoByName() {
cluster_ids_ = {0, 0, 0, 0}; cluster_ids_ = {0, 0, 0, 0};
SetArchInfo(1, kA53); SetArchInfo(1, kA53);
return true; return true;
} else if (dev_name_.find("KIRIN980") != std::string::npos) { // Kirin 980 } else if (dev_name_.find("KIRIN980") != std::string::npos ||
dev_name_.find("KIRIN990") !=
std::string::npos) { // Kirin 980, Kirin 990
core_num_ = 8; core_num_ = 8;
core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7};
big_core_ids_ = {4, 5, 6, 7}; big_core_ids_ = {4, 5, 6, 7};
......
...@@ -39,6 +39,13 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -39,6 +39,13 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int pad = param.paddings[0]; int pad = param.paddings[0];
int stride = param.strides[0]; int stride = param.strides[0];
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
int chout = param.output->dims()[1];
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
bool kps_equal = (param.paddings[0] == param.paddings[1]) && bool kps_equal = (param.paddings[0] == param.paddings[1]) &&
(param.strides[0] == param.strides[1]) && (kw == kh); (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
...@@ -54,7 +61,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -54,7 +61,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG(3) << "invoking dw conv"; VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation) {
if (ic >= 32 && oc >= 32) { if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) {
/// winograd conv impl /// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv"; VLOG(3) << "invoking winograd conv";
...@@ -63,8 +70,8 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -63,8 +70,8 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv"; VLOG(3) << "invoking direct conv";
} }
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 2 &&
no_dilation) { chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl /// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv"; VLOG(3) << "invoking direct conv";
......
...@@ -16,6 +16,8 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute. ...@@ -16,6 +16,8 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
...@@ -26,5 +28,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c ...@@ -26,5 +28,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/pool_compute.h"
#include "lite/utils/macros.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
#define MAX_VAL(a, b) (((a) > (b)) ? (a) : (b))
#define MIN_VAL(a, b) (((a) < (b)) ? (a) : (b))
__global__ void max_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int w_s = w_id * stride_w - pad_w;
const int iw_s = MAX_VAL(w_s, 0);
const int iw_e = MIN_VAL(w_s + win_w, in_w);
const int w_loop = iw_e - iw_s;
const int h_s = h_id * stride_h - pad_h;
const int ih_s = MAX_VAL(h_s, 0);
const int ih_e = MIN_VAL(h_s + win_h, in_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float max_val = -FLT_MAX;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
max_val = MAX_VAL(max_val, *(in_p + j));
}
in_p += in_w;
}
max_val = max_val == -FLT_MAX ? 0.f : max_val;
output[nc_id * spatial_out + h_id * out_w + w_id] = max_val;
}
}
__global__ void adaptive_max_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int iw_s = floor(static_cast<double>(w_id * in_w) / out_w);
const int iw_e = ceil(static_cast<double>((w_id + 1) * in_w) / out_w);
const int w_loop = iw_e - iw_s;
const int ih_s = floor(static_cast<double>(h_id * in_h) / out_h);
const int ih_e = ceil(static_cast<double>((h_id + 1) * in_h) / out_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float max_val = -FLT_MAX;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
max_val = MAX_VAL(max_val, *(in_p + j));
}
in_p += in_w;
}
output[nc_id * spatial_out + h_id * out_w + w_id] = max_val;
}
}
__global__ void avg_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
bool exclusive,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int w_s = w_id * stride_w - pad_w;
const int iw_s = MAX_VAL(w_s, 0);
const int iw_e = MIN_VAL(w_s + win_w, in_w);
const int w_loop = iw_e - iw_s;
const int h_s = h_id * stride_h - pad_h;
const int ih_s = MAX_VAL(h_s, 0);
const int ih_e = MIN_VAL(h_s + win_h, in_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float sum_val = 0.f;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
sum_val += *(in_p + j);
}
in_p += in_w;
}
int pool_size = exclusive ? h_loop * w_loop : win_w * win_h;
pool_size = pool_size == 0 ? 1 : pool_size;
output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size;
}
}
__global__ void adaptive_avg_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int iw_s = floor(static_cast<double>(w_id * in_w) / out_w);
const int iw_e = ceil(static_cast<double>((w_id + 1) * in_w) / out_w);
const int w_loop = iw_e - iw_s;
const int ih_s = floor(static_cast<double>(h_id * in_h) / out_h);
const int ih_e = ceil(static_cast<double>((h_id + 1) * in_h) / out_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float sum_val = 0.f;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
sum_val += *(in_p + j);
}
in_p += in_w;
}
int pool_size = h_loop * w_loop;
pool_size = pool_size == 0 ? 1 : pool_size;
output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size;
}
}
__global__ void global_max_pool_kernel(const float* input,
float* output,
const int in_h,
const int in_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int spatial_in = in_h * in_w;
const float* in_p = input + gid * spatial_in;
int i = 0;
float max_val = -0.f;
// unroll 8
for (; i < spatial_in - 7; i += 8) {
max_val = MAX_VAL(max_val, *(in_p + 0));
max_val = MAX_VAL(max_val, *(in_p + 1));
max_val = MAX_VAL(max_val, *(in_p + 2));
max_val = MAX_VAL(max_val, *(in_p + 3));
max_val = MAX_VAL(max_val, *(in_p + 4));
max_val = MAX_VAL(max_val, *(in_p + 5));
max_val = MAX_VAL(max_val, *(in_p + 6));
max_val = MAX_VAL(max_val, *(in_p + 7));
in_p += 8;
}
for (; i < spatial_in; i++) {
max_val = MAX_VAL(max_val, *in_p);
in_p++;
}
output[gid] = max_val;
}
}
__global__ void global_avg_pool_kernel(const float* input,
float* output,
const int in_h,
const int in_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int spatial_in = in_h * in_w;
const float* in_p = input + gid * spatial_in;
int i = 0;
float sum_val = 0.f;
// unroll 8
for (; i < spatial_in - 7; i += 8) {
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
}
for (; i < spatial_in; i++) {
sum_val += *in_p++;
}
output[gid] = sum_val / spatial_in;
}
}
void PoolCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
auto x_dims = param.x->dims();
auto out_dims = param.output->dims();
const int in_h = x_dims[2];
const int in_w = x_dims[3];
const int out_h = out_dims[2];
const int out_w = out_dims[3];
const int spatial_in = in_h * in_w;
const int spatial_out = out_h * out_w;
const int win_h = param.ksize[0];
const int win_w = param.ksize[1];
const int stride_h = param.strides[0];
const int stride_w = param.strides[1];
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
const int total_threads = out_dims.production();
const int threads = 512;
const int blocks = (total_threads + threads - 1) / threads;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (param.global_pooling) {
if (param.pooling_type == "max") {
global_max_pool_kernel<<<blocks, threads, 0, stream>>>(
input_data, output_data, in_h, in_w, total_threads);
} else {
global_avg_pool_kernel<<<blocks, threads, 0, stream>>>(
input_data, output_data, in_h, in_w, total_threads);
}
} else {
if (!adaptive) {
if (param.pooling_type == "max") {
max_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
} else {
avg_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
exclusive,
total_threads);
}
} else {
if (param.pooling_type == "max") {
adaptive_max_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
} else {
adaptive_avg_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
}
}
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
pool2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::PoolCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class PoolCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::PoolParam;
void Run() override;
virtual ~PoolCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
static int PoolOutputSize(
int input_size, int filter_size, int padding, int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
const auto x_dims = param_->x->dims();
std::vector<int>& ksize = param_->ksize;
if (param_->global_pooling) {
ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
param_->paddings[i] = 0;
ksize[i] = static_cast<int>(x_dims[i + 2]);
}
}
std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
if (param_->adaptive) {
output_shape.insert(
output_shape.end(), param_->ksize.begin(), param_->ksize.end());
} else {
for (size_t i = 0; i < param_->ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(x_dims[i + 2],
param_->ksize[i],
param_->paddings[i],
param_->strides[i],
param_->ceil_mode));
}
}
return output_shape;
}
static void pool_compute_ref(const operators::PoolParam& param) {
auto& in_dims = param.x->dims();
auto& out_dims = param.output->dims();
const float* src_ptr = param.x->data<const float>();
float* dst_ptr = param.output->mutable_data<float>();
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = param.paddings;
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
std::string data_format = param.data_format;
int in_n = in_dims[0];
int in_c = in_dims[1];
int in_h = in_dims[2];
int in_w = in_dims[3];
int size_in_n = in_c * in_h * in_w;
int size_in_c = in_h * in_w;
int out_h = out_dims[2];
int out_w = out_dims[3];
int size_out_n = in_c * out_h * out_w;
int size_out_c = out_h * out_w;
int window_h = ksize[0];
int window_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
if (global_pooling == true) {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * size_in_n + c * size_in_c;
float res = src[0];
if (pooling_type == "max") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res = cur_val > res ? cur_val : res;
}
} else if (pooling_type == "avg") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res += cur_val;
}
res /= size_in_c;
}
dst_ptr[n * size_out_n + c] = res;
}
}
} else {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
for (int h = 0; h < out_h; ++h) {
int sh = h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int w = 0; w < out_w; ++w) {
int sw = w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
int pooling_size = (ew - sw) * (eh - sh);
if (pooling_size == 0) {
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = 0.f;
continue;
}
float res = 0.f;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
res = src_ptr[src_idx];
} else {
if (pooling_type == "max") {
res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx];
}
if (pooling_type == "avg") {
res += src_ptr[src_idx];
}
}
}
}
if (pooling_type == "avg") {
if (exclusive) {
res /= pooling_size;
} else {
res /= window_h * window_w;
}
}
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res;
}
}
}
}
}
}
TEST(pool_cuda, compute) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
PoolCompute pool;
operators::PoolParam param;
pool.SetContext(std::move(ctx));
lite::Tensor x;
lite::Tensor x_cpu;
lite::Tensor output;
lite::Tensor output_cpu;
lite::Tensor output_ref;
for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {true, false}) {
for (auto global_pooling : {true, false}) {
for (auto exclusive : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) {
for (auto n : {1, 2}) {
for (auto c : {1, 3}) {
for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.dims().production(); ++i) {
float sign = i % 3 == 0 ? -0.03 : 0.05f;
x_cpu_data[i] = sign * (i % 128);
}
x.Assign<float, DDim, TARGET(kCUDA)>(x_cpu_data,
x_cpu.dims());
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
if (global_pooling) {
param.ksize = {h, w};
} else {
param.ksize = {ksize, ksize};
}
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
if (output_shape[2] * output_shape[3] == 0) continue;
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
output_cpu.Resize(DDim(output_shape));
auto* output_data =
output.mutable_data<float>(TARGET(kCUDA));
auto* output_ref_data =
output_ref.mutable_data<float>();
auto* output_cpu_data =
output_cpu.mutable_data<float>();
// compute
pool.SetParam(param);
pool.Launch();
// compute ref
param.x = &x_cpu;
param.output = &output_ref;
pool_compute_ref(param);
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
// compare
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_cpu_data[i], output_ref_data[i], 1e-4);
}
VLOG(3) << "compare pass";
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/softmax_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
extern __shared__ char tile[];
template <typename dtype>
__global__ void sharemem_softmax_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
int inner_num,
int outer_num,
int axis_size) {
dtype* data = reinterpret_cast<dtype*>(tile) + threadIdx.x;
//! compute thread index and real data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
int blocksize = blockDim.x;
int real_index = idx_outer * inner_num + idx_inner;
int loop_idx = real_index;
//! read all data to sharemem in softmax channel
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
data[i * blocksize] = in_data[loop_idx];
loop_idx += inner_num;
}
//! get maximum value in softmax channel
dtype max_data = data[0];
#pragma unroll
for (int i = 1; i < axis_size; ++i) {
dtype dt = data[i * blocksize];
if (max_data < dt) {
max_data = dt;
}
}
//! subtract then summarize
dtype sum = 0;
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
dtype* dt = data + i * blocksize;
*dt = expf(*dt - max_data);
sum += *dt;
}
//! write back result
loop_idx = real_index;
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
out_data[loop_idx] = data[i * blocksize] / sum;
loop_idx += inner_num;
}
}
}
//! general kernel for softmax
template <typename dtype>
__global__ void softmax_max_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
dtype min_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
//! get maximum data across softmax axis
dtype max_data = min_data;
for (int i = 0; i < axis_size; ++i) {
max_data =
in_data[real_index] > max_data ? in_data[real_index] : max_data;
real_index += inner_num;
}
out_data[idx] = max_data;
}
}
template <typename dtype>
__global__ void softmax_sub_exp_sum_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
const dtype* max_data,
dtype* sum_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
dtype max_data_cur = max_data[idx];
dtype sum_data_cur = 0;
int real_index = idx_outer * inner_num + idx_inner;
//! compute exp and summarize across the softmax axis
for (int i = 0; i < axis_size; ++i) {
dtype sub_data = in_data[real_index] - max_data_cur;
sub_data = expf(sub_data);
sum_data_cur += sub_data;
out_data[real_index] = sub_data;
real_index += inner_num;
}
sum_data[idx] = sum_data_cur;
}
}
template <typename dtype>
__global__ void softmax_divid_output_kernel(int total_size,
dtype* io_data,
const dtype* sum_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
dtype sum_data_cur = 1.f / sum_data[idx];
int real_index = idx_outer * inner_num + idx_inner;
//! compute final result
for (int i = 0; i < axis_size; ++i) {
io_data[real_index] = io_data[real_index] * sum_data_cur;
real_index += inner_num;
}
}
}
void SoftmaxCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto x_dims = param.x->dims();
auto x_rank = x_dims.size();
int axis = param.axis;
if (axis < 0) {
axis += x_rank;
}
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int total_threads = inner_num * outer_num;
int axis_size = x_dims[axis];
int device_id;
const int threads = 512;
const int blocks = (total_threads + threads - 1) / threads;
cudaGetDevice(&device_id);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id);
size_t sharedmem_size = deviceProp.sharedMemPerBlock;
int max_dimsize = sharedmem_size / sizeof(float) / threads;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (axis_size <= max_dimsize) {
int use_sharemem_size = axis_size * threads * sizeof(float);
sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
total_threads,
input_data,
output_data,
inner_num,
outer_num,
axis_size);
} else {
//! re_alloc device memory
Tensor tmax_data;
Tensor tsum_data;
tmax_data.Resize({1, 1, 1, outer_num * inner_num});
tsum_data.Resize({1, 1, 1, outer_num * inner_num});
auto max_data = tmax_data.mutable_data<float>(TARGET(kCUDA));
auto sum_data = tsum_data.mutable_data<float>(TARGET(kCUDA));
//! firstly, get maximum data
float min_data = std::numeric_limits<float>::min();
softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads,
input_data,
max_data,
min_data,
inner_num,
outer_num,
axis_size);
//! then, compute exp and sum data
softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads,
input_data,
output_data,
max_data,
sum_data,
inner_num,
outer_num,
axis_size);
//! last, compute divided output
softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, output_data, sum_data, inner_num, outer_num, axis_size);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SoftmaxCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("axis",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SoftmaxCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::SoftmaxParam;
void Run() override;
virtual ~SoftmaxCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/softmax_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
template <typename dtype>
static void softmax_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim x_dims = param.x->dims();
ASSERT_EQ(x_dims.data(), param.output->dims().data());
auto x_rank = x_dims.size();
int axis = param.axis;
if (axis < 0) {
axis += x_rank;
}
int axis_size = x_dims[axis];
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
output_data[offset] = exp(x_data[offset] - max_data);
sum_data += output_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
output_data[offset] /= sum_data;
offset += inner_num;
}
}
}
TEST(softmax_cuda, compute) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
SoftmaxCompute softmax;
operators::SoftmaxParam param;
softmax.SetContext(std::move(ctx));
lite::Tensor x;
lite::Tensor x_cpu;
lite::Tensor output;
lite::Tensor output_cpu;
lite::Tensor output_ref;
for (auto n : {1, 3}) {
for (auto c : {1, 4}) {
for (auto h : {5, 1, 112}) {
for (auto w : {1, 6, 112}) {
for (auto axis : {-2, -1, 0, 1, 2}) {
x.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
output.Resize({n, c, h, w});
output_cpu.Resize({n, c, h, w});
output_ref.Resize({n, c, h, w});
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* output_data = output.mutable_data<float>(TARGET(kCUDA));
auto* output_cpu_data = output_ref.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_cpu_data[i] = i;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data,
x_cpu.dims());
param.x = &x;
param.axis = axis;
param.output = &output;
softmax.SetParam(param);
softmax.Launch();
param.x = &x_cpu;
param.output = &output_ref;
softmax_compute_ref<float>(param);
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/conv_op.h" #include "lite/operators/conv_op.h"
#include <algorithm>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -51,10 +52,41 @@ inline int ConvOutputSize( ...@@ -51,10 +52,41 @@ inline int ConvOutputSize(
return output_size; return output_size;
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_dims[i + 2],
(int64_t)0);
// pad
*(paddings->begin() + i) = pad_sum / 2;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvOpLite::InferShape() const { bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
UpdatePaddingAndDilation(&param_.paddings,
&param_.dilations,
param_.strides,
padding_algorithm_,
in_dims,
filter_dims);
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < param_.strides.size(); ++i) { for (size_t i = 0; i < param_.strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], output_shape.push_back(ConvOutputSize(in_dims[i + 2],
......
...@@ -93,6 +93,10 @@ class ConvOpLite : public OpLite { ...@@ -93,6 +93,10 @@ class ConvOpLite : public OpLite {
<< "The fused conv only supports fuse with relu and leaky relu"; << "The fused conv only supports fuse with relu and leaky relu";
} }
} }
if (op_desc.HasAttr("padding_algorithm")) {
padding_algorithm_ = op_desc.GetAttr<std::string>("padding_algorithm");
}
// For Int8 // For Int8
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8"); param_.enable_int8 = op_desc.GetAttr<bool>("enable_int8");
...@@ -114,6 +118,7 @@ class ConvOpLite : public OpLite { ...@@ -114,6 +118,7 @@ class ConvOpLite : public OpLite {
private: private:
mutable ConvParam param_; mutable ConvParam param_;
std::string padding_algorithm_{""};
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册