提交 c4b5e32c 编写于 作者: Z Zhaolong Xing 提交者: GitHub

can run yolov3 fp32 on cuda devices (#2092)

* add conv int8 support(in condition which the input or output channel not be the times of 4)
add add_kernel for cuda.

* can run yolov3 fp32
test=develop

* 1. fix bug with yolov3 run
test=develop
上级 b82a9eec
...@@ -301,6 +301,18 @@ function(add_kernel TARGET device level) ...@@ -301,6 +301,18 @@ function(add_kernel TARGET device level)
set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "") set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "")
endif() endif()
if ("${device}" STREQUAL "CUDA")
if (NOT LITE_WITH_CUDA)
return()
endif()
set(cuda_kernels "${cuda_kernels};${TARGET}" CACHE INTERNAL "")
foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach()
nv_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS})
return()
endif()
# the source list will collect for paddle_use_kernel.h code generation. # the source list will collect for paddle_use_kernel.h code generation.
foreach(src ${args_SRCS}) foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n") file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
......
...@@ -147,6 +147,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -147,6 +147,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
factor.ConsiderPrecision(); factor.ConsiderPrecision();
factor.ConsiderDataLayout();
optimizer_.Run(std::move(program), valid_places, factor, passes); optimizer_.Run(std::move(program), valid_places, factor, passes);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = optimizer_.exec_scope();
} }
......
...@@ -41,6 +41,10 @@ const int8_t *Tensor::data() const { ...@@ -41,6 +41,10 @@ const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>(); return ctensor(raw_tensor_)->data<int8_t>();
} }
template <>
int *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<int>();
}
template <> template <>
float *Tensor::mutable_data() const { float *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<float>(); return tensor(raw_tensor_)->mutable_data<float>();
......
...@@ -5,7 +5,7 @@ endif() ...@@ -5,7 +5,7 @@ endif()
nv_library(cuda_activation SRCS activation.cu) nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu) nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu) nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cuda_transpose SRCS transpose.cu) nv_library(cuda_transpose SRCS transpose.cu )
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans) cuda_type_trans)
......
...@@ -37,6 +37,22 @@ __global__ void relu_kernel(const int num, ...@@ -37,6 +37,22 @@ __global__ void relu_kernel(const int num,
} }
} }
template <typename T>
__global__ void bias_relu_kernel(const int num,
const T alpha,
const T* input,
T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
: __ldg(input + index) * alpha;
#else
output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num, __global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in, const float4* in,
const float4* bias, const float4* bias,
...@@ -277,7 +293,23 @@ void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) { ...@@ -277,7 +293,23 @@ void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) {
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error); if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
} }
template <typename T>
void bias_relu(int num,
const T* din,
const float* bias,
T* dout,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void relu(int, const float*, float*, float, cudaStream_t); template void relu(int, const float*, float*, float, cudaStream_t);
template void bias_relu(
int, const float*, const float* bias, float*, float, cudaStream_t);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
......
...@@ -26,6 +26,26 @@ namespace math { ...@@ -26,6 +26,26 @@ namespace math {
template <typename T> template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream); void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
template <typename T>
void bias_relu(int num,
const T* din,
const float* bias,
T* dout,
float alpha,
cudaStream_t stream);
// For int8 // For int8
template <typename out_type> template <typename out_type>
void bias_relu_int8_nhwc4(int num, void bias_relu_int8_nhwc4(int num,
...@@ -40,18 +60,6 @@ void bias_relu_int8_nhwc4(int num, ...@@ -40,18 +60,6 @@ void bias_relu_int8_nhwc4(int num,
float alpha, float alpha,
cudaStream_t stream); cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/backends/cuda/math/cudnn_conv.h" #include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h" #include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/conv_op_cache_cudnn.h"
#include "lite/backends/cuda/math/scale.h" #include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h" #include "lite/backends/cuda/math/type_trans.h"
...@@ -87,6 +88,56 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param, ...@@ -87,6 +88,56 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
if (ic == param.groups && ic == oc && ic != 1) { if (ic == param.groups && ic == oc && ic != 1) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else if (1) {
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
auto* o_data = param.output->mutable_data<float>(TARGET(kCUDA));
int workspace_size_limit = 256 * 1024 * 1024;
auto search_func = [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
this->handle_,
this->input_desc_,
i_data,
this->filter_desc_,
w_data,
this->conv_desc_,
this->output_desc_,
o_data,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
&returned_algo_count,
fwd_perf_stat.data(),
cudnn_workspace,
workspace_size_limit));
};
ResetWorkSpace();
CUDA_CALL(cudaMalloc(&this->workspace_data_, workspace_size_limit));
cudnn_find_func(this->workspace_data_);
ResetWorkSpace();
VLOG(2) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(2) << stat.algo << ": " << stat.status << " " << stat.time << " "
<< stat.memory;
}
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> algo_cache;
this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(),
w_dims.Vectorize(),
param.strides,
param.paddings,
param.dilations,
0,
search_func);
} else { } else {
CUDNN_CHECK( CUDNN_CHECK(
cudnnGetConvolutionForwardAlgorithm(this->handle_, cudnnGetConvolutionForwardAlgorithm(this->handle_,
...@@ -108,9 +159,7 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param, ...@@ -108,9 +159,7 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
&this->workspace_fwd_sizes_)); &this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) { if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_; this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) { ResetWorkSpace();
cudaFree(this->workspace_data_);
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_); cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_); this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
} }
...@@ -272,16 +321,21 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param, ...@@ -272,16 +321,21 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
std::vector<float> weight_scale = param.weight_scale; std::vector<float> weight_scale = param.weight_scale;
float input_scale = param.input_scale; float input_scale = param.input_scale;
float output_scale = param.output_scale; float output_scale = param.output_scale;
CHECK(weight_scale.size() == oc) CHECK(weight_scale.size() == static_cast<size_t>(oc))
<< "the num of the weight_scale should be equals to the output channel."; << "the num of the weight_scale should be equals to the output channel.";
if (Ptype_out == PRECISION(kInt8)) { if (Ptype_out == PRECISION(kInt8)) {
this->temp_tensor_.Resize(o_dims); this->temp_tensor_.Resize(o_dims);
this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA)); this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
for (int i = 0; i < weight_scale.size(); i++) { for (size_t i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale) / output_scale; weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
} }
auto* b_data = param.bias ? param.bias->mutable_data<float>() : nullptr;
if (b_data) {
scale(param.bias->numel(), b_data, b_data, 1.f / output_scale);
}
} else { } else {
for (int i = 0; i < weight_scale.size(); i++) { for (size_t i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale); weight_scale[i] = (weight_scale[i] * input_scale);
} }
} }
...@@ -322,8 +376,11 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param, ...@@ -322,8 +376,11 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
oc, oc,
oh, oh,
ow)); ow));
if (ic % 4 == 0 && oc % 4 == 0) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
CUDNN_CHECK( CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_, cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_, this->input_desc_,
...@@ -331,14 +388,15 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param, ...@@ -331,14 +388,15 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
this->conv_desc_, this->conv_desc_,
this->output_desc_, this->output_desc_,
this->fwd_algo_, this->fwd_algo_,
&(this->workspace_fwd_sizes_))); &this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) { if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_; this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) { if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_); CUDA_CALL(cudaFree(this->workspace_data_));
} }
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_); CUDA_CALL(
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_));
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_); this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
} }
......
...@@ -66,9 +66,15 @@ class CudnnConv2DBase { ...@@ -66,9 +66,15 @@ class CudnnConv2DBase {
if (handle_ != NULL) { if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_)); CUDNN_CHECK(cudnnDestroy(handle_));
} }
ResetWorkSpace();
}
protected:
void ResetWorkSpace() {
if (workspace_data_ != NULL) { if (workspace_data_ != NULL) {
cudaFree(workspace_data_); CUDA_CALL(cudaFree(workspace_data_));
} }
workspace_data_ = NULL;
} }
protected: protected:
......
...@@ -21,6 +21,18 @@ namespace lite { ...@@ -21,6 +21,18 @@ namespace lite {
namespace cuda { namespace cuda {
namespace math { namespace math {
template <typename T>
__global__ void scale_kernel(int num, const T* in, T* out, const float scale) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
#if __CUDA_ARCH__ >= 350
out[tid] = __ldg(in + tid) * scale;
#else
out[tid] = in[tid] * scale;
#endif
}
}
__global__ void fp32_scale_nhwc4_kernel(int num, __global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in, const float4* in,
float4* out, float4* out,
...@@ -68,6 +80,23 @@ void fp32_scale_nhwc4(int num, ...@@ -68,6 +80,23 @@ void fp32_scale_nhwc4(int num,
if (error != cudaSuccess) std::cout << cudaGetErrorString(error); if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
} }
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale);
}
template <typename T>
void scale(int num, const T* in, T* out, float scale) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread>>>(num, in, out, scale);
}
template void scale(int num, const float*, float*, float, cudaStream_t);
template void scale(int num, const float*, float*, float);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
......
...@@ -31,6 +31,12 @@ void fp32_scale_nhwc4(int num, ...@@ -31,6 +31,12 @@ void fp32_scale_nhwc4(int num,
int W, int W,
cudaStream_t stream); cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
......
...@@ -89,6 +89,7 @@ void BatchTranspose2DCUDAImpl(const int N, ...@@ -89,6 +89,7 @@ void BatchTranspose2DCUDAImpl(const int N,
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \ BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
} }
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC #undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \ #define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
...@@ -102,6 +103,7 @@ TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) ...@@ -102,6 +103,7 @@ TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \ BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
} }
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float) TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW #undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template <typename T> template <typename T>
...@@ -169,8 +171,6 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims, ...@@ -169,8 +171,6 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>( TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>(
size, ndim, d_strides, d_y_dims, X, Y); size, ndim, d_strides, d_y_dims, X, Y);
// cudaError_t error = cudaGetLastError();
// if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ #define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/backends/cuda/target_wrapper.h" #include "lite/backends/cuda/target_wrapper.h"
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,13 +26,11 @@ size_t TargetWrapperCuda::num_devices() { ...@@ -25,13 +26,11 @@ size_t TargetWrapperCuda::num_devices() {
void* TargetWrapperCuda::Malloc(size_t size) { void* TargetWrapperCuda::Malloc(size_t size) {
void* ptr{}; void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); CUDA_CALL(cudaMalloc(&ptr, size));
return ptr; return ptr;
} }
void TargetWrapperCuda::Free(void* ptr) { void TargetWrapperCuda::Free(void* ptr) { CUDA_CALL(cudaFree(ptr)); }
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
void TargetWrapperCuda::MemcpySync(void* dst, void TargetWrapperCuda::MemcpySync(void* dst,
const void* src, const void* src,
...@@ -39,14 +38,13 @@ void TargetWrapperCuda::MemcpySync(void* dst, ...@@ -39,14 +38,13 @@ void TargetWrapperCuda::MemcpySync(void* dst,
IoDirection dir) { IoDirection dir) {
switch (dir) { switch (dir) {
case IoDirection::DtoD: case IoDirection::DtoD:
CHECK(cudaSuccess == CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
break; break;
case IoDirection::HtoD: case IoDirection::HtoD:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice)); CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
break; break;
case IoDirection::DtoH: case IoDirection::DtoH:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
break; break;
default: default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
...@@ -60,16 +58,16 @@ void TargetWrapperCuda::MemcpyAsync(void* dst, ...@@ -60,16 +58,16 @@ void TargetWrapperCuda::MemcpyAsync(void* dst,
const stream_t& stream) { const stream_t& stream) {
switch (dir) { switch (dir) {
case IoDirection::DtoD: case IoDirection::DtoD:
CHECK(cudaSuccess == CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream)); cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream));
break; break;
case IoDirection::HtoD: case IoDirection::HtoD:
CHECK(cudaSuccess == CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream)); cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream));
break; break;
case IoDirection::DtoH: case IoDirection::DtoH:
CHECK(cudaSuccess == CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream)); cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream));
break; break;
default: default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir); LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
......
...@@ -160,9 +160,9 @@ class Context<TargetType::kCUDA> { ...@@ -160,9 +160,9 @@ class Context<TargetType::kCUDA> {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>(); cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
} }
void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) { void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) {
CHECK_GT(devs.size(), 0) CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!"; << "Env is not initialized or current target is not exit!";
if (dev_id >= devs.size()) { if (dev_id >= static_cast<int>(devs.size())) {
LOG(WARNING) << "device index exceeds the number of devices, set to " LOG(WARNING) << "device index exceeds the number of devices, set to "
"default device(0)!"; "default device(0)!";
device_id_ = 0; device_id_ = 0;
......
...@@ -24,7 +24,7 @@ namespace mir { ...@@ -24,7 +24,7 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { for (auto conv_type : {"conv2d", "depthwise_conv2d"}) {
for (auto act_type : {"relu"}) { for (auto act_type : {"relu", "leaky_relu"}) {
for (auto has_bias : {true, false}) { for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
fuser(graph.get()); fuser(graph.get());
......
...@@ -73,7 +73,16 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -73,7 +73,16 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info(); cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info();
op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
op_desc.SetAttr("fuse_relu", true); cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("with_act", true);
op_desc.SetAttr("act_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("leaky_relu_alpha", alpha);
}
return op_desc; return op_desc;
} }
......
...@@ -28,7 +28,6 @@ class ConvActivationFuser : public FuseBase { ...@@ -28,7 +28,6 @@ class ConvActivationFuser : public FuseBase {
explicit ConvActivationFuser(const std::string& conv_type, explicit ConvActivationFuser(const std::string& conv_type,
const std::string& act_type, const std::string& act_type,
bool has_bias) { bool has_bias) {
CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type; conv_type_ = conv_type;
act_type_ = act_type; act_type_ = act_type;
has_bias_ = has_bias; has_bias_ = has_bias;
......
...@@ -87,6 +87,9 @@ void TypeLayoutTransformPass::AddLayoutInst( ...@@ -87,6 +87,9 @@ void TypeLayoutTransformPass::AddLayoutInst(
auto layout_output_name = auto layout_output_name =
string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id());
auto* layout_output_arg = graph->NewArgumentNode(layout_output_name); auto* layout_output_arg = graph->NewArgumentNode(layout_output_name);
layout_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), from.precision(), to.layout());
auto* layout_inst = graph->NewInstructNode(); auto* layout_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
...@@ -110,7 +113,9 @@ void TypeLayoutTransformPass::AddLayoutInst( ...@@ -110,7 +113,9 @@ void TypeLayoutTransformPass::AddLayoutInst(
bool is_found = false; bool is_found = false;
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatible(*in_arg_ty, from)) { const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->layout() == to.layout()) {
is_found = true; is_found = true;
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
......
...@@ -90,6 +90,8 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -90,6 +90,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name = auto cast_op_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id()); in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
auto* cast_inst = graph->NewInstructNode(); auto* cast_inst = graph->NewInstructNode();
// create Op and kernels. // create Op and kernels.
...@@ -118,13 +120,8 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -118,13 +120,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
// TODO(xg): to optimize this if (TypeCompatible(*in_arg_ty, from) &&
#ifndef LITE_WITH_FPGA
if (in_arg_ty->precision() == from.precision() &&
out_arg_ty->precision() == to.precision()) { out_arg_ty->precision() == to.precision()) {
#else
if (TypeCompatible(*in_arg_ty, from)) {
#endif
is_found = true; is_found = true;
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
......
...@@ -87,8 +87,12 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -87,8 +87,12 @@ void TypeTargetTransformPass::AddIoCopyInst(
auto node_id = [&] { return graph->nodes().size(); }; auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = auto io_copy_output_name =
string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id());
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(), from.layout()
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode(); auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
...@@ -114,7 +118,9 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -114,7 +118,9 @@ void TypeTargetTransformPass::AddIoCopyInst(
std::vector<std::unique_ptr<KernelBase>> selected_kernels; std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatible(*in_arg_ty, from)) { const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->target() == to.target()) {
is_found = true; is_found = true;
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
......
...@@ -58,8 +58,8 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -58,8 +58,8 @@ class VariablePlaceInferencePass : public DebugPass {
void SetWeightType(Node* w, const LiteType& type) { void SetWeightType(Node* w, const LiteType& type) {
// TODO(xg) to optimize this // TODO(xg) to optimize this
#ifndef LITE_WITH_FPGA #ifndef LITE_WITH_FPGA
w->AsArg().type = w->AsArg().type = LiteType::GetTensorTy(
LiteType::GetTensorTy(TARGET(kHost), type.precision(), type.layout()); TARGET(kHost), type.precision(), DATALAYOUT(kNCHW));
#else #else
w->AsArg().type = LiteType::GetTensorTy( w->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
......
...@@ -105,6 +105,7 @@ KernelRegistry::KernelRegistry() ...@@ -105,6 +105,7 @@ KernelRegistry::KernelRegistry()
DATALAYOUT(layout__)>::Global()); DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets. // Currently, just register 2 kernel targets.
INIT_FOR(kCUDA, kFloat, kNCHW); INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kFloat, kNHWC);
INIT_FOR(kCUDA, kInt8, kNCHW); INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny); INIT_FOR(kCUDA, kAny, kAny);
......
...@@ -70,6 +70,9 @@ class KernelRegistry final { ...@@ -70,6 +70,9 @@ class KernelRegistry final {
variant<KernelRegistryForTarget<TARGET(kCUDA), variant<KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kCUDA), KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
......
...@@ -113,9 +113,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -113,9 +113,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() { void RuntimeProgram::Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst.kernel()->target());
inst.Run(); inst.Run();
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE
...@@ -192,9 +189,14 @@ void Instruction::Run() { ...@@ -192,9 +189,14 @@ void Instruction::Run() {
CHECK(op_->CheckShape()); CHECK(op_->CheckShape());
} }
if (op_->run_once() && has_run_) return; if (op_->run_once() && has_run_) {
return;
}
VLOG(4) << "kernel launch"; VLOG(4) << "kernel launch";
op_->InferShape(); op_->InferShape();
VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
<< TargetToStr(kernel_->target());
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -4,18 +4,20 @@ endif() ...@@ -4,18 +4,20 @@ endif()
message(STATUS "compile with lite CUDA kernels") message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
nv_library(transpose_compute_cuda SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps})
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps})
nv_library(concat_compute_cuda SRCS concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
nv_library(elementwise_add_compute_cuda SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
...@@ -23,20 +25,4 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c ...@@ -23,20 +25,4 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
set(cuda_kernels
conv2d_cuda
mul_compute_cuda
io_copy_compute_cuda
leaky_relu_compute_cuda
nearest_interp_compute_cuda
concat_compute_cuda
elementwise_add_compute_cuda
yolo_box_compute_cuda
transpose_compute_cuda
)
set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels")
...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
......
...@@ -56,10 +56,20 @@ template class ConvComputeInt8<PRECISION(kFloat)>; ...@@ -56,10 +56,20 @@ template class ConvComputeInt8<PRECISION(kFloat)>;
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def) conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Input",
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))}) {LiteType::GetTensorTy(TARGET(kCUDA),
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA))}) PRECISION(kFloat),
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA))}) DATALAYOUT(kNCHW))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
...@@ -70,13 +80,19 @@ REGISTER_LITE_KERNEL( ...@@ -70,13 +80,19 @@ REGISTER_LITE_KERNEL(
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>, paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out) fp32_out)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) {LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindInput("Bias", .BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter", .BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))}) {LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Output", .BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) {LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
......
...@@ -105,6 +105,7 @@ TEST(conv_compute, fp32) { ...@@ -105,6 +105,7 @@ TEST(conv_compute, fp32) {
LOG(INFO) << y_cpu_data[i]; LOG(INFO) << y_cpu_data[i];
} }
} }
/*
TEST(conv_compute, int8) { TEST(conv_compute, int8) {
ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out; ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out;
...@@ -177,18 +178,19 @@ TEST(conv_compute, int8_int8_out) { ...@@ -177,18 +178,19 @@ TEST(conv_compute, int8_int8_out) {
operators::ActivationParam act_param; operators::ActivationParam act_param;
act_param.has_active = true; act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu; act_param.active_type = lite_api::ActivationType::kRelu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu; // act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1; act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param; operators::ConvParam param;
param.activation_param = act_param; param.activation_param = act_param;
param.groups = 1; param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3; int c_i = 3, h_i = 3, w_i = 3;
int n = 1, c = 4;
y.Resize({1, 1, 1, c}); y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c}); x_cpu.Resize({n, h_i, w_i, c_i});
filter_cpu.Resize({c, 3, 3, c / param.groups}); filter_cpu.Resize({c, 3, 3, c_i / param.groups});
y_cpu.Resize({1, 1, 1, c}); y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c}); bias_cpu.Resize({c});
...@@ -198,14 +200,19 @@ TEST(conv_compute, int8_int8_out) { ...@@ -198,14 +200,19 @@ TEST(conv_compute, int8_int8_out) {
auto* y_cpu_data = x_cpu.mutable_data<int8_t>(); auto* y_cpu_data = x_cpu.mutable_data<int8_t>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>(); auto* bias_cpu_data = bias_cpu.mutable_data<float>();
std::cout << "input" << std::endl;
for (int i = 0; i < x_cpu.numel(); i++) { for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36)); x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
std::cout << float(x_cpu_data[i]) << std::endl;
} }
std::cout << "filter" << std::endl;
for (int i = 0; i < filter_cpu.numel(); i++) { for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10)); filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
std::cout << float(filter_cpu_data[i]) << std::endl;
} }
for (int i = 0; i < bias_cpu.numel(); i++) { for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0; bias_cpu_data[i] = i + 1.0;
// bias_cpu_data[i] = 0;
} }
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims()); x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
...@@ -218,6 +225,7 @@ TEST(conv_compute, int8_int8_out) { ...@@ -218,6 +225,7 @@ TEST(conv_compute, int8_int8_out) {
param.filter = &filter; param.filter = &filter;
param.output = &y; param.output = &y;
param.weight_scale = {0.01, 0.02, 0.03, 0.04}; param.weight_scale = {0.01, 0.02, 0.03, 0.04};
param.output_scale = 2;
param.bias = &bias; param.bias = &bias;
int8_conv_fp32out.SetParam(param); int8_conv_fp32out.SetParam(param);
...@@ -232,12 +240,13 @@ TEST(conv_compute, int8_int8_out) { ...@@ -232,12 +240,13 @@ TEST(conv_compute, int8_int8_out) {
CopySync<TARGET(kCUDA)>( CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH); y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-1, 4, 0, -2}; std::vector<float> real_results = {0, 7, 8, 1};
for (int i = 0; i < y.numel(); i++) { for (int i = 0; i < y.numel(); i++) {
// EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5); // EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
LOG(INFO) << float(y_cpu_data[i]); LOG(INFO) << float(y_cpu_data[i]);
} }
} }
*/
} // namespace cuda } // namespace cuda
} // namespace kernels } // namespace kernels
......
...@@ -26,7 +26,11 @@ __global__ void KeElementwiseAdd(const float* x_data, ...@@ -26,7 +26,11 @@ __global__ void KeElementwiseAdd(const float* x_data,
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
for (; tid < total; tid += stride) { for (; tid < total; tid += stride) {
#if __CUDA_ARCH__ >= 350
out_data[tid] = __ldg(x_data + tid) + __ldg(y_data + tid);
#else
out_data[tid] = x_data[tid] + y_data[tid]; out_data[tid] = x_data[tid] + y_data[tid];
#endif
} }
} }
...@@ -51,7 +55,7 @@ void ElementwiseAddCompute::Run() { ...@@ -51,7 +55,7 @@ void ElementwiseAddCompute::Run() {
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel(); int pixel_num = x->numel();
int threads = 512; int threads = 1024;
int blocks = (pixel_num + threads - 1) / threads; int blocks = (pixel_num + threads - 1) / threads;
blocks = blocks > 8 ? 8 : blocks; blocks = blocks > 8 ? 8 : blocks;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/feed_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void FeedCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
VLOG(4) << "feed_list.size: " << param.feed_list->size();
const lite::Tensor& feed_item = (*param.feed_list)[param.col];
int num = static_cast<int>(feed_item.numel());
auto input = feed_item.data<float>();
param.out->Resize(feed_item.dims());
auto output = param.out->mutable_data<float>(TARGET(kCUDA));
VLOG(4) << "col: " << param.col << " num:" << num;
TargetW::MemcpyAsync(
output, input, num * sizeof(float), IoDirection::HtoD, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
feed, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::FeedCompute, nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
feed, kCUDA, kFloat, kNHWC, paddle::lite::kernels::cuda::FeedCompute, nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class FeedCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FeedParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
void Run() override;
virtual ~FeedCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -51,7 +51,7 @@ class IoCopyHostToCudaCompute ...@@ -51,7 +51,7 @@ class IoCopyHostToCudaCompute
CHECK(param.x->target() == TARGET(kHost) || CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kX86)); param.x->target() == TARGET(kX86));
auto mem_size = param.x->memory_size(); auto mem_size = param.x->memory_size();
LOG(INFO) << "copy size " << mem_size; VLOG(4) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kCUDA), mem_size); auto* data = param.y->mutable_data(TARGET(kCUDA), mem_size);
CopyFromHostSync(data, param.x->raw_data(), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size);
} }
...@@ -89,7 +89,7 @@ class IoCopyCudaToHostCompute ...@@ -89,7 +89,7 @@ class IoCopyCudaToHostCompute
auto& param = Param<operators::IoCopyParam>(); auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA)); CHECK(param.x->target() == TARGET(kCUDA));
auto mem_size = param.x->memory_size(); auto mem_size = param.x->memory_size();
LOG(INFO) << "io copy cuda to host " << mem_size; VLOG(4) << "io copy cuda to host " << mem_size;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size); auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
CopyToHostSync(data, param.x->raw_data(), mem_size); CopyToHostSync(data, param.x->raw_data(), mem_size);
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/layout_compute.h"
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
void NCHWToNHWCCompute<Dtype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto input = param.x->template data<Dtype>();
auto input_dim = param.x->dims();
CHECK(input_dim.size() == 4)
<< "NCHW to NHWC should guarantee that the input dims should be 4";
auto output = param.y->template mutable_data<Dtype>(TARGET(kCUDA));
int n = input_dim[0];
int c = input_dim[1];
int h = input_dim[2];
int w = input_dim[3];
lite::cuda::math::NCHW2NHWC<Dtype>(n, c, h * w, input, output, &ctx);
}
template <typename Dtype>
void NHWCToNCHWCompute<Dtype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto input = param.x->template data<Dtype>();
auto output = param.y->template mutable_data<Dtype>(TARGET(kCUDA));
auto input_dim = param.x->dims();
CHECK(input_dim.size() == 4)
<< "NHWC to NCHW should guarantee that the input dims should be 4";
int n = input_dim[0];
int h = input_dim[1];
int w = input_dim[2];
int c = input_dim[3];
lite::cuda::math::NHWC2NCHW<Dtype>(n, c, h * w, input, output, &ctx);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(layout,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::NCHWToNHWCCompute<float>,
nchw2nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::NHWCToNCHWCompute<float>,
nhwc2nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::NCHWToNHWCCompute<int8_t>,
nchw2nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::NHWCToNCHWCompute<int8_t>,
nhwc2nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
class LayOutCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~LayOutCompute() = default;
};
template <typename Dtype>
class NCHWToNHWCCompute : public LayOutCompute<Dtype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NCHWToNHWCCompute() = default;
};
template <typename Dtype>
class NHWCToNCHWCompute : public LayOutCompute<Dtype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NHWCToNCHWCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -74,6 +74,17 @@ REGISTER_LITE_KERNEL(transpose, ...@@ -74,6 +74,17 @@ REGISTER_LITE_KERNEL(transpose,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(transpose2,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::TransposeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// REGISTER_LITE_KERNEL(transpose2, // REGISTER_LITE_KERNEL(transpose2,
// kCUDA, // kCUDA,
// kFloat, // kFloat,
......
...@@ -76,8 +76,22 @@ class ConvOpLite : public OpLite { ...@@ -76,8 +76,22 @@ class ConvOpLite : public OpLite {
} }
} }
} }
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu"); if (op_desc.HasAttr("with_act") && op_desc.GetAttr<bool>("with_act")) {
param_.activation_param.has_active = true;
auto act_type = op_desc.GetAttr<std::string>("act_type");
if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true;
} else if (act_type == "leaky_relu") {
param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
param_.activation_param.Leaky_relu_alpha =
op_desc.GetAttr<float>("leaky_relu_alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
} }
// For Int8 // For Int8
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册