未验证 提交 3d6d744f 编写于 作者: Z Zhaolong Xing 提交者: GitHub

can run yolov3 fp32 on cuda devices (#2092)

* add conv int8 support(in condition which the input or output channel not be the times of 4)
add add_kernel for cuda.

* can run yolov3 fp32
test=develop

* 1. fix bug with yolov3 run
test=develop
上级 aa6623b8
......@@ -301,6 +301,18 @@ function(add_kernel TARGET device level)
set(opencl_kernels "${opencl_kernels};${TARGET}" CACHE INTERNAL "")
endif()
if ("${device}" STREQUAL "CUDA")
if (NOT LITE_WITH_CUDA)
return()
endif()
set(cuda_kernels "${cuda_kernels};${TARGET}" CACHE INTERNAL "")
foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach()
nv_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS})
return()
endif()
# the source list will collect for paddle_use_kernel.h code generation.
foreach(src ${args_SRCS})
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
......
......@@ -147,6 +147,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
core::KernelPickFactor factor;
factor.ConsiderTarget();
factor.ConsiderPrecision();
factor.ConsiderDataLayout();
optimizer_.Run(std::move(program), valid_places, factor, passes);
exec_scope_ = optimizer_.exec_scope();
}
......
......@@ -41,6 +41,10 @@ const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>();
}
template <>
int *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<int>();
}
template <>
float *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<float>();
......
......@@ -5,7 +5,7 @@ endif()
nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cuda_transpose SRCS transpose.cu)
nv_library(cuda_transpose SRCS transpose.cu )
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans)
......
......@@ -37,6 +37,22 @@ __global__ void relu_kernel(const int num,
}
}
template <typename T>
__global__ void bias_relu_kernel(const int num,
const T alpha,
const T* input,
T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
: __ldg(input + index) * alpha;
#else
output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in,
const float4* bias,
......@@ -277,7 +293,23 @@ void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) {
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template <typename T>
void bias_relu(int num,
const T* din,
const float* bias,
T* dout,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void relu(int, const float*, float*, float, cudaStream_t);
template void bias_relu(
int, const float*, const float* bias, float*, float, cudaStream_t);
} // namespace math
} // namespace cuda
......
......@@ -26,6 +26,26 @@ namespace math {
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
template <typename T>
void bias_relu(int num,
const T* din,
const float* bias,
T* dout,
float alpha,
cudaStream_t stream);
// For int8
template <typename out_type>
void bias_relu_int8_nhwc4(int num,
......@@ -40,18 +60,6 @@ void bias_relu_int8_nhwc4(int num,
float alpha,
cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -14,6 +14,7 @@
#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/conv_op_cache_cudnn.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"
......@@ -87,6 +88,56 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
if (ic == param.groups && ic == oc && ic != 1) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else if (1) {
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
auto* o_data = param.output->mutable_data<float>(TARGET(kCUDA));
int workspace_size_limit = 256 * 1024 * 1024;
auto search_func = [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
this->handle_,
this->input_desc_,
i_data,
this->filter_desc_,
w_data,
this->conv_desc_,
this->output_desc_,
o_data,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
&returned_algo_count,
fwd_perf_stat.data(),
cudnn_workspace,
workspace_size_limit));
};
ResetWorkSpace();
CUDA_CALL(cudaMalloc(&this->workspace_data_, workspace_size_limit));
cudnn_find_func(this->workspace_data_);
ResetWorkSpace();
VLOG(2) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(2) << stat.algo << ": " << stat.status << " " << stat.time << " "
<< stat.memory;
}
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> algo_cache;
this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(),
w_dims.Vectorize(),
param.strides,
param.paddings,
param.dilations,
0,
search_func);
} else {
CUDNN_CHECK(
cudnnGetConvolutionForwardAlgorithm(this->handle_,
......@@ -108,9 +159,7 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
&this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
}
ResetWorkSpace();
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
......@@ -272,16 +321,21 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
std::vector<float> weight_scale = param.weight_scale;
float input_scale = param.input_scale;
float output_scale = param.output_scale;
CHECK(weight_scale.size() == oc)
CHECK(weight_scale.size() == static_cast<size_t>(oc))
<< "the num of the weight_scale should be equals to the output channel.";
if (Ptype_out == PRECISION(kInt8)) {
this->temp_tensor_.Resize(o_dims);
this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
for (int i = 0; i < weight_scale.size(); i++) {
for (size_t i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
}
auto* b_data = param.bias ? param.bias->mutable_data<float>() : nullptr;
if (b_data) {
scale(param.bias->numel(), b_data, b_data, 1.f / output_scale);
}
} else {
for (int i = 0; i < weight_scale.size(); i++) {
for (size_t i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale);
}
}
......@@ -322,8 +376,11 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
oc,
oh,
ow));
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
if (ic % 4 == 0 && oc % 4 == 0) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_,
......@@ -331,14 +388,15 @@ bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
this->conv_desc_,
this->output_desc_,
this->fwd_algo_,
&(this->workspace_fwd_sizes_)));
&this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
CUDA_CALL(cudaFree(this->workspace_data_));
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
CUDA_CALL(
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_));
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
......
......@@ -66,9 +66,15 @@ class CudnnConv2DBase {
if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_));
}
ResetWorkSpace();
}
protected:
void ResetWorkSpace() {
if (workspace_data_ != NULL) {
cudaFree(workspace_data_);
CUDA_CALL(cudaFree(workspace_data_));
}
workspace_data_ = NULL;
}
protected:
......
......@@ -21,6 +21,18 @@ namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void scale_kernel(int num, const T* in, T* out, const float scale) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
#if __CUDA_ARCH__ >= 350
out[tid] = __ldg(in + tid) * scale;
#else
out[tid] = in[tid] * scale;
#endif
}
}
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
float4* out,
......@@ -68,6 +80,23 @@ void fp32_scale_nhwc4(int num,
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale);
}
template <typename T>
void scale(int num, const T* in, T* out, float scale) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread>>>(num, in, out, scale);
}
template void scale(int num, const float*, float*, float, cudaStream_t);
template void scale(int num, const float*, float*, float);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -31,6 +31,12 @@ void fp32_scale_nhwc4(int num,
int W,
cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -89,6 +89,7 @@ void BatchTranspose2DCUDAImpl(const int N,
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
......@@ -102,6 +103,7 @@ TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template <typename T>
......@@ -169,8 +171,6 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>(
size, ndim, d_strides, d_y_dims, X, Y);
// cudaError_t error = cudaGetLastError();
// if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
......@@ -25,13 +26,11 @@ size_t TargetWrapperCuda::num_devices() {
void* TargetWrapperCuda::Malloc(size_t size) {
void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
CUDA_CALL(cudaMalloc(&ptr, size));
return ptr;
}
void TargetWrapperCuda::Free(void* ptr) {
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
void TargetWrapperCuda::Free(void* ptr) { CUDA_CALL(cudaFree(ptr)); }
void TargetWrapperCuda::MemcpySync(void* dst,
const void* src,
......@@ -39,14 +38,13 @@ void TargetWrapperCuda::MemcpySync(void* dst,
IoDirection dir) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
break;
case IoDirection::HtoD:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
break;
case IoDirection::DtoH:
CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
CUDA_CALL(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
break;
default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
......@@ -60,16 +58,16 @@ void TargetWrapperCuda::MemcpyAsync(void* dst,
const stream_t& stream) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream));
CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream));
break;
case IoDirection::HtoD:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream));
CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream));
break;
case IoDirection::DtoH:
CHECK(cudaSuccess ==
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream));
CUDA_CALL(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream));
break;
default:
LOG(FATAL) << "Unsupported IoDirection " << static_cast<int>(dir);
......
......@@ -160,9 +160,9 @@ class Context<TargetType::kCUDA> {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) {
CHECK_GT(devs.size(), 0)
CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!";
if (dev_id >= devs.size()) {
if (dev_id >= static_cast<int>(devs.size())) {
LOG(WARNING) << "device index exceeds the number of devices, set to "
"default device(0)!";
device_id_ = 0;
......
......@@ -24,7 +24,7 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto conv_type : {"conv2d", "depthwise_conv2d"}) {
for (auto act_type : {"relu"}) {
for (auto act_type : {"relu", "leaky_relu"}) {
for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
fuser(graph.get());
......
......@@ -73,7 +73,16 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info();
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
op_desc.SetAttr("fuse_relu", true);
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("with_act", true);
op_desc.SetAttr("act_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("leaky_relu_alpha", alpha);
}
return op_desc;
}
......
......@@ -28,7 +28,6 @@ class ConvActivationFuser : public FuseBase {
explicit ConvActivationFuser(const std::string& conv_type,
const std::string& act_type,
bool has_bias) {
CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type;
act_type_ = act_type;
has_bias_ = has_bias;
......
......@@ -87,6 +87,9 @@ void TypeLayoutTransformPass::AddLayoutInst(
auto layout_output_name =
string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id());
auto* layout_output_arg = graph->NewArgumentNode(layout_output_name);
layout_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), from.precision(), to.layout());
auto* layout_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
......@@ -110,7 +113,9 @@ void TypeLayoutTransformPass::AddLayoutInst(
bool is_found = false;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatible(*in_arg_ty, from)) {
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->layout() == to.layout()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
......
......@@ -90,6 +90,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
auto* cast_inst = graph->NewInstructNode();
// create Op and kernels.
......@@ -118,13 +120,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
// TODO(xg): to optimize this
#ifndef LITE_WITH_FPGA
if (in_arg_ty->precision() == from.precision() &&
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->precision() == to.precision()) {
#else
if (TypeCompatible(*in_arg_ty, from)) {
#endif
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
......
......@@ -87,8 +87,12 @@ void TypeTargetTransformPass::AddIoCopyInst(
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name =
string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id());
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(), from.layout()
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
......@@ -114,7 +118,9 @@ void TypeTargetTransformPass::AddIoCopyInst(
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatible(*in_arg_ty, from)) {
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->target() == to.target()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
......
......@@ -58,8 +58,8 @@ class VariablePlaceInferencePass : public DebugPass {
void SetWeightType(Node* w, const LiteType& type) {
// TODO(xg) to optimize this
#ifndef LITE_WITH_FPGA
w->AsArg().type =
LiteType::GetTensorTy(TARGET(kHost), type.precision(), type.layout());
w->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), type.precision(), DATALAYOUT(kNCHW));
#else
w->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
......
......@@ -105,6 +105,7 @@ KernelRegistry::KernelRegistry()
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kFloat, kNHWC);
INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
......
......@@ -70,6 +70,9 @@ class KernelRegistry final {
variant<KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
......
......@@ -113,9 +113,6 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() {
for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst.kernel()->target());
inst.Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
......@@ -192,9 +189,14 @@ void Instruction::Run() {
CHECK(op_->CheckShape());
}
if (op_->run_once() && has_run_) return;
if (op_->run_once() && has_run_) {
return;
}
VLOG(4) << "kernel launch";
op_->InferShape();
VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
<< TargetToStr(kernel_->target());
kernel_->Launch();
has_run_ = true;
}
......
......@@ -4,18 +4,20 @@ endif()
message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(transpose_compute_cuda SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(concat_compute_cuda SRCS concat_compute.cu DEPS ${lite_kernel_deps})
nv_library(elementwise_add_compute_cuda SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps})
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps})
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
......@@ -23,20 +25,4 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
set(cuda_kernels
conv2d_cuda
mul_compute_cuda
io_copy_compute_cuda
leaky_relu_compute_cuda
nearest_interp_compute_cuda
concat_compute_cuda
elementwise_add_compute_cuda
yolo_box_compute_cuda
transpose_compute_cuda
)
set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels")
nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
......@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/op_registry.h"
......
......@@ -56,10 +56,20 @@ template class ConvComputeInt8<PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
......@@ -70,13 +80,19 @@ REGISTER_LITE_KERNEL(
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
......
......@@ -105,6 +105,7 @@ TEST(conv_compute, fp32) {
LOG(INFO) << y_cpu_data[i];
}
}
/*
TEST(conv_compute, int8) {
ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out;
......@@ -177,18 +178,19 @@ TEST(conv_compute, int8_int8_out) {
operators::ActivationParam act_param;
act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.active_type = lite_api::ActivationType::kRelu;
// act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param;
param.activation_param = act_param;
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3;
int c_i = 3, h_i = 3, w_i = 3;
int n = 1, c = 4;
y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c});
filter_cpu.Resize({c, 3, 3, c / param.groups});
x_cpu.Resize({n, h_i, w_i, c_i});
filter_cpu.Resize({c, 3, 3, c_i / param.groups});
y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c});
......@@ -198,14 +200,19 @@ TEST(conv_compute, int8_int8_out) {
auto* y_cpu_data = x_cpu.mutable_data<int8_t>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>();
std::cout << "input" << std::endl;
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
std::cout << float(x_cpu_data[i]) << std::endl;
}
std::cout << "filter" << std::endl;
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
std::cout << float(filter_cpu_data[i]) << std::endl;
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
// bias_cpu_data[i] = 0;
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
......@@ -218,6 +225,7 @@ TEST(conv_compute, int8_int8_out) {
param.filter = &filter;
param.output = &y;
param.weight_scale = {0.01, 0.02, 0.03, 0.04};
param.output_scale = 2;
param.bias = &bias;
int8_conv_fp32out.SetParam(param);
......@@ -232,12 +240,13 @@ TEST(conv_compute, int8_int8_out) {
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-1, 4, 0, -2};
std::vector<float> real_results = {0, 7, 8, 1};
for (int i = 0; i < y.numel(); i++) {
// EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
LOG(INFO) << float(y_cpu_data[i]);
}
}
*/
} // namespace cuda
} // namespace kernels
......
......@@ -26,7 +26,11 @@ __global__ void KeElementwiseAdd(const float* x_data,
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < total; tid += stride) {
#if __CUDA_ARCH__ >= 350
out_data[tid] = __ldg(x_data + tid) + __ldg(y_data + tid);
#else
out_data[tid] = x_data[tid] + y_data[tid];
#endif
}
}
......@@ -51,7 +55,7 @@ void ElementwiseAddCompute::Run() {
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
int threads = 512;
int threads = 1024;
int blocks = (pixel_num + threads - 1) / threads;
blocks = blocks > 8 ? 8 : blocks;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/feed_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void FeedCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
VLOG(4) << "feed_list.size: " << param.feed_list->size();
const lite::Tensor& feed_item = (*param.feed_list)[param.col];
int num = static_cast<int>(feed_item.numel());
auto input = feed_item.data<float>();
param.out->Resize(feed_item.dims());
auto output = param.out->mutable_data<float>(TARGET(kCUDA));
VLOG(4) << "col: " << param.col << " num:" << num;
TargetW::MemcpyAsync(
output, input, num * sizeof(float), IoDirection::HtoD, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
feed, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::FeedCompute, nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
feed, kCUDA, kFloat, kNHWC, paddle::lite::kernels::cuda::FeedCompute, nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class FeedCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FeedParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
void Run() override;
virtual ~FeedCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -51,7 +51,7 @@ class IoCopyHostToCudaCompute
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kX86));
auto mem_size = param.x->memory_size();
LOG(INFO) << "copy size " << mem_size;
VLOG(4) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kCUDA), mem_size);
CopyFromHostSync(data, param.x->raw_data(), mem_size);
}
......@@ -89,7 +89,7 @@ class IoCopyCudaToHostCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto mem_size = param.x->memory_size();
LOG(INFO) << "io copy cuda to host " << mem_size;
VLOG(4) << "io copy cuda to host " << mem_size;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
CopyToHostSync(data, param.x->raw_data(), mem_size);
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/layout_compute.h"
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
void NCHWToNHWCCompute<Dtype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto input = param.x->template data<Dtype>();
auto input_dim = param.x->dims();
CHECK(input_dim.size() == 4)
<< "NCHW to NHWC should guarantee that the input dims should be 4";
auto output = param.y->template mutable_data<Dtype>(TARGET(kCUDA));
int n = input_dim[0];
int c = input_dim[1];
int h = input_dim[2];
int w = input_dim[3];
lite::cuda::math::NCHW2NHWC<Dtype>(n, c, h * w, input, output, &ctx);
}
template <typename Dtype>
void NHWCToNCHWCompute<Dtype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto input = param.x->template data<Dtype>();
auto output = param.y->template mutable_data<Dtype>(TARGET(kCUDA));
auto input_dim = param.x->dims();
CHECK(input_dim.size() == 4)
<< "NHWC to NCHW should guarantee that the input dims should be 4";
int n = input_dim[0];
int h = input_dim[1];
int w = input_dim[2];
int c = input_dim[3];
lite::cuda::math::NHWC2NCHW<Dtype>(n, c, h * w, input, output, &ctx);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(layout,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::NCHWToNHWCCompute<float>,
nchw2nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::NHWCToNCHWCompute<float>,
nhwc2nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kInt8,
kNCHW,
paddle::lite::kernels::cuda::NCHWToNHWCCompute<int8_t>,
nchw2nhwc)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::NHWCToNCHWCompute<int8_t>,
nhwc2nchw)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
class LayOutCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~LayOutCompute() = default;
};
template <typename Dtype>
class NCHWToNHWCCompute : public LayOutCompute<Dtype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NCHWToNHWCCompute() = default;
};
template <typename Dtype>
class NHWCToNCHWCompute : public LayOutCompute<Dtype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NHWCToNCHWCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -74,6 +74,17 @@ REGISTER_LITE_KERNEL(transpose,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::TransposeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// REGISTER_LITE_KERNEL(transpose2,
// kCUDA,
// kFloat,
......
......@@ -76,8 +76,22 @@ class ConvOpLite : public OpLite {
}
}
}
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
if (op_desc.HasAttr("with_act") && op_desc.GetAttr<bool>("with_act")) {
param_.activation_param.has_active = true;
auto act_type = op_desc.GetAttr<std::string>("act_type");
if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true;
} else if (act_type == "leaky_relu") {
param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
param_.activation_param.Leaky_relu_alpha =
op_desc.GetAttr<float>("leaky_relu_alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
}
// For Int8
if (op_desc.HasAttr("enable_int8")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册