提交 00191223 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!677 GPU fix codex for conv2d

Merge pull request !677 from VectorSL/fix-codex-for-gpu-conv2d
......@@ -114,23 +114,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_;
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
}
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[2];
dilation_ = dilation_ori[2];
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t input_descriptor_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(in_shape, kernel_node);
......@@ -284,6 +268,24 @@ class Conv2dGpuFwdKernel : public GpuKernel {
conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
}
}
void SetStrideAndDilation(const CNodePtr &kernel_node) {
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
}
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "conv2d only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[2];
dilation_ = dilation_ori[2];
}
cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
......
......@@ -117,19 +117,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_;
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t x_desc_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(in_shape, kernel_node);
......@@ -288,7 +276,21 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])),
"SetTensor4dDescriptor failed");
}
void SetStrideAndDilation(const CNodePtr &kernel_node) {
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
}
cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t dw_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
......
......@@ -118,19 +118,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
pad_height_ = GetAttr<int>(kernel_node, "pad");
pad_width_ = pad_height_;
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t dx_desc_real = nullptr;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad(input_shape, kernel_node);
......@@ -286,6 +274,21 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
input_shape[2], input_shape[3]),
"SetTensor4dDescriptor failed");
}
void SetStrideAndDilation(const CNodePtr &kernel_node) {
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal stride, and stride must be 2d!";
}
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal dilation, and dilation must be 4d!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!";
}
stride_ = stride_ori[0];
dilation_ = dilation_ori[2];
}
cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t w_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册