提交 23d83c04 编写于 作者: Z Zhaolong Xing 提交者: GitHub

add cudnn conv fp32, int8 support (#1974)

* paddle lite cuda init
can run model with leaky_relu

* add the missing file.
test=develop

* add the load from memory interface.
test=develop

* refine this pr. fix comments
fix ci error
test=develop

* conv impl
fp32:
conv, conv+bais, conv+bias+relu, conv+bias+leaky_relu

int8:
conv, conv+bais+relu(int8 or fp32 output), conv+bias+leaky_relu(int8 or fp32 output)

can run conv+ bias+relu using cxx_api
test=develop

* move the lite/cuda/math to backends/cuda/math
test=develop
上级 52f933d4
...@@ -503,17 +503,14 @@ function(nv_test TARGET_NAME) ...@@ -503,17 +503,14 @@ function(nv_test TARGET_NAME)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog ${os_dependency_modules}) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog) gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY})
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL) if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif() endif()
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif() endif()
endfunction(nv_test) endfunction(nv_test)
......
...@@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); ...@@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(conv2d, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def);
......
...@@ -5,3 +5,4 @@ endif() ...@@ -5,3 +5,4 @@ endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc) nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas SRCS blas.cc) nv_library(cuda_blas SRCS blas.cc)
add_subdirectory(math)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cublas_api.h> #include <cublas_api.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h> #include <cuda.h>
#include <cudnn.h>
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
/* /*
...@@ -46,6 +47,15 @@ ...@@ -46,6 +47,15 @@
<< "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \ << "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \
} }
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
#define CUDNN_CHECK(condition) \
{ \
cudnnStatus_t status = condition; \
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << CudnnGetErrorInfo(status); \
}
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace cuda { namespace cuda {
...@@ -71,6 +81,44 @@ static const char* CublasErrorInfo(int error) { ...@@ -71,6 +81,44 @@ static const char* CublasErrorInfo(int error) {
} }
} }
static const char* CudnnGetErrorInfo(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
#if CUDNN_VERSION_MIN(6, 0, 0)
case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
#endif
}
return "Unknown cudnn status";
}
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
if(NOT LITE_WITH_CUDA)
return()
endif()
nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans)
set (
math_cuda
cudnn_conv
cuda_activation
cuda_scale
cuda_type_trans
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void relu_kernel(const int num,
const T alpha,
const T* input,
T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
: __ldg(input + index) * alpha;
#else
output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in,
const float4* bias,
float4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float4 bias_ptr = bias[bias_idx];
const float4 scale_ptr = scale[bias_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
out[tid] = packed_val;
}
}
__global__ void bias_relu_int8_nhwc4_kernel(int num,
const float4* in,
const float4* bias,
char4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float4 bias_ptr = bias[bias_idx];
const float4 scale_ptr = scale[bias_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
char4 result_val;
packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
result_val.x =
from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
result_val.y =
from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
result_val.z =
from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
result_val.w =
from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));
out[tid] = result_val;
}
}
__global__ void relu_int8_nhwc4_kernel(int num,
const float4* in,
float4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x;
packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
packed_val.y = in_ptr.y * scale_ptr.y;
packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
packed_val.z = in_ptr.z * scale_ptr.z;
packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
packed_val.w = in_ptr.w * scale_ptr.w;
packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
out[tid] = packed_val;
}
}
__global__ void relu_int8_nhwc4_kernel(int num,
const float4* in,
char4* out,
int N,
int K,
int H,
int W,
const float4* scale,
float alpha) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
char4 result_val;
packed_val.x = in_ptr.x * scale_ptr.x;
result_val.x =
from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
packed_val.y = in_ptr.y * scale_ptr.y;
result_val.y =
from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
packed_val.z = in_ptr.z * scale_ptr.z;
result_val.z =
from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
packed_val.w = in_ptr.w * scale_ptr.w;
result_val.w =
from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));
out[tid] = result_val;
}
}
template <>
void bias_relu_int8_nhwc4<float>(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<const float4*>(bias),
static_cast<float4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void bias_relu_int8_nhwc4<int8_t>(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<const float4*>(bias),
static_cast<char4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void relu_int8_nhwc4<float>(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<float4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <>
void relu_int8_nhwc4<int8_t>(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<char4*>(out),
N,
K,
H,
W,
static_cast<const float4*>(scale),
alpha);
}
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void relu(int, const float*, float*, float, cudaStream_t);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
// fp32
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
// For int8
template <typename out_type>
void bias_relu_int8_nhwc4(int num,
const void* in,
const void* bias,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
template <typename out_type>
void relu_int8_nhwc4(int num,
const void* in,
void* out,
int N,
int K,
int H,
int W,
const void* scale,
float alpha,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <>
bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kw = w_dims[3];
int kh = w_dims[2];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
CHECK(ic % param.groups == 0)
<< "The conv input channel shoud be divide group number.";
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
if (param.activation_param.has_active && with_relu_act_) {
CUDNN_CHECK(cudnnSetActivationDescriptor(
this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
}
if (ic == param.groups && ic == oc && ic != 1) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else {
CUDNN_CHECK(
cudnnGetConvolutionForwardAlgorithm(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->preference_,
this->workspace_limit_bytes_,
&this->fwd_algo_));
}
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->fwd_algo_,
&this->workspace_fwd_sizes_));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
if (param.bias) {
int dim_bias[] = {1, oc, 1, 1};
int stride_bias[] = {oc, 1, 1, 1};
cudnnSetTensorNdDescriptor(
this->bias_desc_, CUDNN_DATA_FLOAT, 4, dim_bias, stride_bias);
}
return true;
}
template <>
bool CudnnConv2D<PRECISION(kFloat)>::init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->workspace_size_inbytes_ = 0;
this->workspace_data_ = NULL;
this->workspace_fwd_sizes_ = 0;
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
this->workspace_ = NULL;
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreateFilterDescriptor(&this->filter_desc_);
cudnnCreateConvolutionDescriptor(&this->conv_desc_);
cudnnCreateTensorDescriptor(&this->bias_desc_);
if (param.activation_param.has_active) {
if (param.activation_param.active_type == lite_api::ActivationType::kRelu) {
cudnnCreateActivationDescriptor(&this->act_desc_);
} else {
this->with_relu_act_ = false;
}
}
return create(param, ctx);
}
template <>
bool CudnnConv2D<PRECISION(kFloat)>::run(const operators::ConvParam& param) {
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (param.activation_param.has_active && with_relu_act_) {
if (b_data) {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data,
bias_desc_,
b_data,
act_desc_,
output_desc_,
o_data));
} else {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data));
CUDNN_CHECK(cudnnActivationForward(handle_,
act_desc_,
&alpha,
output_desc_,
o_data,
&beta,
output_desc_,
o_data));
}
} else {
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(handle_,
&alpha,
input_desc_,
i_data,
filter_desc_,
w_data,
conv_desc_,
fwd_algo_,
workspace_,
workspace_fwd_sizes_,
&beta,
output_desc_,
o_data));
if (b_data) {
CUDNN_CHECK(cudnnAddTensor(
handle_, &alpha, bias_desc_, b_data, &alpha, output_desc_, o_data));
}
}
if (!with_relu_act_) {
CHECK(param.activation_param.active_type ==
lite_api::ActivationType::kLeakyRelu)
<< "Only support leaky relu now.";
auto out_dims = param.output->dims();
int n = out_dims[0], c = out_dims[1], h = out_dims[2], w = out_dims[3];
int num = n * h * w * c;
float alpha = param.activation_param.Leaky_relu_alpha;
relu(num, o_data, o_data, alpha, this->stream_);
}
return true;
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
int iw = x_dims[2]; // nchw
int ih = x_dims[1];
int ic = x_dims[3];
int ow = o_dims[2];
int oh = o_dims[1];
int oc = o_dims[3];
int kw = w_dims[2];
int kh = w_dims[1];
int sw = param.strides[1];
int sh = param.strides[0];
int pw = param.paddings[1];
int ph = param.paddings[0];
int dw = param.dilations[1];
int dh = param.dilations[0];
std::vector<float> weight_scale = param.weight_scale;
float input_scale = param.input_scale;
float output_scale = param.output_scale;
CHECK(weight_scale.size() == oc)
<< "the num of the weight_scale should be equals to the output channel.";
if (Ptype_out == PRECISION(kInt8)) {
this->temp_tensor_.Resize(o_dims);
this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
for (int i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
}
} else {
for (int i = 0; i < weight_scale.size(); i++) {
weight_scale[i] = (weight_scale[i] * input_scale);
}
}
this->scale_.Resize({oc});
auto* scale_data = this->scale_.template mutable_data<float>(TARGET(kCUDA));
this->scale_.template Assign<float, lite::DDim, TARGET(kCUDA)>(
weight_scale.data(), this->scale_.dims());
CHECK(ic % param.groups == 0)
<< "The conv input channel shoud be divide group number.";
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_INT8,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
CUDNN_DATA_INT8,
CUDNN_TENSOR_NHWC,
oc,
ic / param.groups,
kh,
kw));
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
ph,
pw,
sh,
sw,
dh,
dw,
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_INT32));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
CUDNN_CHECK(
cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
this->input_desc_,
this->filter_desc_,
this->conv_desc_,
this->output_desc_,
this->fwd_algo_,
&(this->workspace_fwd_sizes_)));
if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
if (this->workspace_data_ != NULL) {
cudaFree(this->workspace_data_);
}
cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
}
return true;
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->workspace_size_inbytes_ = 0; // 64Mb
this->workspace_data_ = NULL;
this->workspace_fwd_sizes_ = 0;
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
this->workspace_ = NULL;
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreateFilterDescriptor(&this->filter_desc_);
cudnnCreateConvolutionDescriptor(&this->conv_desc_);
cudnnCreateTensorDescriptor(&this->bias_desc_);
if (param.activation_param.has_active) {
if (!(param.activation_param.active_type ==
lite_api::ActivationType::kRelu)) {
this->with_relu_act_ = false;
}
}
return create(param, ctx);
}
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) {
const auto* i_data = param.x->data<int8_t>();
const auto* w_data = param.filter->data<int8_t>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
float* temp_out;
float* scale = this->scale_.template mutable_data<float>(TARGET(kCUDA));
if (Ptype_out == PRECISION(kInt8)) {
temp_out = this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
} else {
temp_out = param.output->mutable_data<float>(TARGET(kCUDA));
}
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
&alpha,
this->input_desc_,
i_data,
this->filter_desc_,
w_data,
this->conv_desc_,
this->fwd_algo_,
this->workspace_,
this->workspace_fwd_sizes_,
&beta,
this->output_desc_,
temp_out));
auto out_dims = param.output->dims();
int n = out_dims[0], h = out_dims[1], w = out_dims[2], c = out_dims[3];
int num = n * h * w * c / 4;
if (!param.activation_param.has_active && !b_data) {
if (Ptype_out == PRECISION(kInt8)) {
auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
fp32_to_int8_nhwc4(num,
static_cast<const void*>(temp_out),
static_cast<void*>(out),
static_cast<const void*>(scale),
n,
c / 4,
h,
w,
this->stream_);
} else {
fp32_scale_nhwc4(num,
static_cast<const void*>(temp_out),
static_cast<void*>(temp_out),
static_cast<const void*>(scale),
n,
c / 4,
h,
w,
this->stream_);
}
return true;
}
if (b_data) {
if (param.activation_param.has_active) {
float alpha = 0.0;
if (!this->with_relu_act_)
alpha = param.activation_param.Leaky_relu_alpha;
if (Ptype_out == PRECISION(kInt8)) {
auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
bias_relu_int8_nhwc4<int8_t>(num,
static_cast<const void*>(temp_out),
static_cast<const void*>(b_data),
static_cast<void*>(out),
n,
c / 4,
h,
w,
static_cast<const void*>(scale),
alpha,
this->stream_);
} else {
bias_relu_int8_nhwc4<float>(num,
static_cast<const void*>(temp_out),
static_cast<const void*>(b_data),
static_cast<void*>(temp_out),
n,
c / 4,
h,
w,
static_cast<const void*>(scale),
alpha,
this->stream_);
}
return true;
}
}
CHECK(false)
<< "Conv Int8 support Conv, Conv + bias + relu, Conv + bias + leaky_relu";
}
template class CudnnConv2DInt8<PRECISION(kInt8)>;
template class CudnnConv2DInt8<PRECISION(kFloat)>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <PrecisionType Ptype_out>
class CudnnConv2DBase {
public:
CudnnConv2DBase()
: handle_(NULL),
workspace_data_(NULL),
workspace_(NULL),
conv_desc_(NULL),
input_desc_(NULL),
output_desc_(NULL),
filter_desc_(NULL),
act_desc_(NULL),
bias_desc_(NULL),
workspace_fwd_sizes_(0),
workspace_size_inbytes_(0),
fwd_algo_((cudnnConvolutionFwdAlgo_t)0) {}
~CudnnConv2DBase() {
if (conv_desc_) {
CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc_));
}
if (input_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
}
if (output_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_));
}
if (act_desc_) {
CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
}
if (bias_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_));
}
if (filter_desc_) {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc_));
}
if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_));
}
if (workspace_data_ != NULL) {
cudaFree(workspace_data_);
}
}
protected:
cudaStream_t stream_;
cudnnHandle_t handle_;
cudnnConvolutionFwdAlgo_t fwd_algo_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
// activation descriptor
cudnnActivationDescriptor_t act_desc_;
bool with_relu_act_{true};
size_t workspace_fwd_sizes_;
size_t workspace_size_inbytes_; // size of underlying storage
void* workspace_data_; // underlying storage
void* workspace_; // aliases into _workspaceData
const bool use_tensor_core_ = true;
const size_t workspace_limit_bytes_ = 4 * 1024 * 1024;
const cudnnConvolutionFwdPreference_t preference_ =
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
// For int8
Tensor temp_tensor_;
Tensor scale_;
};
template <PrecisionType Ptype_out>
class CudnnConv2D : public CudnnConv2DBase<Ptype_out> {
public:
CudnnConv2D() : CudnnConv2DBase<Ptype_out>() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::ConvParam& param);
};
template <PrecisionType Ptype_out>
class CudnnConv2DInt8 : CudnnConv2DBase<Ptype_out> {
public:
CudnnConv2DInt8() : CudnnConv2DBase<Ptype_out>() {}
virtual bool init(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::ConvParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::ConvParam& param);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iostream"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
float4* out,
const float4* scale,
int N,
int K,
int H,
int W) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
float4 packed_val;
packed_val.x = in_ptr.x * scale_ptr.x;
packed_val.y = in_ptr.y * scale_ptr.y;
packed_val.z = in_ptr.z * scale_ptr.z;
packed_val.w = in_ptr.w * scale_ptr.w;
out[tid] = packed_val;
}
}
void fp32_scale_nhwc4(int num,
const void* in,
void* out,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
fp32_scale_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<float4*>(out),
static_cast<const float4*>(scale),
N,
K,
H,
W);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
void fp32_scale_nhwc4(int num,
const void* din,
void* dout,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/type_trans.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
char4* out,
const float4* scale,
int N,
int K,
int H,
int W) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int scale_idx = tid % K;
const float4 scale_ptr = scale[scale_idx];
const float4 in_ptr = in[tid];
char4 result_val;
result_val.x = from_float<int8_t>(in_ptr.x * scale_ptr.x);
result_val.y = from_float<int8_t>(in_ptr.y * scale_ptr.y);
result_val.z = from_float<int8_t>(in_ptr.z * scale_ptr.z);
result_val.w = from_float<int8_t>(in_ptr.w * scale_ptr.w);
out[tid] = result_val;
}
}
void fp32_to_int8_nhwc4(int num,
const void* in,
void* out,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
fp32_scale_nhwc4_kernel<<<block, thread, 0, stream>>>(
num,
static_cast<const float4*>(in),
static_cast<char4*>(out),
static_cast<const float4*>(scale),
N,
K,
H,
W);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
void fp32_to_int8_nhwc4(int num,
const void* din,
void* dout,
const void* scale,
int N,
int K,
int H,
int W,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <string>
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__device__ T from_float(float x);
template <>
__device__ __forceinline__ float from_float<float>(float x) {
return x;
}
template <>
__device__ __forceinline__ half from_float<half>(float x) {
return __float2half(x);
}
template <>
__device__ __forceinline__ int8_t from_float<int8_t>(float x) {
x = fmaxf(x, INT8_MIN);
x = fminf(x, INT8_MAX);
return __float2int_rn(x);
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
...@@ -107,6 +107,7 @@ KernelRegistry::KernelRegistry() ...@@ -107,6 +107,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kCUDA, kFloat, kNCHW); INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny); INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kAny, kNCHW);
......
...@@ -70,6 +70,9 @@ class KernelRegistry final { ...@@ -70,6 +70,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kCUDA), KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kX86), KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFloat), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
......
...@@ -7,15 +7,18 @@ message(STATUS "compile with lite CUDA kernels") ...@@ -7,15 +7,18 @@ message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps}
${math_cuda})
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
set(cuda_kernels set(cuda_kernels
conv2d_cuda
mul_compute_cuda mul_compute_cuda
io_copy_compute_cuda io_copy_compute_cuda
leaky_relu_compute_cuda leaky_relu_compute_cuda
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/conv_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
conv_impl_.reset(new lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>);
conv_impl_->init(param, &ctx);
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
conv_impl_->run(param);
}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
conv_impl_.reset(new lite::cuda::math::CudnnConv2DInt8<Ptype_out>);
conv_impl_->init(param, &ctx);
}
template <PrecisionType Ptype_out>
void ConvComputeInt8<Ptype_out>::Run() {
auto& param = this->Param<param_t>();
conv_impl_->run(param);
}
template class ConvComputeInt8<PRECISION(kInt8)>;
template class ConvComputeInt8<PRECISION(kFloat)>;
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
conv2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kFloat)>,
fp32_out)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
kInt8,
kNHWC,
paddle::lite::kernels::cuda::ConvComputeInt8<PRECISION(kInt8)>,
int8_out)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ConvCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConvCompute() = default;
private:
std::unique_ptr<lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>> conv_impl_;
};
template <PrecisionType Ptype_out>
class ConvComputeInt8 : public KernelLite<TARGET(kCUDA), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConvComputeInt8() = default;
private:
std::unique_ptr<lite::cuda::math::CudnnConv2DInt8<Ptype_out>> conv_impl_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
float random(float low, float high) {
static std::mt19937 mt(100);
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
TEST(conv_compute, fp32) {
ConvCompute conv_fp32;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param;
param.activation_param = act_param;
param.paddings = {1, 1};
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 1, h = 3, w = 3;
int c_o = 1, h_o = 3, w_o = 3;
y.Resize({n, c_o, h_o, w_o});
x_cpu.Resize({n, c, h, w});
filter_cpu.Resize({c_o, c / param.groups, 3, 3});
y_cpu.Resize({n, c_o, h_o, w_o});
bias_cpu.Resize({c_o});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* filter_cpu_data = filter_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
float* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i;
}
std::vector<float> weight = {-0.2209115,
-0.17199445,
-0.2059412,
0.6763207,
-0.12260777,
-0.43123743,
-0.49696392,
-0.27471393,
-0.81017196};
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = weight[i];
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = 0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<float, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data, bias_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
// param.bias = &bias;
conv_fp32.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
conv_fp32.SetContext(std::move(ctx));
conv_fp32.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-0.8, -0.7};
for (int i = 0; i < y.numel(); i++) {
LOG(INFO) << y_cpu_data[i];
}
}
TEST(conv_compute, int8) {
ConvComputeInt8<PRECISION(kFloat)> int8_conv_fp32out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
operators::ConvParam param;
// param.activation_param = act_param;
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3;
y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c});
filter_cpu.Resize({c, 3, 3, c / param.groups});
y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c});
auto* x_data = x.mutable_data<int8_t>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
auto* y_cpu_data = x_cpu.mutable_data<float>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(1);
}
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(1);
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
param.weight_scale = {1, 2, 3, 4};
int8_conv_fp32out.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
int8_conv_fp32out.SetContext(std::move(ctx));
int8_conv_fp32out.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {36, 72, 108, 144};
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
}
}
TEST(conv_compute, int8_int8_out) {
ConvComputeInt8<PRECISION(kInt8)> int8_conv_fp32out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam act_param;
act_param.has_active = true;
// act_param.active_type = core::ActiveType::Active_relu;
act_param.active_type = lite_api::ActivationType::kLeakyRelu;
act_param.Leaky_relu_alpha = 0.1;
operators::ConvParam param;
param.activation_param = act_param;
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
int n = 1, c = 4, h = 3, w = 3;
y.Resize({1, 1, 1, c});
x_cpu.Resize({n, h, w, c});
filter_cpu.Resize({c, 3, 3, c / param.groups});
y_cpu.Resize({1, 1, 1, c});
bias_cpu.Resize({c});
auto* x_data = x.mutable_data<int8_t>(TARGET(kCUDA));
auto* y_data = y.mutable_data<int8_t>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<int8_t>();
auto* filter_cpu_data = filter_cpu.mutable_data<int8_t>();
auto* y_cpu_data = x_cpu.mutable_data<int8_t>();
auto* bias_cpu_data = bias_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
}
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
}
x.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
filter.Assign<int8_t, lite::DDim, TARGET(kCUDA)>(filter_cpu_data,
filter_cpu.dims());
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
param.x = &x;
param.filter = &filter;
param.output = &y;
param.weight_scale = {0.01, 0.02, 0.03, 0.04};
param.bias = &bias;
int8_conv_fp32out.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
int8_conv_fp32out.SetContext(std::move(ctx));
int8_conv_fp32out.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(int8_t) * y.numel(), IoDirection::DtoH);
std::vector<float> real_results = {-1, 4, 0, -2};
for (int i = 0; i < y.numel(); i++) {
// EXPECT_NEAR(y_cpu_data[i], real_results[i], 1e-5);
LOG(INFO) << float(y_cpu_data[i]);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -89,6 +89,7 @@ class IoCopyCudaToHostCompute ...@@ -89,6 +89,7 @@ class IoCopyCudaToHostCompute
auto& param = Param<operators::IoCopyParam>(); auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA)); CHECK(param.x->target() == TARGET(kCUDA));
auto mem_size = param.x->memory_size(); auto mem_size = param.x->memory_size();
LOG(INFO) << "io copy cuda to host " << mem_size;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size); auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
CopyToHostSync(data, param.x->raw_data(), mem_size); CopyToHostSync(data, param.x->raw_data(), mem_size);
} }
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/cuda/leaky_relu_compute.h" #include "lite/kernels/cuda/leaky_relu_compute.h"
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/cuda/yolo_box_compute.h" #include "lite/kernels/cuda/yolo_box_compute.h"
...@@ -179,7 +178,7 @@ void YoloBoxCompute::Run() { ...@@ -179,7 +178,7 @@ void YoloBoxCompute::Run() {
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h; int input_size = downsample_ratio * h;
anchors_.Resize(static_cast<int>({anchors.size()})); anchors_.Resize({static_cast<int64_t>(anchors.size())});
int* d_anchors = anchors_.mutable_data<int>(TARGET(kCUDA)); int* d_anchors = anchors_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_anchors, CopySync<TARGET(kCUDA)>(d_anchors,
anchors.data(), anchors.data(),
......
...@@ -466,7 +466,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, ...@@ -466,7 +466,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
#define SET_DATA_TYPE(precision, type_desc) \ #define SET_DATA_TYPE(precision, type_desc) \
case precision: \ case precision: \
desc.SetDataType(type_desc); \ desc.SetDataType(type_desc); \
break break;
SET_DATA_TYPE(PRECISION(kFloat), VarDescAPI::VarDataType::FP32); SET_DATA_TYPE(PRECISION(kFloat), VarDescAPI::VarDataType::FP32);
SET_DATA_TYPE(PRECISION(kInt8), VarDescAPI::VarDataType::INT8); SET_DATA_TYPE(PRECISION(kInt8), VarDescAPI::VarDataType::INT8);
...@@ -487,14 +487,14 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, ...@@ -487,14 +487,14 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
if (tensor.target() == TARGET(kCUDA)) { if (tensor.target() == TARGET(kCUDA)) {
switch (tensor.precision()) { switch (tensor.precision()) {
#define DO(precision, type) \ #define DO(precision, type) \
case precision: \ case precision: { \
std::unique_ptr<type> tmp_buffer(new type[tensor.data_size()]); \ std::unique_ptr<type> tmp_buffer(new type[tensor.data_size()]); \
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), \ TargetWrapperCuda::MemcpySync(tmp_buffer.get(), \
tensor.data<type>(), \ tensor.data<type>(), \
tensor.data_size(), \ tensor.data_size(), \
IoDirection::DtoH); \ IoDirection::DtoH); \
desc.SetData<type>(tmp_buffer.get(), tensor.data_size()); \ desc.SetData<type>(tmp_buffer.get(), tensor.data_size()); \
break } break;
DO(PRECISION(kFloat), float); DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t); DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t); DO(PRECISION(kInt16), int16_t);
...@@ -512,7 +512,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, ...@@ -512,7 +512,7 @@ void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
#define DO(precision, type) \ #define DO(precision, type) \
case precision: \ case precision: \
desc.SetData<type>(tensor.data<type>(), tensor.data_size()); \ desc.SetData<type>(tensor.data<type>(), tensor.data_size()); \
break break;
DO(PRECISION(kFloat), float); DO(PRECISION(kFloat), float);
DO(PRECISION(kInt8), int8_t); DO(PRECISION(kInt8), int8_t);
DO(PRECISION(kInt16), int16_t); DO(PRECISION(kInt16), int16_t);
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/core/types.h"
#include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h" #include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
...@@ -203,6 +205,28 @@ struct ConcatParam { ...@@ -203,6 +205,28 @@ struct ConcatParam {
int axis{0}; int axis{0};
}; };
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
std::string Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
lite::Tensor* Out{};
bool has_active{false};
lite_api::ActivationType active_type;
};
struct ActivationGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out{};
// for backward
lite::Tensor* X_grad{};
const lite::Tensor* Out_grad{};
};
// For Convolution op // For Convolution op
struct ConvParam { struct ConvParam {
lite::Tensor* x{}; lite::Tensor* x{};
...@@ -226,6 +250,8 @@ struct ConvParam { ...@@ -226,6 +250,8 @@ struct ConvParam {
float scale_weights{1.0f}; // only used with mkl-dnn int8 float scale_weights{1.0f}; // only used with mkl-dnn int8
bool force_fp32_output{false}; // only used in mkl-dnn int8 bool force_fp32_output{false}; // only used in mkl-dnn int8
std::string data_format{"Anylayout"}; std::string data_format{"Anylayout"};
// for activation
ActivationParam activation_param;
// for int8 // for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
...@@ -320,26 +346,6 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { ...@@ -320,26 +346,6 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
std::string act_type; std::string act_type;
}; };
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
std::string Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
lite::Tensor* Out{};
};
struct ActivationGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out{};
// for backward
lite::Tensor* X_grad{};
const lite::Tensor* Out_grad{};
};
/// ----------------------- mean operators ---------------------- /// ----------------------- mean operators ----------------------
struct MeanParam { struct MeanParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册