未验证 提交 f3124b30 编写于 作者: Z Zhaolong Xing 提交者: GitHub

add cudnn conv fp32, int8 support (#1974)

* paddle lite cuda init
can run model with leaky_relu

* add the missing file.
test=develop

* add the load from memory interface.
test=develop

* refine this pr. fix comments
fix ci error
test=develop

* conv impl
fp32:
conv, conv+bais, conv+bias+relu, conv+bias+leaky_relu

int8:
conv, conv+bais+relu(int8 or fp32 output), conv+bias+leaky_relu(int8 or fp32 output)

can run conv+ bias+relu using cxx_api
test=develop

* move the lite/cuda/math to backends/cuda/math
test=develop
上级 cfc7af76
......@@ -503,17 +503,14 @@ function(nv_test TARGET_NAME)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog ${os_dependency_modules})
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY})
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif()
endfunction(nv_test)
......
......@@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(conv2d, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def);
......
......@@ -5,3 +5,4 @@ endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas SRCS blas.cc)
add_subdirectory(math)
......@@ -17,6 +17,7 @@
#include <cublas_api.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cudnn.h>
#include "lite/utils/cp_logging.h"
/*
......@@ -46,6 +47,15 @@
<< "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \
}
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
#define CUDNN_CHECK(condition) \
{ \
cudnnStatus_t status = condition; \
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << CudnnGetErrorInfo(status); \
}
namespace paddle {
namespace lite {
namespace cuda {
......@@ -71,6 +81,44 @@ static const char* CublasErrorInfo(int error) {
}
}
static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
#if CUDNN_VERSION_MIN(6, 0, 0)
case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
#endif
}
return "Unknown cudnn status";
}
} // namespace cuda
} // namespace lite
} // namespace paddle
if(NOT LITE_WITH_CUDA)
return()
endif()
nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans)
set (
math_cuda
cudnn_conv
cuda_activation
cuda_scale
cuda_type_trans
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void relu_kernel(const int num,
const T alpha,
const T* input,
T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
: __ldg(input + index) * alpha;
#else
output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in,
const float4* bias,
float4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float4 bias_ptr = bias[bias_idx];
const float4 scale_ptr = scale[bias_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
out[tid] = packed_val;
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in,
const float4* bias,
char4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float4 bias_ptr = bias[bias_idx];
const float4 scale_ptr = scale[bias_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
char4 result_val;
packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
result_val.x =
from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
result_val.y =
from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
result_val.z =
from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
result_val.w =
from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));
out[tid] = result_val;
}
}
__global__ void relu_int8_nhwc4_kernel(int num,
const float4* in,
float4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x;
packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
packed_val.y = in_ptr.y * scale_ptr.y;
packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
packed_val.z = in_ptr.z * scale_ptr.z;
packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
packed_val.w = in_ptr.w * scale_ptr.w;
packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
out[tid] = packed_val;
}
}
__global__ void relu_int8_nhwc4_kernel(int num,
const float4* in,
char4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
char4 result_val;
packed_val.x = in_ptr.x * scale_ptr.x;
result_val.x =
from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
packed_val.y = in_ptr.y * scale_ptr.y;
result_val.y =
from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
packed_val.z = in_ptr.z * scale_ptr.z;
result_val.z =
from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
packed_val.w = in_ptr.w * scale_ptr.w;
result_val.w =
from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));
out[tid] = result_val;
}
}
template <>
void bias_relu_int8_nhwc4<float>(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<const float4*>(bias),
static_cast<float4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void bias_relu_int8_nhwc4<int8_t>(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<const float4*>(bias),
static_cast<char4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void relu_int8_nhwc4<float>(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<float4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void relu_int8_nhwc4<int8_t>(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<char4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void relu(int, const float*, float*, float, cudaStream_t);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
// fp32
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
// For int8
template <typename out_type>
void bias_relu_int8_nhwc4(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kw = w_dims[3];
int kh = w_dims[2];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
CHECK(ic % param.groups == 0)
<< "The conv input channel shoud be divide group number.";
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
if (param.activation_param.has_active && with_relu_act_) {
CUDNN_CHECK(cudnnSetActivationDescriptor(
this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
}
if (ic == param.groups && ic == oc && ic != 1) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(
cudnnGetConvolutionForwardAlgorithm(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->preference_,
this->workspace_limit_bytes_,
&this->fwd_algo_));
}
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->fwd_algo_,
&this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
if (param.bias) {
int dim_bias[] = {1, oc, 1, 1};
int stride_bias[] = {oc, 1, 1, 1};
cudnnSetTensorNdDescriptor(
this->bias_desc_, CUDNN_DATA_FLOAT, 4, dim_bias, stride_bias);
}
return true;
}
template <>
bool CudnnConv2D<PRECISION(kFloat)>::init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->workspace_size_inbytes_ = 0;
this->workspace_data_ = NULL;
this->workspace_fwd_sizes_ = 0;
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
this->workspace_ = NULL;
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreateFilterDescriptor(&this->filter_desc_);
cudnnCreateConvolutionDescriptor(&this->conv_desc_);
cudnnCreateTensorDescriptor(&this->bias_desc_);
if (param.activation_param.has_active) {
if (param.activation_param.active_type == lite_api::ActivationType::kRelu) {
cudnnCreateActivationDescriptor(&this->act_desc_);
} else {
this->with_relu_act_ = false;
}
}
return create(param, ctx);
}
template <>
bool CudnnConv2D<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (param.activation_param.has_active && with_relu_act_) {
if (b_data) {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data,
bias_desc_,
b_data,
act_desc_,
output_desc_,
o_data));
} else {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data));
CUDNN_CHECK(cudnnActivationForward(handle_,
act_desc_,
&alpha,
output_desc_,
o_data,
&beta,
output_desc_,
o_data));
}
} else {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data));
if (b_data) {
CUDNN_CHECK(cudnnAddTensor(
handle_, &alpha, bias_desc_, b_data, &alpha, output_desc_, o_data));
}
}
if (!with_relu_act_) {
CHECK(param.activation_param.active_type ==
lite_api::ActivationType::kLeakyRelu)
<< "Only support leaky relu now.";
auto out_dims = param.output->dims();
int n = out_dims[0], c = out_dims[1], h = out_dims[2], w = out_dims[3];
int num = n * h * w * c;
float alpha = param.activation_param.Leaky_relu_alpha;
relu(num, o_data, o_data, alpha, this->stream_);
}
return true;
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
int iw = x_dims[2]; // nchw
int ih = x_dims[1];
int ic = x_dims[3];
int ow = o_dims[2];
int oh = o_dims[1];
int oc = o_dims[3];
int kw = w_dims[2];
int kh = w_dims[1];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
std::vector<float> weight_scale = param.weight_scale;
float input_scale = param.input_scale;
float output_scale = param.output_scale;
CHECK(weight_scale.size() == oc)
<< "the num of the weight_scale should be equals to the output channel.";
if (Ptype_out == PRECISION(kInt8)) {
this->temp_tensor_.Resize(o_dims);
this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
for (int i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
}
} else {
for (int i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale);
}
}
this->scale_.Resize({oc});
auto* scale_data = this->scale_.template mutable_data<float>(TARGET(kCUDA));
this->scale_.template Assign<float, lite::DDim, TARGET(kCUDA)>(
weight_scale.data(), this->scale_.dims());
CHECK(ic % param.groups == 0)
<< "The conv input channel shoud be divide group number.";
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_INT8,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
CUDNN_DATA_INT8,
CUDNN_TENSOR_NHWC,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_INT32));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->fwd_algo_,
&(this->workspace_fwd_sizes_)));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
return true;
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->workspace_size_inbytes_ = 0; // 64Mb
this->workspace_data_ = NULL;
this->workspace_fwd_sizes_ = 0;
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
this->workspace_ = NULL;
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreateFilterDescriptor(&this->filter_desc_);
cudnnCreateConvolutionDescriptor(&this->conv_desc_);
cudnnCreateTensorDescriptor(&this->bias_desc_);
if (param.activation_param.has_active) {
if (!(param.activation_param.active_type ==
lite_api::ActivationType::kRelu)) {
this->with_relu_act_ = false;
}
}
return create(param, ctx);
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) {
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = param.filter->data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
float* temp_out;
float* scale = this->scale_.template mutable_data<float>(TARGET(kCUDA));
if (Ptype_out == PRECISION(kInt8)) {
temp_out = this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
} else {
temp_out = param.output->mutable_data<float>(TARGET(kCUDA));
}
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
&alpha,
this->input_desc_,
i_data,
this->filter_desc_,
w_data,
this->conv_desc_,
this->fwd_algo_,
this->workspace_,
this->workspace_fwd_sizes_,
&beta,
this->output_desc_,
temp_out));
auto out_dims = param.output->dims();
int n = out_dims[0], h = out_dims[1], w = out_dims[2], c = out_dims[3];
int num = n * h * w * c / 4;
if (!param.activation_param.has_active && !b_data) {
if (Ptype_out == PRECISION(kInt8)) {
auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
fp32_to_int8_nhwc4(num,
static_cast<const void*>(temp_out),
static_cast<void*>(out),
static_cast<const void*>(scale),
n,
c / 4,
h,
w,
this->stream_);
} else {
fp32_scale_nhwc4(num,
static_cast<const void*>(temp_out),
static_cast<void*>(temp_out),
static_cast<const void*>(scale),
n,
c / 4,
h,
w,
this->stream_);
}
return true;
}
if (b_data) {
if (param.activation_param.has_active) {
float alpha = 0.0;
if (!this->with_relu_act_)
alpha = param.activation_param.Leaky_relu_alpha;
if (Ptype_out == PRECISION(kInt8)) {
auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
bias_relu_int8_nhwc4<int8_t>(num,
static_cast<const void*>(temp_out),
static_cast<const void*>(b_data),
static_cast<void*>(out),
n,
c / 4,
h,
w,
static_cast<const void*>(scale),
alpha,
this->stream_);
} else {
bias_relu_int8_nhwc4<float>(num,
static_cast<const void*>(temp_out),
static_cast<const void*>(b_data),
static_cast<void*>(temp_out),
n,
c / 4,
h,
w,
static_cast<const void*>(scale),
alpha,
this->stream_);
}
return true;
}
}
CHECK(false)
<< "Conv Int8 support Conv, Conv + bias + relu, Conv + bias + leaky_relu";
}
template class CudnnConv2DInt8<PRECISION(kInt8)>;
template class CudnnConv2DInt8<PRECISION(kFloat)>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <PrecisionType Ptype_out>
class CudnnConv2DBase {
public:
CudnnConv2DBase()
: handle_(NULL),
workspace_data_(NULL),
workspace_(NULL),
conv_desc_(NULL),
input_desc_(NULL),
output_desc_(NULL),
filter_desc_(NULL),
act_desc_(NULL),
bias_desc_(NULL),
workspace_fwd_sizes_(0),
workspace_size_inbytes_(0),
fwd_algo_((cudnnConvolutionFwdAlgo_t)0) {}
~CudnnConv2DBase() {
if (conv_desc_) {
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc_));
}
if (input_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
}
if (output_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_));
}
if (act_desc_) {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
if (bias_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_));
}
if (filter_desc_) {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc_));
}
if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_));
}
if (workspace_data_ != NULL) {
cudaFree(workspace_data_);
}
}
protected:
cudaStream_t stream_;
cudnnHandle_t handle_;
cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
// activation descriptor
cudnnActivationDescriptor_t act_desc_;
bool with_relu_act_{true};
size_t workspace_fwd_sizes_;
size_t workspace_size_inbytes_; // size of underlying storage
void* workspace_data_; // underlying storage
void* workspace_; // aliases into _workspaceData
const bool use_tensor_core_ = true;
const size_t workspace_limit_bytes_ = 4 * 1024 * 1024;
const cudnnConvolutionFwdPreference_t preference_ =
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
// For int8
Tensor temp_tensor_;
Tensor scale_;
};
template <PrecisionType Ptype_out>
class CudnnConv2D : public CudnnConv2DBase<Ptype_out> {
public:
CudnnConv2D() : CudnnConv2DBase<Ptype_out>() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::ConvParam& param);
};
template <PrecisionType Ptype_out>
class CudnnConv2DInt8 : CudnnConv2DBase<Ptype_out> {
public:
CudnnConv2DInt8() : CudnnConv2DBase<Ptype_out>() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::ConvParam& param);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iostream"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
float4* out,
const float4* scale,
int N,
int K,
int H,
int W) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x;
packed_val.y = in_ptr.y * scale_ptr.y;
packed_val.z = in_ptr.z * scale_ptr.z;
packed_val.w = in_ptr.w * scale_ptr.w;
out[tid] = packed_val;
}
}
void fp32_scale_nhwc4(int num,
const void* in,
void* out,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
fp32_scale_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<float4*>(out),
static_cast<const float4*>(scale),
N,
K,
H,
W);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
void fp32_scale_nhwc4(int num,
const void* din,
void* dout,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/type_trans.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
char4* out,
const float4* scale,
int N,
int K,
int H,
int W) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
char4 result_val;
result_val.x = from_float<int8_t>(in_ptr.x * scale_ptr.x);
result_val.y = from_float<int8_t>(in_ptr.y * scale_ptr.y);
result_val.z = from_float<int8_t>(in_ptr.z * scale_ptr.z);
result_val.w = from_float<int8_t>(in_ptr.w * scale_ptr.w);
out[tid] = result_val;
}
}
void fp32_to_int8_nhwc4(int num,
const void* in,
void* out,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
fp32_scale_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<char4*>(out),
static_cast<const float4*>(scale),
N,
K,
H,
W);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
void fp32_to_int8_nhwc4(int num,
const void* din,
void* dout,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__device__ T from_float(float x);
template <>
__device__ __forceinline__ float from_float<float>(float x) {
return x;
}
template <>
__device__ __forceinline__ half from_float<half>(float x) {
return __float2half(x);
}
template <>
__device__ __forceinline__ int8_t from_float<int8_t>(float x) {
x = fmaxf(x, INT8_MIN);
x = fminf(x, INT8_MAX);
return __float2int_rn(x);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -107,6 +107,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
......
......@@ -70,6 +70,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
......
......@@ -7,15 +7,18 @@ message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps}
${math_cuda})
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
set(cuda_kernels
conv2d_cuda
mul_compute_cuda
io_copy_compute_cuda
leaky_relu_compute_cuda
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/conv_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
conv_impl_.reset(new lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>);
conv_impl_->init(param, &ctx);
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
conv_impl_->run(param);
}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
conv_impl_.reset(new lite::cuda::math::CudnnConv2DInt8<Ptype_out>);
conv_impl_->init(param, &ctx);
}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::Run() {
auto& param = this->Param<param_t>();
conv_impl_->run(param);
}
template class ConvComputeInt8<PRECISION(kInt8)>;
template class ConvComputeInt8<PRECISION(kFloat)>;
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kInt8)>,
int8_out)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ConvCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConvCompute() = default;
private:
std::unique_ptr<lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>> conv_impl_;
};
template <PrecisionType Ptype_out>
class ConvComputeInt8 : public KernelLite<TARGET(kCUDA), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConvComputeInt8() = default;
private:
std::unique_ptr<lite::cuda::math::CudnnConv2DInt8<Ptype_out>> conv_impl_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
float random(float low, float high) {
static std::mt19937 mt(100);
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
TEST(conv_compute, fp32) {
ConvCompute conv_fp32;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param;
param.activation_param = act_param;
param.paddings = {1, 1};
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 1, h = 3, w = 3;
int c_o = 1, h_o = 3, w_o = 3;
y.Resize({n, c_o, h_o, w_o});
x_cpu.Resize({n, c, h, w});
filter_cpu.Resize({c_o, c / param.groups, 3, 3});
y_cpu.Resize({n, c_o, h_o, w_o});
bias_cpu.Resize({c_o});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* filter_cpu_data = filter_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
float* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i;
}
std::vector<float> weight = {-0.2209115,
-0.17199445,
-0.2059412,
0.6763207,
-0.12260777,
-0.43123743,
-0.49696392,
-0.27471393,
-0.81017196};
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = weight[i];
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = 0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<float, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data, bias_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
// param.bias = &bias;
conv_fp32.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
conv_fp32.SetContext(std::move(ctx));
conv_fp32.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-0.8, -0.7};
for (int i = 0; i < y.numel(); i++) {
LOG(INFO) << y_cpu_data[i];
}
}
TEST(conv_compute, int8) {
ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
operators::ConvParam param;
// param.activation_param = act_param;
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3;
y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c});
filter_cpu.Resize({c, 3, 3, c / param.groups});
y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c});
auto* x_data = x.mutable_data<int8_t>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
auto* y_cpu_data = x_cpu.mutable_data<float>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(1);
}
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(1);
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
param.weight_scale = {1, 2, 3, 4};
int8_conv_fp32out.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
int8_conv_fp32out.SetContext(std::move(ctx));
int8_conv_fp32out.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {36, 72, 108, 144};
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
}
}
TEST(conv_compute, int8_int8_out) {
ConvComputeInt8<PRECISION(kInt8)> int8_conv_fp32out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param;
param.activation_param = act_param;
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3;
y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c});
filter_cpu.Resize({c, 3, 3, c / param.groups});
y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c});
auto* x_data = x.mutable_data<int8_t>(TARGET(kCUDA));
auto* y_data = y.mutable_data<int8_t>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
auto* y_cpu_data = x_cpu.mutable_data<int8_t>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
}
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
param.weight_scale = {0.01, 0.02, 0.03, 0.04};
param.bias = &bias;
int8_conv_fp32out.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
int8_conv_fp32out.SetContext(std::move(ctx));
int8_conv_fp32out.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-1, 4, 0, -2};
for (int i = 0; i < y.numel(); i++) {
// EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
LOG(INFO) << float(y_cpu_data[i]);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -89,6 +89,7 @@ class IoCopyCudaToHostCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto mem_size = param.x->memory_size();
LOG(INFO) << "io copy cuda to host " << mem_size;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
CopyToHostSync(data, param.x->raw_data(), mem_size);
}
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/leaky_relu_compute.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/yolo_box_compute.h"
......@@ -179,7 +178,7 @@ void YoloBoxCompute::Run() {
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
anchors_.Resize(static_cast<int>({anchors.size()}));
anchors_.Resize({static_cast<int64_t>(anchors.size())});
int* d_anchors = anchors_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_anchors,
anchors.data(),
......
......@@ -466,7 +466,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
#define SET_DATA_TYPE(precision, type_desc) \
case precision: \
desc.SetDataType(type_desc); \
break
break;
SET_DATA_TYPE(PRECISION(kFloat), VarDescAPI::VarDataType::FP32);
SET_DATA_TYPE(PRECISION(kInt8), VarDescAPI::VarDataType::INT8);
......@@ -487,14 +487,14 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
if (tensor.target() == TARGET(kCUDA)) {
switch (tensor.precision()) {
#define DO(precision, type) \
case precision: \
case precision: { \
std::unique_ptr<type> tmp_buffer(new type[tensor.data_size()]); \
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), \
tensor.data<type>(), \
tensor.data_size(), \
IoDirection::DtoH); \
desc.SetData<type>(tmp_buffer.get(), tensor.data_size()); \
break
} break;
DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t);
......@@ -512,7 +512,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
#define DO(precision, type) \
case precision: \
desc.SetData<type>(tensor.data<type>(), tensor.data_size()); \
break
break;
DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t);
......
......@@ -15,8 +15,10 @@
#pragma once
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/core/types.h"
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h"
......@@ -203,6 +205,28 @@ struct ConcatParam {
int axis{0};
};
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
std::string Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
lite::Tensor* Out{};
bool has_active{false};
lite_api::ActivationType active_type;
};
struct ActivationGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out{};
// for backward
lite::Tensor* X_grad{};
const lite::Tensor* Out_grad{};
};
// For Convolution op
struct ConvParam {
lite::Tensor* x{};
......@@ -226,6 +250,8 @@ struct ConvParam {
float scale_weights{1.0f}; // only used with mkl-dnn int8
bool force_fp32_output{false}; // only used in mkl-dnn int8
std::string data_format{"Anylayout"};
// for activation
ActivationParam activation_param;
// for int8
WITH_INT8_CONFIG
};
......@@ -320,26 +346,6 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
std::string act_type;
};
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
std::string Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
lite::Tensor* Out{};
};
struct ActivationGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out{};
// for backward
lite::Tensor* X_grad{};
const lite::Tensor* Out_grad{};
};
/// ----------------------- mean operators ----------------------
struct MeanParam {
const lite::Tensor* X{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册