提交 7d7b8134 编写于 作者: V VectorSL

gpu fix conv bug

上级 3449abd7
......@@ -134,7 +134,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
cudnnTensorDescriptor_t x_desc_real = nullptr;
int padA[2];
int strideA[2] = {stride_[0], stride_[1]};
int dilaA[2] = {dilation_[0], dilation_[1]};
int dilaA[2] = {dilation_[2], dilation_[3]};
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) {
pad_height_ = pad_list[0] + pad_list[1];
pad_width_ = pad_list[2] + pad_list[3];
......
......@@ -135,7 +135,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
cudnnTensorDescriptor_t dx_desc_real = nullptr;
int padA[2];
int strideA[2] = {stride_[0], stride_[1]};
int dilaA[2] = {dilation_[0], dilation_[1]};
int dilaA[2] = {dilation_[2], dilation_[3]};
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) {
pad_height_ = pad_list[0] + pad_list[1];
pad_width_ = pad_list[2] + pad_list[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册