提交 74912c7d 编写于 作者: C chengduoZH

fix data layout

上级 c5330fa7
......@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
// N, M, H, W
// (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
......@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
// Input: (N, M, H, W)
// Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W)
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W)
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -63,9 +60,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \
} while (false)
enum class DataLayout {
enum class DataLayout { // Not use
kNHWC,
kNCHW,
kNCDHW,
kNCHW_VECT_C,
};
......@@ -107,12 +105,15 @@ class CudnnDataType<double> {
}
};
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
inline cudnnTensorFormat_t GetCudnnTensorFormat(
const DataLayout& order) { // Not use
switch (order) {
case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
}
......@@ -139,7 +140,7 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1];
}
// Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW order
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册