提交 ee7a6401 编写于 作者: V VectorSL

gpu update conv kernel for auto-mixed-precision

上级 eefb6edd
...@@ -218,7 +218,7 @@ class BinaryOpGpuKernel : public GpuKernel { ...@@ -218,7 +218,7 @@ class BinaryOpGpuKernel : public GpuKernel {
} }
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetOpTensorDescriptor(opTensor_descriptor_, tensor_op_, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN), cudnnSetOpTensorDescriptor(opTensor_descriptor_, tensor_op_, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
"cudnnSetOpTensorDescriptor failed"); "cudnnSetOpTensorDescriptor failed");
return; return;
} }
......
...@@ -142,10 +142,14 @@ class Conv2dGpuFwdKernel : public GpuKernel { ...@@ -142,10 +142,14 @@ class Conv2dGpuFwdKernel : public GpuKernel {
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_,
CUDNN_CROSS_CORRELATION, cudnn_data_type_), CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed"); "cudnnSetConvolution2dDescriptor failed");
input_descriptor_real = input_desc_; input_descriptor_real = input_desc_;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(input_descriptor_real); SelectAlgorithm(input_descriptor_real);
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -240,7 +244,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { ...@@ -240,7 +244,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_,
dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed"); "cudnnSetConvolution2dDescriptor failed");
} }
...@@ -276,6 +280,9 @@ class Conv2dGpuFwdKernel : public GpuKernel { ...@@ -276,6 +280,9 @@ class Conv2dGpuFwdKernel : public GpuKernel {
"cudnnGetConvolutionForwardAlgorithm_v7 failed"); "cudnnGetConvolutionForwardAlgorithm_v7 failed");
conv_algorithm_ = perf_results.algo; conv_algorithm_ = perf_results.algo;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
}
} }
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t input_desc_;
......
...@@ -141,10 +141,14 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { ...@@ -141,10 +141,14 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_,
CUDNN_CROSS_CORRELATION, cudnn_data_type_), CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"GetConvolution2dDescriptor failed"); "GetConvolution2dDescriptor failed");
x_desc_real = x_desc_; x_desc_real = x_desc_;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(x_desc_real); SelectAlgorithm(x_desc_real);
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -239,7 +243,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { ...@@ -239,7 +243,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_,
dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed"); "cudnnSetConvolution2dDescriptor failed");
} }
void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) {
...@@ -258,6 +262,9 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { ...@@ -258,6 +262,9 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
"GetConvolutionBackwardFilterAlgorithm failed"); "GetConvolutionBackwardFilterAlgorithm failed");
algo_ = perf_results.algo; algo_ = perf_results.algo;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
} }
void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) { void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value(); auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value();
......
...@@ -142,10 +142,14 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { ...@@ -142,10 +142,14 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_,
CUDNN_CROSS_CORRELATION, cudnn_data_type_), CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed"); "cudnnSetConvolution2dDescriptor failed");
dx_desc_real = dx_desc_; dx_desc_real = dx_desc_;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(dx_desc_real); SelectAlgorithm(dx_desc_real);
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -239,7 +243,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { ...@@ -239,7 +243,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_,
dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed"); "cudnnSetConvolution2dDescriptor failed");
} }
void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
...@@ -258,6 +262,9 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { ...@@ -258,6 +262,9 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
algo_ = perf_results.algo; algo_ = perf_results.algo;
} }
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
} }
void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) { void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value(); auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册