提交 3aa331d9 编写于 作者: L liym27 提交者: Guo Sheng

fix conv2d and conv3d: (#20042)

1.support asymmetric padding;
    2.support padding algorithm:"SAME" and "VALID";
    3.support channel_last: data_format NHWC and NDHWC;
    4.change doc of python API and c++;

    test=develop, test=document_preview
上级 02c6edc0
......@@ -9,6 +9,7 @@ function(op_library TARGET)
set(miopen_hip_cc_srcs)
set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(cudnn_cu_srcs)
set(CUDNN_FILE)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
......@@ -44,6 +45,9 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu)
list(APPEND cudnn_cu_srcs ${CUDNN_FILE}.cu)
endif()
if(WITH_AMD_GPU)
string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
......@@ -60,6 +64,8 @@ function(op_library TARGET)
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.hip.cu$")
list(APPEND hip_cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu$")
list(APPEND cudnn_cu_srcs ${src})
elseif (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
......@@ -97,7 +103,7 @@ function(op_library TARGET)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif()
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
......@@ -160,6 +166,12 @@ function(op_library TARGET)
endif()
endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len)
if (WITH_GPU AND ${cudnn_cu_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN
if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
......
......@@ -140,8 +140,8 @@ paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=N
paddle.fluid.layers.square_error_cost (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'bbb9e708bab250359864fefbdf48e9d9'))
paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types', 'seq_length'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'b02844e0ad4bd713c5fe6802aa13219c'))
paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'padding_start', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, True, None, None, None, None, None)), ('document', '2bf23e7884c380c3b27f2709aa322cb9'))
paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '06de9adb5994f6f8cb806c75b55550af'))
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '71b09227709475fa178c1739dff64af6'))
paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None, 'NCHW')), ('document', 'b8da17862ba02b5297a37d2edd571d76'))
paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '73a15322d460ef9aa90d4d237b0bc5d5'))
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82'))
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'eaa9d0bbd3d4e017c8bc4ecdac483711'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5'))
......
......@@ -9,11 +9,14 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the spopecific language governing permissions and
limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
......@@ -56,18 +59,91 @@ static inline void GetNCDHW(const framework::DDim& dims,
}
}
static inline bool IsSymmetricPadding(const std::vector<int>& paddings,
const int data_dim) {
bool is_sys_pad = true;
if (paddings.size() == data_dim * 2) {
for (size_t i = 0; i < data_dim; ++i) {
if (paddings[2 * i] != paddings[2 * i + 1]) {
is_sys_pad = false;
return is_sys_pad;
}
}
}
return is_sys_pad;
}
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D>
static void PadFunction(const framework::ExecutionContext& context,
const std::vector<int>& pads,
const framework::Tensor& src, T pad_value,
framework::Tensor* out) {
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = pads[i * 2];
paddings[i].second = pads[i * 2 + 1];
}
auto src_tensor = EigenTensor<T, D>::From(src);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_tensor.device(place) = src_tensor.pad(paddings, pad_value);
}
template <typename DeviceContext, typename T, size_t D>
static void Slice_2(const framework::ExecutionContext& context,
const Tensor* input, Tensor* out,
const std::vector<int>& starts,
const std::vector<int>& axes) {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto in_dims = input->dims();
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = new_out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims);
out_t.device(place) = in_t.slice(offsets, extents);
}
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
const Tensor* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
output->mutable_data<T>(ctx.GetPlace());
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
......@@ -79,23 +155,121 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// ------------ transformed tensor -----------
Tensor transformed_input_channel(input->type());
Tensor transformed_output(output->type());
T* output_data = nullptr;
if (channel_last) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, output,
&transformed_output);
} else {
transformed_input_channel = *input;
transformed_output = *output;
}
output_data = transformed_output.data<T>();
// update padding and dilation
auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
if (!is_sys_pad) {
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
std::vector<int> input_pad(transformed_input_channel.dims().size() * 2,
0);
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_input =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
const int rank = transformed_input_channel.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
}
} else {
transformed_input = transformed_input_channel;
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* input_data = input->data<T>();
const T* input_data = transformed_input.data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ConvArgs args{input, filter, output, strides, paddings, dilations};
ConvArgs args{&transformed_input, filter, &transformed_output, strides,
padding_common, dilations};
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto dtype = platform::CudnnDataType<T>::type;
DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
if (transformed_input_channel.dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
auto layout_format = GetCudnnTensorFormat(layout);
args.handle = handle;
args.cdesc.set(dtype, paddings, strides, dilations);
args.cdesc.set(dtype, padding_common, strides, dilations);
#if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it manually
// FIXME(typhoonzero): find a better way to disable groups
......@@ -104,13 +278,17 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
args.cdesc.desc(), groups));
groups = 1;
#endif
args.idesc.set(*input, groups);
args.idesc.set(transformed_input, groups);
args.wdesc.set(*filter, layout_format, groups);
args.odesc.set(*output, groups);
args.odesc.set(transformed_output, groups);
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(output->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);
GetNCDHW(transformed_output.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d,
&o_h, &o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
......@@ -138,6 +316,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
},
workspace_size);
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_output, output);
}
}
};
......@@ -146,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
......@@ -154,13 +337,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const T* input_data = input->data<T>();
const T* output_grad_data = output_grad->data<T>();
const T* filter_data = filter->data<T>();
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
}
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
......@@ -170,14 +358,141 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
"Can't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
Tensor transformed_input_channel(input->type());
Tensor transformed_output_grad_channel(output_grad->type());
Tensor transformed_input_grad_channel(input->type());
if (channel_last) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, output_grad, &transformed_output_grad_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, output_grad, &transformed_output_grad_channel);
if (input_grad) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input_grad, &transformed_input_grad_channel);
}
} else {
transformed_input_channel = *input;
transformed_output_grad_channel = *output_grad;
if (input_grad) {
transformed_input_grad_channel.ShareDataWith(*input_grad);
}
}
// update paddings
auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
// cuDNN only supports padding the same amount on every dimension.
// So we create a new padded input tensor.
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input(input->type());
Tensor transformed_input_grad(input->type());
std::vector<int> padding_common(data_dim, 0);
std::vector<int> input_pad(transformed_input_channel.dims().size() * 2, 0);
if (!is_sys_pad) {
// get pad
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
transformed_input_grad.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_input =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
if (input_grad) {
transformed_input_grad =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
}
// pad for input
const int rank = transformed_input_channel.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
}
} else {
transformed_input.ShareDataWith(transformed_input_channel);
if (input_grad) {
transformed_input_grad.ShareDataWith(transformed_input_grad_channel);
}
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* input_data = transformed_input.data<T>();
const T* output_grad_data = transformed_output_grad_channel.data<T>();
T* filter_grad_data = nullptr;
T* input_grad_data = nullptr;
ConvArgs args1{input_grad, filter, output_grad,
strides, paddings, dilations};
ConvArgs args2{input, filter_grad, output_grad,
strides, paddings, dilations};
// conv_cudnn_helper.h
T* transformed_input_grad_data = nullptr;
ConvArgs args1{&transformed_input_grad,
filter,
&transformed_output_grad_channel,
strides,
padding_common,
dilations};
ConvArgs args2{&transformed_input,
filter_grad,
&transformed_output_grad_channel,
strides,
padding_common,
dilations};
auto handle = dev_ctx.cudnn_handle();
auto dtype = platform::CudnnDataType<T>::type;
DataLayout layout = DataLayout::kNCHW;
......@@ -188,10 +503,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(output_grad->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h,
&o_w);
GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNCHW, &o_n,
&o_c, &o_d, &o_h, &o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
......@@ -212,12 +528,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (input_grad) {
// ------------------- cudnn descriptors ---------------------
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
input_grad_data = input_grad->data<T>();
transformed_input_grad_data = transformed_input_grad.data<T>();
args1.handle = handle;
args1.idesc.set(*input_grad, iwo_groups);
args1.idesc.set(transformed_input_grad, iwo_groups);
args1.wdesc.set(*filter, layout_tensor, iwo_groups);
args1.odesc.set(*output_grad, iwo_groups);
args1.cdesc.set(dtype, paddings, strides, dilations, c_groups);
args1.odesc.set(transformed_output_grad_channel, iwo_groups);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
......@@ -228,12 +545,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) {
// ------------------- cudnn descriptors ---------------------
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
filter_grad_data = filter_grad->data<T>();
args2.handle = handle;
args2.idesc.set(*input, iwo_groups);
args2.idesc.set(transformed_input, iwo_groups);
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups);
args2.odesc.set(*output_grad, iwo_groups);
args2.cdesc.set(dtype, paddings, strides, dilations, c_groups);
args2.odesc.set(transformed_output_grad_channel, iwo_groups);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
......@@ -254,10 +571,35 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_data + i * group_offset_filter, args1.odesc.desc(),
output_grad_data + i * group_offset_out, args1.cdesc.desc(),
data_algo, cudnn_workspace_ptr, workspace_size, &beta,
args1.idesc.desc(), input_grad_data + i * group_offset_in));
args1.idesc.desc(),
transformed_input_grad_data + i * group_offset_in));
},
workspace_size);
}
std::vector<int> starts(transformed_input_channel.dims().size(), 0);
std::vector<int> axes(transformed_input_channel.dims().size(), 0);
for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) {
starts[i] = input_pad[2 * i];
axes[i] = i;
}
transformed_input_grad_channel.mutable_data(ctx.GetPlace());
if (transformed_input_channel.dims().size() == 4) {
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
ctx, &transformed_input_grad, &transformed_input_grad_channel,
starts, axes);
} else {
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_input_grad, &transformed_input_grad_channel,
starts, axes);
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_input_grad_channel, input_grad);
}
}
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
......@@ -291,7 +633,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
auto X = ctx.Input<Tensor>("Input");
auto W = ctx.Input<Tensor>("Filter");
......@@ -302,8 +644,17 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto ddO = ctx.Output<Tensor>("DDOutput");
auto dW = ctx.Output<Tensor>("DFilter");
auto dX = ctx.Output<Tensor>("DInput");
if (ddO) {
ddO->mutable_data<T>(ctx.GetPlace());
}
if (dW) {
dW->mutable_data<T>(ctx.GetPlace());
}
if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
}
const T* x = X->data<T>();
// const T* x = X->data<T>();
const T* dy = dO->data<T>();
const T* w = W->data<T>();
......@@ -311,10 +662,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
const T* ddw = nullptr;
T *dw, *dx, *ddy;
dw = dx = ddy = nullptr;
T* transformed_dx = nullptr;
const std::vector<int>& strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int>& paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int>& dilations = ctx.Attr<std::vector<int>>("dilations");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
......@@ -324,6 +674,154 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
"Can't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensors to channel first-----------
Tensor transformed_X_channel(X->type());
Tensor transformed_dO_channel(dO->type());
Tensor transformed_ddX_channel(ddX->type());
Tensor transformed_ddO_channel(dO->type());
Tensor transformed_dX_channel(X->type());
if (channel_last) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, X, &transformed_X_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, X, &transformed_X_channel);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, dO, &transformed_dO_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, dO, &transformed_dO_channel);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddX, &transformed_ddX_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddX, &transformed_ddX_channel);
if (ddO) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddO, &transformed_ddO_channel);
}
if (dX) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, dX, &transformed_dX_channel);
transformed_dX_channel.mutable_data<T>(ctx.GetPlace());
}
} else {
transformed_X_channel = *X;
transformed_dO_channel = *dO;
transformed_ddX_channel = *ddX;
if (ddO) {
transformed_ddO_channel.ShareDataWith(*ddO);
}
if (dX) {
transformed_dX_channel.ShareDataWith(*dX);
}
}
auto in_dims = transformed_X_channel.dims();
auto filter_dims = W->dims();
framework::DDim in_data_dims =
framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = IsSymmetricPadding(paddings, data_dim);
Tensor transformed_X(X->type());
Tensor transformed_ddX(X->type());
Tensor transformed_dX(X->type());
std::vector<int> padding_common(data_dim, 0);
std::vector<int> input_pad(X->dims().size() * 2, 0);
if (!is_sys_pad) {
// get pad
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_X_channel.dims()[0];
new_input_shape_vec[1] = transformed_X_channel.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_X_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec));
transformed_X.Resize(new_input_shape);
transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_X =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
transformed_ddX =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
if (dX) {
transformed_dX =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
}
// pad for input
const int rank = X->dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
} break;
case 5: {
PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
} break;
default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
}
} else {
transformed_X.ShareDataWith(transformed_X_channel);
transformed_ddX.ShareDataWith(transformed_ddX_channel);
if (dX) {
transformed_dX.ShareDataWith(transformed_dX_channel);
}
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* x = transformed_X.data<T>();
int iwo_group = groups;
int c_group = 1;
......@@ -335,10 +833,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle();
ConvArgs args1{ddX, W, ddO, strides, paddings, dilations};
ConvArgs args2{X, ddW, ddO, strides, paddings, dilations};
ConvArgs args3{ddX, dW, dO, strides, paddings, dilations};
ConvArgs args4{dX, ddW, dO, strides, paddings, dilations};
ConvArgs args1{&transformed_ddX, W,
&transformed_ddO_channel, strides,
padding_common, dilations};
ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides,
padding_common, dilations};
ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides,
padding_common, dilations};
ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides,
padding_common, dilations};
cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
......@@ -353,14 +856,17 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
// ddo = conv(ddI, W) + conv(I, ddW)
size_t workspace_size = 0;
T* transformed_ddy_channel = nullptr;
if (ddO) {
ddy = ddO->mutable_data<T>(ctx.GetPlace());
ddy = ddO->data<T>();
transformed_ddy_channel = transformed_ddO_channel.data<T>();
if (ddX) {
args1.handle = handle;
args1.idesc.set(*ddX, iwo_group);
args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(*ddO, iwo_group);
args1.cdesc.set(dtype, paddings, strides, dilations, c_group);
args1.odesc.set(transformed_ddO_channel, iwo_group);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
......@@ -370,10 +876,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
if (ddW) {
ddw = ddW->data<T>();
args2.handle = handle;
args2.idesc.set(*X, iwo_group);
args2.idesc.set(transformed_X, iwo_group);
args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(*ddO, iwo_group);
args2.cdesc.set(dtype, paddings, strides, dilations, c_group);
args2.odesc.set(transformed_ddO_channel, iwo_group);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx);
......@@ -383,12 +891,14 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
}
if (dW && ddX) {
dw = dW->mutable_data<T>(ctx.GetPlace());
dw = dW->data<T>();
args3.handle = handle;
args3.idesc.set(*ddX, iwo_group);
args3.idesc.set(transformed_ddX, iwo_group);
args3.wdesc.set(*dW, layout, iwo_group);
args3.odesc.set(*dO, iwo_group);
args3.cdesc.set(dtype, paddings, strides, dilations, c_group);
args3.odesc.set(transformed_dO_channel, iwo_group);
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
......@@ -398,12 +908,13 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
}
if (ddW && dX) {
dx = dX->mutable_data<T>(ctx.GetPlace());
transformed_dx = transformed_dX.data<T>();
args4.handle = handle;
args4.idesc.set(*dX, iwo_group);
args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(*dO, iwo_group);
args4.cdesc.set(dtype, paddings, strides, dilations, c_group);
args4.odesc.set(transformed_dO_channel, iwo_group);
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
......@@ -413,9 +924,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(X->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
GetNCDHW(transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h,
&i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(dO->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);
GetNCDHW(transformed_dO_channel.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d,
&o_h, &o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
......@@ -426,7 +940,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
if (ddO) {
if (ddX) {
ddx = ddX->data<T>();
ddx = transformed_ddX.data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
......@@ -435,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
ddx + i * group_offset_in, args1.wdesc.desc(),
w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1,
workspace_ptr, workspace_size, &beta, args1.odesc.desc(),
ddy + i * group_offset_out));
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
......@@ -449,21 +963,27 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.wdesc.desc(), ddw + i * group_offset_filter,
args2.cdesc.desc(), fwd_algo2, workspace_ptr,
workspace_size, &alpha, args2.odesc.desc(),
ddy + i * group_offset_out));
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_ddO_channel, ddO);
}
}
T* transformed_dy_channel = nullptr;
if (dW && ddX) {
ddx = ddX->data<T>();
ddx = transformed_ddX.data<T>();
transformed_dy_channel = transformed_dO_channel.data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in,
args3.odesc.desc(), dy + i * group_offset_out,
args3.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args3.cdesc.desc(), filter_algo, workspace_ptr,
workspace_size, &beta, args3.wdesc.desc(),
dw + i * group_offset_filter));
......@@ -480,12 +1000,33 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, args4.wdesc.desc(),
ddw + i * group_offset_filter, args4.odesc.desc(),
dy + i * group_offset_out, args4.cdesc.desc(), data_algo,
workspace_ptr, workspace_size, &beta, args4.idesc.desc(),
dx + i * group_offset_in));
transformed_dy_channel + i * group_offset_out,
args4.cdesc.desc(), data_algo, workspace_ptr, workspace_size,
&beta, args4.idesc.desc(),
transformed_dx + i * group_offset_in));
},
workspace_size);
}
// reverse padded input
std::vector<int> starts(X->dims().size(), 0);
std::vector<int> axes(X->dims().size(), 0);
for (size_t i = 0; i < X->dims().size(); ++i) {
starts[i] = input_pad[2 * i];
axes[i] = i;
}
if (X->dims().size() == 4) {
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
} else {
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_dX_channel, dX);
}
}
}
};
......@@ -514,8 +1055,7 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>);
paddle::operators::CUDNNConvGradOpKernel<double>);
REGISTER_OP_KERNEL(
conv3d_grad_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
......
......@@ -31,11 +31,11 @@ namespace paddle {
namespace operators {
void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
......@@ -43,41 +43,64 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
"Conv intput should be 4-D or 5-D tensor, get %u",
in_dims.size());
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same.");
PADDLE_ENFORCE(
in_dims.size() - strides.size() == 2U,
"Conv input dimension and strides dimension should be consistent.");
PADDLE_ENFORCE_EQ(
paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension should be the same.");
in_dims.size() - strides.size() == 2U, true,
"Conv input dimension and strides dimension should be consistent.");
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (!channel_last) {
output_shape.push_back(filter_dims[0]);
}
for (size_t i = 0; i < in_data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
output_shape.push_back(ConvOutputSize(in_data_dims[i], filter_dims[i + 2],
dilations[i], paddings[2 * i],
paddings[2 * i + 1], strides[i]));
}
}
if (channel_last) {
output_shape.push_back(filter_dims[0]);
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
}
......@@ -89,7 +112,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto input_data_type = ctx.Input<Tensor>("Input")->type();
std::string data_format = ctx.Attr<std::string>("data_format");
std::string data_format =
"AnyLayout"; // todo enable data layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
......@@ -142,10 +166,10 @@ void Conv2DOpMaker::Make() {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput(
"Input",
AddInput("Input",
"(Tensor) The input tensor of convolution operator. "
"The format of input tensor is NCHW, where N is batch size, C is the "
"The format of input tensor is NCHW or NHWC, where N is batch size, "
"C is the "
"number of channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput("Filter",
......@@ -167,7 +191,7 @@ void Conv2DOpMaker::Make() {
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
"It has same data fromat and data type as the Input.");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......@@ -175,9 +199,16 @@ void Conv2DOpMaker::Make() {
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0}), the "
"paddings(h_pad, w_pad) of "
"paddings(pad_height_top, pad_height_bottom, "
"pad_width_left, pad_wifth_right) of "
"convolution operator.")
.SetDefault({0, 0});
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
......@@ -254,7 +285,7 @@ void Conv2DOpMaker::Make() {
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
.SetDefault("NCHW");
// TODO(dzhwinter): need to registered layout transform function
AddAttr<int>("workspace_size_MB",
"Only used in cudnn kernel. Need set use_cudnn to true."
......@@ -269,13 +300,14 @@ void Conv2DOpMaker::Make() {
"convolution, whether enable exhaustive search "
"for cuDNN convolution or not, default is False.")
.SetDefault(false);
AddComment(R"DOC(
Convolution Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and Output(Output) are in NCHW format. Where N is batch
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is
the width of the feature.
Filters(Input) is MCHW format. Where M is the number of output image channels, C is
......@@ -293,8 +325,8 @@ Example:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
Apply();
......@@ -308,7 +340,8 @@ void Conv3DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is the "
"The format of input tensor is NCDHW or NDHWC. Where N is batch size, C "
"is the "
"number of channels, D is the depth of the feature, H is the height of "
"the feature, "
"and W is the width of the feature.");
......@@ -327,17 +360,25 @@ void Conv3DOpMaker::Make() {
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
"It has same data fromat and data type as the Input.");
AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("paddings",
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, default:{0, 0, 0}), the "
"paddings(d_pad, h_pad, w_pad) of convolution "
"paddings(pad_depth_front, pad_depth_back, pad_height_top, "
"pad_height_bottom, pad_width_left, pad_width_right) of convolution "
"operator.")
.SetDefault({0, 0, 0});
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
......@@ -375,11 +416,11 @@ void Conv3DOpMaker::Make() {
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"(string, default NCDHW) Only used in "
"An optional string from: \"NDHWC\", \"NCDHW\". "
"Defaults to \"NDHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
.SetDefault("NCDHW");
AddAttr<bool>("force_fp32_output",
"(bool, default false) Only used in mkldnn INT8 kernel")
.SetDefault(false);
......@@ -402,7 +443,7 @@ Convolution3D Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW format, where N is batch
Input(Input) and output(Output) are in NCDHW or NDHWC format, where N is batch
size, C is the number of channels,D is the depth of the feature, H is the height of
the feature, and W is the width of the feature.
Filters(Input) is MCDHW format, where M is the number of output image channels,
......@@ -420,9 +461,9 @@ Example:
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Where
$$
D_{out}= \frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\
H_{out}= \frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\
W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
D_{out}= \frac{(D_{in} + pad_depth_front + pad_depth_back - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\
H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\
W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
$$
)DOC");
Apply();
......@@ -445,7 +486,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = ctx.Attr<std::string>("data_format");
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
......@@ -623,7 +664,7 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -39,8 +40,8 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
PADDLE_ENFORCE(
output_size > 0,
PADDLE_ENFORCE_GT(
output_size, 0,
"Due to the settings of padding(%d), filter_size(%d), dilation(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
......@@ -48,6 +49,62 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
return output_size;
}
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding_1, int padding_2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(output_size, 0,
"Due to the settings of padding(%d, %d), filter_size(%d), "
"dilation(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
padding_1, padding_2, filter_size, dilation, stride,
input_size);
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilation,
const std::string padding_algorithm,
const framework::DDim data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
// set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims);
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
PADDLE_ENFORCE_EQ(
data_dims.size() * 2, paddings->size(),
"Paddings size should be the same or twice as the input data size.");
}
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
int pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilation->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
......@@ -59,9 +116,80 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
if (paddings.size() != strides.size()) {
for (size_t j = 0; j < paddings.size(); ++j) {
padding_0 = padding_0 && (paddings[j] == 0);
}
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
}
}
// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -131,39 +259,82 @@ class GemmConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
const int groups = context.Attr<int>("groups");
const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output,
&transformed_output);
} else {
transformed_input = *input;
transformed_output = *output;
}
// update padding and dilation
auto trans_in_dims = transformed_input.dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims =
framework::slice_ddim(trans_in_dims, 2, trans_in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
auto& dev_ctx = context.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int batch_size = static_cast<int>(transformed_input.dims()[0]);
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
// filter_shape_vec:
// {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
// output_shape_vec:
// {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(
framework::vectorize(transformed_output.dims()));
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
// col_shape_vec:
// {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w,
// o_d,o_h, o_w}
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
col_shape_vec[0] = trans_in_dims[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w)
// size:
// (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * o_h *
// o_w)
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
framework::flatten_to_2d(col_shape, data_dim);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
......@@ -175,28 +346,31 @@ class GemmConvKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape =
framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim in_matrix_shape = framework::slice_ddim(
transformed_input.dims(), 1, transformed_input.dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
transformed_output.dims()[1],
transformed_output.numel() /
(transformed_output.dims()[0] * transformed_output.dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch =
transformed_input.Slice(i, i + 1).Resize(in_matrix_shape);
Tensor out_batch =
transformed_output.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
......@@ -206,13 +380,12 @@ class GemmConvKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
}
......@@ -223,6 +396,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
T(0.0));
}
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_output,
output);
}
}
};
......@@ -245,11 +422,44 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return;
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
Tensor transformed_input(input->type());
Tensor transformed_output_grad(output_grad->type());
const int batch_size = static_cast<int>(input->dims()[0]);
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
TransToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
} else {
transformed_input = *input;
transformed_output_grad = *output_grad;
}
// update padding and dilation
auto in_dims = transformed_input.dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims =
framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(transformed_input.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
......@@ -257,14 +467,14 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(
framework::vectorize(output_grad->dims()));
framework::vectorize(transformed_output_grad.dims()));
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
col_shape_vec[0] = transformed_input.dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
......@@ -278,24 +488,25 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
framework::DDim input_shape =
framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim input_shape = framework::slice_ddim(
transformed_input.dims(), 1, transformed_input.dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output_grad->dims()[1],
output_grad->numel() /
(output_grad->dims()[0] * output_grad->dims()[1])};
transformed_output_grad.dims()[1],
transformed_output_grad.numel() / (transformed_output_grad.dims()[0] *
transformed_output_grad.dims()[1])};
// convolution backward input operator: gemm + col2im(or col2vol)
// convolution backward weight operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_output_grad.dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
......@@ -312,19 +523,27 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
Tensor transformed_input_grad(input_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input_grad,
&transformed_input_grad);
} else {
transformed_input_grad = *input_grad;
}
// if is_expand is false, the operation of set_zero is unnecessary,
// because math::matmul will reset input_grad.
if (is_expand) {
set_zero(dev_ctx, input_grad, static_cast<T>(0));
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
}
math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
transformed_input_grad.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
......@@ -343,14 +562,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&in_grad_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice);
}
}
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_input_grad,
input_grad);
}
}
if (filter_grad) {
......@@ -362,8 +585,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = transformed_input.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
......@@ -376,9 +599,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
}
......@@ -412,21 +636,60 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"),
"Cannot find input Filter(%s) in scope)",
ctx.Inputs("Filter")[0]);
if (!ddY && !dW && !dX) return;
int groups = ctx.Attr<int>("groups");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const int groups = ctx.Attr<int>("groups");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
Tensor transformed_X(X->type());
Tensor transformed_dY(dY->type());
Tensor transformed_ddX(ddX->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);
TransToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);
ResizeToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
TransToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
ResizeToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
TransToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
} else {
transformed_X = *X;
transformed_dY = *dY;
transformed_ddX = *ddX;
}
// update padding and dilation
auto in_dims = transformed_X.dims();
auto filter_dims = W.dims();
framework::DDim in_data_dims =
framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(X->dims()[0]);
const int batch_size = static_cast<int>(transformed_X.dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(W.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(dY->dims()));
std::vector<int64_t> output_shape_vec(
framework::vectorize(transformed_dY.dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
// col_shape [in_channel/group, kh, kw, oh, ow]
col_shape_vec[0] = X->dims()[1] / groups;
col_shape_vec[0] = transformed_X.dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2];
......@@ -436,17 +699,19 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
// input_shape [Cin, H, W]
framework::DDim input_shape =
framework::slice_ddim(X->dims(), 1, X->dims().size());
framework::DDim input_shape = framework::slice_ddim(
transformed_X.dims(), 1, transformed_X.dims().size());
// filter_matrix_shape [Cout, Cin * kh * kw]
framework::DDim filter_matrix_shape = {W.dims()[0],
W.numel() / W.dims()[0]};
W.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
dY->dims()[1], dY->numel() / (dY->dims()[0] * dY->dims()[1])};
int in_step = static_cast<int>(X->dims()[1]) / groups;
int out_step = static_cast<int>(dY->dims()[1]) / groups;
transformed_dY.dims()[1],
transformed_dY.numel() /
(transformed_dY.dims()[0] * transformed_dY.dims()[1])};
int in_step = static_cast<int>(transformed_X.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_dY.dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
......@@ -466,19 +731,28 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
if (dX && ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dX->mutable_data<T>(ctx.GetPlace());
Tensor transformed_dX(dX->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(ctx, dX, &transformed_dX);
} else {
transformed_dX = *dX;
}
// if is_expand is false, the operation of set_zero is unnecessary
// because math::matmul will reset dx
if (is_expand) {
set_zero(dev_ctx, dX, static_cast<T>(0));
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
}
math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
for (int i = 0; i < batch_size; i++) {
Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor dx_batch = dX->Slice(i, i + 1).Resize(input_shape);
Tensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
......@@ -493,14 +767,17 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&dx_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice);
}
}
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(ctx, &transformed_dX, dX);
}
}
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
......@@ -514,8 +791,9 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddx_batch = transformed_ddX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; ++g) {
// im2col
Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
......@@ -526,8 +804,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx, ddx_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
......@@ -545,55 +823,62 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
// ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) {
ddY->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, ddY, static_cast<T>(0));
Tensor transformed_ddY(ddY->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(ctx, ddY, &transformed_ddY);
} else {
transformed_ddY = *ddY;
}
set_zero(dev_ctx, &transformed_ddY, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddy_batch =
transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) {
// gemm
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
if (ddX) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor ddx_batch =
transformed_ddX.Slice(i, i + 1).Resize(input_shape);
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(dev_ctx, ddx_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
}
// gemm
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
}
if (ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape);
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
if (!is_expand) {
col.ShareDataWith(x_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(dev_ctx, x_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col);
}
......@@ -604,6 +889,9 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
}
}
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(ctx, &transformed_ddY, ddY);
}
}
}
};
......@@ -617,22 +905,76 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(
output->dims()[1] % input->dims()[1], 0,
"The output channels must be a multiple of the input channels");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
if (channel_last) {
PADDLE_ENFORCE_EQ(
output->dims()[output->dims().size() - 1] %
input->dims()[input->dims().size() - 1],
0, "The output channels must be a multiple of the input channels");
} else {
PADDLE_ENFORCE_EQ(
output->dims()[1] % input->dims()[1], 0,
"The output channels must be a multiple of the input channels");
}
// transform tensor
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output,
&transformed_output);
} else {
transformed_input = *input;
transformed_output = *output;
}
// update padding and dilation
auto in_dims = transformed_input.dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true;
if (!is_sys_pad) {
for (size_t i = 0; i < strides.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
auto& dev_ctx = context.template device_context<DeviceContext>();
if (fuse_relu) {
math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output);
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings,
dilations, &transformed_output);
} else {
math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings,
dilations, &transformed_output);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_output,
output);
}
}
......@@ -657,24 +999,81 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
Tensor transformed_input(input->type());
Tensor transformed_output_grad(output_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
TransToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
} else {
transformed_input = *input;
transformed_output_grad = *output_grad;
}
// update padding and dilation
auto in_dims = transformed_input.dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true;
if (!is_sys_pad) {
for (size_t i = 0; i < strides.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
Tensor transformed_input_grad(input_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input_grad,
&transformed_input_grad);
} else {
transformed_input_grad = *input_grad;
}
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
if (fuse_relu) {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad);
depthwiseConvInputGrad(dev_ctx, transformed_input, filter,
transformed_output_grad, strides, paddings,
dilations, &transformed_input_grad);
} else {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad);
depthwiseConvInputGrad(dev_ctx, transformed_input, filter,
transformed_output_grad, strides, paddings,
dilations, &transformed_input_grad);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_input_grad,
input_grad);
}
}
......@@ -684,13 +1083,15 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
if (fuse_relu) {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
depthwiseConvFilterGrad(dev_ctx, transformed_input,
transformed_output_grad, strides, paddings,
dilations, filter_grad);
} else {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
depthwiseConvFilterGrad(dev_ctx, transformed_input,
transformed_output_grad, strides, paddings,
dilations, filter_grad);
}
}
}
......
......@@ -33,15 +33,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col->dims().size() == 5);
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
dilation[1] == 1) {
if (padding[0] == 0 && padding[1] == 0) {
if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
padding[3] == 0) {
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
return;
} else if (padding[0] == 1 && padding[1] == 1) {
} else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
padding[3] == 1) {
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
return;
}
......@@ -65,8 +68,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......@@ -136,8 +140,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col->dims().size() == 5);
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
......@@ -198,8 +203,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......
......@@ -34,9 +34,10 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col->dims().size() == 7);
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
......@@ -50,28 +51,35 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
// changed
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"mismatching.");
const T* vol_data = vol.data<T>();
T* col_data = col->data<T>();
......@@ -81,11 +89,11 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
......@@ -120,9 +128,10 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
......@@ -136,21 +145,29 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
......@@ -166,11 +183,11 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
......
......@@ -92,27 +92,34 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
int output_height = col->dims()[5];
int output_width = col->dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
"mismatching.");
int num_outputs =
input_channels * output_depth * output_height * output_width;
......@@ -122,9 +129,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
vol2col<T><<<blocks, threads, 0, context.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], paddings[0],
paddings[1], paddings[2], output_depth, output_height, output_width,
col->data<T>());
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>());
}
};
......@@ -218,27 +224,35 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
int output_height = col.dims()[5];
int output_width = col.dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
"mismatching.");
int num_kernels = input_channels * input_depth * input_height * input_width;
......@@ -248,9 +262,8 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
col2vol<T><<<blocks, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], paddings[0],
paddings[1], paddings[2], output_depth, output_height, output_width,
vol->data<T>());
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>());
}
};
......
......@@ -2259,11 +2259,12 @@ def conv2d(input,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
name=None,
data_format="NCHW"):
"""
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
Output are in NCHW or NHWC format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature.
Filter is in MCHW format, where M is the number of output image channels,
C is the number of input image channels, H is the height of the filter,
......@@ -2284,7 +2285,7 @@ def conv2d(input,
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`X`: Input value, a tensor with NCHW or NHWC format.
* :math:`W`: Filter value, a tensor with MCHW format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
......@@ -2314,7 +2315,7 @@ def conv2d(input,
padding mode is 'SAME' and 'VALID' can reference this link<https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleGAN/network/base_network.py#L181>`_
Args:
input (Variable): The input image with [N, C, H, W] format.
input (Variable): The input image with [N, C, H, W] or [N, H, W, C] format.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple): The filter size. If filter_size
......@@ -2324,9 +2325,14 @@ def conv2d(input,
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_height, stride_width). Otherwise,
stride_height = stride_width = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_height, padding_width). Otherwise,
padding_height = padding_width = padding. Default: padding = 0.
padding (string|int|list|tuple): The padding size. If `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
it could be in three forms: `[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`,
`padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_height, dilation_width). Otherwise,
dilation_height = dilation_width = dilation. Default: dilation = 1.
......@@ -2350,7 +2356,10 @@ def conv2d(input,
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
will be named automatically. Default: None.
data_format (str): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -2368,8 +2377,23 @@ def conv2d(input,
conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
"""
num_channels = input.shape[1]
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s. " % str(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
channel_last = (data_format == "NHWC")
num_channels = input.shape[3] if channel_last else input.shape[1]
if num_channels < 0:
raise ValueError(
"The channel dimmention of the input(%s) should be defined. "
"Received: %s." % (str(input.shape), str(num_channels)))
assert param_attr is not False, "param_attr should not be False here."
l_type = 'conv2d'
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
......@@ -2382,18 +2406,61 @@ def conv2d(input,
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
raise ValueError(
"The number of input channels must be divisible by Attr(groups). "
"Received: number of channels(%s), groups(%s)." %
(str(num_channels), str(groups)))
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
# padding
def _update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 4:
if is_list_or_tuple(padding[0]) and (data_format == "NCHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:4]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"):
if not (padding[0] == [0, 0] and padding[3] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:3]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding')
else:
padding = utils.convert_to_list(padding, 2, 'padding')
padding = [padding[0], padding[0], padding[1], padding[1]]
return padding
padding_algorithm = "EXPLICIT"
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." %
str(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0, 0, 0, 0]
elif padding == "SAME":
padding_algorithm = "SAME"
padding = [0, 0, 0, 0]
padding = _update_padding(padding, data_format)
input_shape = input.shape
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
......@@ -2423,7 +2490,9 @@ def conv2d(input,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......@@ -2442,13 +2511,14 @@ def conv3d(input,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
name=None,
data_format="NCDHW"):
"""
**Convlution3D Layer**
The convolution3D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are in NCDHW format. Where N is batch size C is the number of
Output(Output) are in NCDHW or NDHWC format. Where N is batch size C is the number of
channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature. Convlution3D is similar with Convlution2D
but adds one dimension(depth). If bias attribution and activation type are
......@@ -2463,7 +2533,7 @@ def conv3d(input,
In the above equation:
* :math:`X`: Input value, a tensor with NCDHW format.
* :math:`X`: Input value, a tensor with NCDHW or NDHWC format.
* :math:`W`: Filter value, a tensor with MCDHW format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
......@@ -2490,7 +2560,7 @@ def conv3d(input,
W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1
Args:
input (Variable): The input image with [N, C, D, H, W] format.
input (Variable): The input image with [N, C, D, H, W] or [N, D, H, W, C]format.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple): The filter size. If filter_size is a tuple,
......@@ -2500,9 +2570,15 @@ def conv3d(input,
stride (int|tuple): The stride size. If stride is a tuple, it must
contain three integers, (stride_depth, stride_height, stride_width). Otherwise,
stride_depth = stride_height = stride_width = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain three integers, (padding_depth, padding_height, padding_width). Otherwise,
padding_depth = padding_height = padding_width = padding. Default: padding = 0.
padding (string|int|list|tuple): The padding size. f `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
it could be in three forms: `[pad_depth, pad_height, pad_width]` or
`[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form
`[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NDHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain three integers, (dilation_depth, dilation_height, dilation_width). Otherwise,
dilation_depth = dilation_height = dilation_width = dilation. Default: dilation = 1.
......@@ -2527,6 +2603,9 @@ def conv3d(input,
Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
data_format (str): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -2549,22 +2628,85 @@ def conv3d(input,
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
num_channels = input.shape[1]
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s. " % str(use_cudnn))
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received "
"Attr(data_format): %s." % str(data_format))
channel_last = (data_format == "NDHWC")
num_channels = input.shape[4] if channel_last else input.shape[1]
if num_channels < 0:
raise ValueError(
"The channel dimmention of the input(%s) should be defined. "
"Received: %s." % (str(input.shape), str(num_channels)))
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
raise ValueError(
"The number of input channels must be divisible by Attr(groups). "
"Received: number of channels(%s), groups(%s)." %
(str(num_channels), str(groups)))
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 3, 'filter_size')
stride = utils.convert_to_list(stride, 3, 'stride')
padding = utils.convert_to_list(padding, 3, 'padding')
dilation = utils.convert_to_list(dilation, 3, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
def _update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 5:
if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:5]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"):
if not (padding[0] == [0, 0] and padding[4] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding')
elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding')
else:
padding = utils.convert_to_list(padding, 3, 'padding')
padding = [
padding[0], padding[0], padding[1], padding[1], padding[2],
padding[2]
]
return padding
padding_algorithm = "EXPLICIT"
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." %
str(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0, 0, 0, 0, 0, 0]
elif padding == "SAME":
padding_algorithm = "SAME"
padding = [0, 0, 0, 0, 0, 0]
padding = _update_padding(padding, data_format)
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
......@@ -2596,7 +2738,9 @@ def conv3d(input,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False
'use_mkldnn': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......
......@@ -19,29 +19,87 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
def conv2d_forward_naive(input, filter, group, conv_param):
def conv2d_forward_naive(input,
filter,
group,
conv_param,
padding_algorithm='EXPLICIT',
data_format='NCHW'):
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Unknown Attr(data_format): '%s' ."
"It can only be 'NCHW' or 'NHWC'." % str(data_format))
channel_last = (data_format == "NHWC")
if channel_last:
input = np.transpose(input, [0, 3, 1, 2])
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
f_n, f_c, f_h, f_w = filter.shape
out_n = in_n
out_c = f_n
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
sub_out_c = out_c // group
sub_f_n = f_n // group
stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
'dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
out = np.zeros((in_n, out_c, out_h, out_w))
# update pad and dilation
def _get_padding_with_SAME(input_shape, pool_size, pool_stride):
padding = []
for input_size, filter_size, stride_size in zip(input_shape, pool_size,
pool_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
ksize = filter.shape[2:4]
if padding_algorithm == "VALID":
pad = [0, 0, 0, 0]
elif padding_algorithm == "SAME":
dilation = [1, 1]
input_data_shape = []
if data_format == "NCHW":
input_data_shape = input.shape[2:4]
elif data_format == "NHWC":
input_data_shape = input.shape[1:3]
pad = _get_padding_with_SAME(input_data_shape, ksize, stride)
pad_h_0, pad_h_1 = pad[0], pad[0]
pad_w_0, pad_w_1 = pad[1], pad[1]
if len(pad) == 4:
pad_h_0, pad_h_1 = pad[0], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[3]
out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] *
(f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] *
(f_w - 1) + 1)) // stride[1]
out = np.zeros((out_n, out_c, out_h, out_w))
d_bolck_h = (dilation[0] * (f_h - 1) + 1)
d_bolck_w = (dilation[1] * (f_w - 1) + 1)
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )),
input_pad = np.pad(input, ((0, 0), (0, 0), (pad_h_0, pad_h_1),
(pad_w_0, pad_w_1)),
mode='constant',
constant_values=0)
filter_dilation = np.zeros((out_c, f_c, d_bolck_h, d_bolck_w))
filter_dilation = np.zeros((f_n, f_c, d_bolck_h, d_bolck_w))
filter_dilation[:, :, 0:d_bolck_h:dilation[0], 0:d_bolck_w:dilation[
1]] = filter
......@@ -53,16 +111,156 @@ def conv2d_forward_naive(input, filter, group, conv_param):
i * stride[0]:i * stride[0] + d_bolck_h,
j * stride[1]:j * stride[1] + d_bolck_w]
f_sub = filter_dilation[g * sub_out_c:(g + 1) *
sub_out_c, :, :, :]
f_sub = filter_dilation[g * sub_f_n:(g + 1) * sub_f_n, :, :, :]
# sub_f_n == sub_out_c
for k in range(sub_out_c):
# Multiplication of Corresponding Elements, then sum all
out[:, g * sub_out_c + k, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :],
axis=(1, 2, 3))
if channel_last:
out = np.transpose(out, [0, 2, 3, 1])
return out, in_n, out_h, out_w, out_c
def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase
def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16
def create_test_channel_last_class(parent):
class TestChannelLastCase(parent):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast")
TestChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestChannelLastCase
def create_test_cudnn_channel_last_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCudnnChannelLastCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast")
TestCudnnChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestCudnnChannelLastCase
def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent):
def init_paddings(self):
self.pad = [0, 0]
self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp")
TestPaddingSMAECase.__name__ = cls_name
globals()[cls_name] = TestPaddingSMAECase
def create_test_padding_VALID_class(parent):
class TestPaddingVALIDCase(parent):
def init_paddings(self):
self.pad = [1, 1]
self.padding_algorithm = "VALID"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp")
TestPaddingVALIDCase.__name__ = cls_name
globals()[cls_name] = TestPaddingVALIDCase
def create_test_cudnn_padding_SAME_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNPaddingSMAECase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_paddings(self):
self.pad = [1, 1]
self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingSAMEOp")
TestCUDNNPaddingSMAECase.__name__ = cls_name
globals()[cls_name] = TestCUDNNPaddingSMAECase
def create_test_cudnn_padding_VALID_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNPaddingVALIDCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_paddings(self):
self.pad = [1, 1]
self.padding_algorithm = "VALID"
cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingVALIDOp")
TestCUDNNPaddingVALIDCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNPaddingVALIDCase
class TestConv2dOp(OpTest):
def setUp(self):
self.op_type = "conv2d"
......@@ -95,6 +293,7 @@ class TestConv2dOp(OpTest):
else:
input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups,
conv2d_param)
output = output.astype(self.dtype)
......@@ -160,6 +359,9 @@ class TestConv2dOp(OpTest):
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_test_case_2(self):
pass
def init_dilation(self):
self.dilations = [1, 1]
......@@ -281,19 +483,6 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN----------------
def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase
create_test_cudnn_class(TestConv2dOp)
create_test_cudnn_class(TestWithPad)
create_test_cudnn_class(TestWithStride)
......@@ -301,45 +490,7 @@ create_test_cudnn_class(TestWithGroup)
create_test_cudnn_class(TestWith1x1)
create_test_cudnn_class(TestWithInput1x1Filter1x1)
#----------------Conv2dCUDNN----------------
def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16
#----------------Conv2dCUDNN fp16----------------
create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False)
create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
......@@ -348,7 +499,7 @@ create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
# -------TestDepthwiseConv
#----------------TestDepthwiseConv -----
class TestDepthwiseConv(TestConv2dOp):
......@@ -502,5 +653,704 @@ class TestCUDNNExhaustiveSearch(TestConv2dOp):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
# ---- test asymmetric padding ----
class TestConv2dOp_v2(OpTest):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
self.init_dilation()
self.init_data_format()
self.init_test_case()
self.init_paddings()
self.init_test_case_2()
conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False
if self.fuse_relu_before_depthwise_conv:
input = input - 0.5
input -= (input < 0) * 0.1
input += (input >= 0) * 0.1
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(
input2, filter, self.groups, conv2d_param, self.padding_algorithm,
self.data_format)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format,
'fuse_relu_before_depthwise_conv':
self.fuse_relu_before_depthwise_conv,
'exhaustive_search': self.exhaustive_search
}
self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda() and (self.use_cudnn or
self.use_cuda)
def test_check_output(self):
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_grad_with_place(
place, {'Input', 'Filter'}, 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace()
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
def init_kernel_type(self):
pass
def init_paddings(self):
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
def init_data_format(self):
self.data_format = "NCHW"
def init_test_case_2(self):
pass
class TestConv2dOp_AsyPadding(TestConv2dOp_v2):
def init_paddings(self):
self.pad = [0, 0, 1, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithPad_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_paddings(self):
self.pad = [2, 1, 3, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithStride_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_paddings(self):
self.pad = [2, 1, 3, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithGroup_AsyPadding(TestConv2dOp_v2):
def init_group(self):
self.groups = 3
class TestWith1x1_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [2, 2, 4, 0]
self.padding_algorithm = "EXPLICIT"
class TestWithDepthWise3x3_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [3, 4, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [8, f_c, 3, 3]
def init_dilation(self):
self.dilations = [2, 2]
def init_group(self):
self.groups = 4
def init_paddings(self):
self.pad = [1, 3, 2, 1]
self.padding_algorithm = "EXPLICIT"
class TestWithDepthWise5x5_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 4, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [8, f_c, 5, 5]
def init_group(self):
self.groups = 4
def init_paddings(self):
self.pad = [0, 1, 1, 0]
self.padding_algorithm = "EXPLICIT"
class TestWithDepthWise7x7_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [2, 2]
self.input_size = [2, 8, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [16, f_c, 7, 7]
def init_group(self):
self.groups = 8
def init_paddings(self):
self.pad = [1, 3, 4, 1]
self.padding_algorithm = "EXPLICIT"
class TestWithDilation_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 3, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [2, 2]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [0, 1, 3, 0]
self.padding_algorithm = "EXPLICIT"
class TestWithInput1x1Filter1x1_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [0, 3, 4, 0]
self.padding_algorithm = "EXPLICIT"
create_test_cudnn_class(TestConv2dOp_AsyPadding)
create_test_cudnn_class(TestWithPad_AsyPadding)
create_test_cudnn_class(TestWithStride_AsyPadding)
create_test_cudnn_class(TestWithGroup_AsyPadding)
create_test_cudnn_class(TestWith1x1_AsyPadding)
create_test_cudnn_class(TestWithInput1x1Filter1x1_AsyPadding)
class TestDepthwiseConv_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.use_cuda = True
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 1, 0, 1]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConv2_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.use_cuda = True
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [0, 1, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConv3_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.use_cuda = True
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 1, 0, 0]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConvWithDilation_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 1, 2, 1]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConvWithDilation2_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [0, 1, 1, 0]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConvandFuse_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [2, 1, 2, 3]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConv2andFuse_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 1, 1, 2]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConv3andFuse_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 2, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConvWithDilationandFuse_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [2, 1, 1, 0]
self.padding_algorithm = "EXPLICIT"
class TestDepthwiseConvWithDilation2andFuse_AsyPadding(TestConv2dOp_v2):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
def init_paddings(self):
self.pad = [1, 3, 1, 3]
self.padding_algorithm = "EXPLICIT"
#---------- test SAME VALID -----------
create_test_padding_SAME_class(TestConv2dOp_AsyPadding)
create_test_padding_SAME_class(TestWithPad_AsyPadding)
create_test_padding_SAME_class(TestWithStride_AsyPadding)
create_test_padding_SAME_class(TestWithGroup_AsyPadding)
create_test_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding)
create_test_padding_VALID_class(TestConv2dOp_AsyPadding)
create_test_padding_VALID_class(TestWithPad_AsyPadding)
create_test_padding_VALID_class(TestWithStride_AsyPadding)
create_test_padding_VALID_class(TestWithGroup_AsyPadding)
create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding)
create_test_cudnn_padding_SAME_class(TestConv2dOp_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWithPad_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWithStride_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWithGroup_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding)
create_test_cudnn_padding_VALID_class(TestConv2dOp_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWithPad_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWithStride_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWithGroup_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding)
# depthwise conv2d
create_test_padding_SAME_class(TestDepthwiseConv_AsyPadding)
create_test_padding_SAME_class(TestDepthwiseConvWithDilation_AsyPadding)
create_test_padding_SAME_class(TestDepthwiseConvandFuse_AsyPadding)
create_test_padding_SAME_class(TestDepthwiseConvWithDilationandFuse_AsyPadding)
create_test_padding_VALID_class(TestDepthwiseConv_AsyPadding)
create_test_padding_VALID_class(TestDepthwiseConvWithDilation_AsyPadding)
create_test_padding_VALID_class(TestDepthwiseConvandFuse_AsyPadding)
create_test_padding_VALID_class(TestDepthwiseConvWithDilationandFuse_AsyPadding)
# ------------ test channel last ---------
create_test_channel_last_class(TestConv2dOp_AsyPadding)
create_test_channel_last_class(TestWithPad_AsyPadding)
create_test_channel_last_class(TestWithGroup_AsyPadding)
create_test_channel_last_class(TestWith1x1_AsyPadding)
create_test_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding)
create_test_channel_last_class(TestDepthwiseConv_AsyPadding)
create_test_channel_last_class(TestDepthwiseConvWithDilation2_AsyPadding)
create_test_channel_last_class(TestDepthwiseConvandFuse_AsyPadding)
create_test_channel_last_class(TestDepthwiseConvWithDilationandFuse_AsyPadding)
create_test_cudnn_channel_last_class(TestConv2dOp_AsyPadding)
create_test_cudnn_channel_last_class(TestWithPad_AsyPadding)
create_test_cudnn_channel_last_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding)
# --------- test python API ---------------
class TestConv2dAPI(OpTest):
def test_api(self):
input_NHWC = fluid.layers.data(
name="input_NHWC",
shape=[2, 5, 5, 3],
append_batch_size=False,
dtype="float32")
input_NCHW = fluid.layers.data(
name="input_NCHW",
shape=[2, 3, 5, 5],
append_batch_size=False,
dtype="float32")
fluid.layers.conv2d(
input=input_NHWC,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=1,
data_format="NCHW")
fluid.layers.conv2d(
input=input_NCHW,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=[1, 2, 1, 0],
dilation=[1, 1],
groups=1,
data_format="NCHW")
fluid.layers.conv2d(
input=input_NCHW,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=[[0, 0], [0, 0], [1, 1], [1, 1]],
dilation=[1, 1],
groups=1,
data_format="NCHW")
fluid.layers.conv2d(
input=input_NHWC,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
dilation=[1, 1],
groups=1,
data_format="NHWC")
fluid.layers.conv2d(
input=input_NCHW,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding="SAME",
dilation=[1, 1],
groups=1,
data_format="NCHW")
fluid.layers.conv2d(
input=input_NCHW,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding="VALID",
dilation=[1, 1],
groups=1,
data_format="NCHW")
class TestConv2dAPI_Error(OpTest):
def test_api(self):
input = fluid.layers.data(
name="input",
shape=[2, 5, 5, 5],
append_batch_size=False,
dtype="float32")
# ValueError: cudnn
def run_1():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=1,
use_cudnn=[0],
data_format="NCHW")
self.assertRaises(ValueError, run_1)
# ValueError: data_format
def run_2():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=1,
use_cudnn=False,
data_format="NCHWC")
self.assertRaises(ValueError, run_2)
# ValueError: padding
def run_3():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding="SAMEE",
dilation=[1, 1],
groups=1,
use_cudnn=False,
data_format="NCHW")
self.assertRaises(ValueError, run_3)
def run_4():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=[[0, 1], [0, 1], [0, 1], [0, 1]],
dilation=[1, 1],
groups=1,
use_cudnn=False,
data_format="NCHW")
self.assertRaises(ValueError, run_4)
def run_5():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=[[0, 1], [0, 1], [0, 1], [0, 1]],
dilation=[1, 1],
groups=1,
use_cudnn=False,
data_format="NHWC")
self.assertRaises(ValueError, run_5)
# ValueError: channel dimmention
x = fluid.layers.data(
name="x",
shape=[2, 5, 5, -1],
append_batch_size=False,
dtype="float32")
def run_6():
fluid.layers.conv2d(
input=x,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=1,
use_cudnn=False,
data_format="NHWC")
self.assertRaises(ValueError, run_6)
# ValueError: groups
def run_7():
fluid.layers.conv2d(
input=input,
num_filters=3,
filter_size=[3, 3],
stride=[1, 1],
padding=0,
dilation=[1, 1],
groups=3,
use_cudnn=False,
data_format="NHWC")
self.assertRaises(ValueError, run_7)
if __name__ == '__main__':
unittest.main()
......@@ -19,21 +19,83 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
def conv3d_forward_naive(input, filter, group, conv_param):
def conv3d_forward_naive(input,
filter,
group,
conv_param,
padding_algorithm='EXPLICIT',
data_format="NCDHW"):
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError("Unknown Attr(data_format): '%s' ."
"It can only be 'NCDHW' or 'NDHWC'." %
str(data_format))
channel_last = (data_format == "NDHWC")
if channel_last:
input = np.transpose(input, [0, 4, 1, 2, 3])
in_n, in_c, in_d, in_h, in_w = input.shape
out_c, f_c, f_d, f_h, f_w = filter.shape
f_n, f_c, f_d, f_h, f_w = filter.shape
out_n = in_n
out_c = f_n
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
sub_out_c = out_c // group
sub_f_n = f_n // group
stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
'dilations']
out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) // stride[0]
out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) // stride[1]
out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) // stride[2]
# update pad and dilation
def _get_padding_with_SAME(input_shape, pool_size, pool_stride):
padding = []
for input_size, filter_size, stride_size in zip(input_shape, pool_size,
pool_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
ksize = filter.shape[2:5]
if padding_algorithm == "VALID":
pad = [0, 0, 0, 0, 0, 0]
elif padding_algorithm == "SAME":
dilation = [1, 1, 1]
input_data_shape = []
if data_format == "NCDHW":
input_data_shape = input.shape[2:5]
elif data_format == "NDHWC":
input_data_shape = input.shape[1:4]
pad = _get_padding_with_SAME(input_data_shape, ksize, stride)
pad_d_0, pad_d_1 = pad[0], pad[0]
pad_h_0, pad_h_1 = pad[1], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[2]
if len(pad) == 6:
pad_d_0, pad_d_1 = pad[0], pad[1]
pad_h_0, pad_h_1 = pad[2], pad[3]
pad_w_0, pad_w_1 = pad[4], pad[5]
out_d = 1 + (in_d + pad_d_0 + pad_d_1 - (dilation[0] *
(f_d - 1) + 1)) // stride[0]
out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[1] *
(f_h - 1) + 1)) // stride[1]
out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[2] *
(f_w - 1) + 1)) // stride[2]
out = np.zeros((in_n, out_c, out_d, out_h, out_w))
......@@ -41,12 +103,12 @@ def conv3d_forward_naive(input, filter, group, conv_param):
d_bolck_h = (dilation[1] * (f_h - 1) + 1)
d_bolck_w = (dilation[2] * (f_w - 1) + 1)
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ),
(pad[2], )),
input_pad = np.pad(input, ((0, 0), (0, 0), (pad_d_0, pad_d_1),
(pad_h_0, pad_h_1), (pad_w_0, pad_w_1)),
mode='constant',
constant_values=0)
filter_dilation = np.zeros((out_c, f_c, d_bolck_d, d_bolck_h, d_bolck_w))
filter_dilation = np.zeros((f_n, f_c, d_bolck_d, d_bolck_h, d_bolck_w))
filter_dilation[:, :, 0:d_bolck_d:dilation[0], 0:d_bolck_h:dilation[1], 0:
d_bolck_w:dilation[2]] = filter
......@@ -60,16 +122,114 @@ def conv3d_forward_naive(input, filter, group, conv_param):
i * stride[1]:i * stride[1] + d_bolck_h,
j * stride[2]:j * stride[2] + d_bolck_w]
f_sub = filter_dilation[g * sub_out_c:(g + 1) *
sub_out_c, :, :, :, :]
f_sub = filter_dilation[g * sub_f_n:(g + 1) *
sub_f_n, :, :, :, :]
for k in range(sub_out_c):
out[:, g * sub_out_c + k, d, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :, :],
axis=(1, 2, 3, 4))
if channel_last:
out = np.transpose(out, [0, 2, 3, 4, 1])
return out
def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase
def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent):
def init_paddings(self):
self.pad = [0, 0, 0]
self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp")
TestPaddingSMAECase.__name__ = cls_name
globals()[cls_name] = TestPaddingSMAECase
def create_test_padding_VALID_class(parent):
class TestPaddingVALIDCase(parent):
def init_paddings(self):
self.pad = [1, 1, 1]
self.padding_algorithm = "VALID"
cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp")
TestPaddingVALIDCase.__name__ = cls_name
globals()[cls_name] = TestPaddingVALIDCase
def create_test_cudnn_padding_SAME_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNPaddingSMAECase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_paddings(self):
self.pad = [1, 1, 1]
self.padding_algorithm = "SAME"
cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingSAMEOp")
TestCUDNNPaddingSMAECase.__name__ = cls_name
globals()[cls_name] = TestCUDNNPaddingSMAECase
def create_test_cudnn_padding_VALID_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNPaddingVALIDCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_paddings(self):
self.pad = [1, 1, 1]
self.padding_algorithm = "VALID"
cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingVALIDOp")
TestCUDNNPaddingVALIDCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNPaddingVALIDCase
def create_test_channel_last_class(parent):
class TestChannelLastCase(parent):
def init_data_format(self):
self.data_format = "NDHWC"
def init_test_case_2(self):
N, C, D, H, W = self.input_size
self.input_size = [N, D, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast")
TestChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestChannelLastCase
def create_test_cudnn_channel_last_class(parent):
class TestCudnnChannelLastCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def init_data_format(self):
self.data_format = "NDHWC"
def init_test_case_2(self):
N, C, D, H, W = self.input_size
self.input_size = [N, D, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast")
TestCudnnChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestCudnnChannelLastCase
class TestConv3dOp(OpTest):
def setUp(self):
self.op_type = "conv3d"
......@@ -90,8 +250,11 @@ class TestConv3dOp(OpTest):
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(input, filter, self.groups,
conv3d_param).astype(self.dtype)
output = conv3d_forward_naive(
input,
filter,
self.groups,
conv3d_param, ).astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
......@@ -150,6 +313,9 @@ class TestConv3dOp(OpTest):
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_test_case_2(self):
pass
def init_dilation(self):
self.dilations = [1, 1, 1]
......@@ -184,7 +350,7 @@ class TestWith1x1(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCHW
self.input_size = [2, 3, 4, 4, 4]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1, 1]
......@@ -200,7 +366,7 @@ class TestWithInput1x1Filter1x1(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 1, 1, 1] # NCHW
self.input_size = [2, 3, 1, 1, 1]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1, 1]
......@@ -216,7 +382,7 @@ class TestWithDilation(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 6, 6, 6] # NCDHW
self.input_size = [2, 3, 6, 6, 6]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 2, 2, 2]
......@@ -228,7 +394,9 @@ class TestWithDilation(TestConv3dOp):
self.groups = 3
#----------------Conv3dCUDNN----------------
#---------------- Conv3dCUDNN ----------------
class TestCUDNN(TestConv3dOp):
def init_kernel_type(self):
self.use_cudnn = True
......@@ -320,11 +488,435 @@ class TestCUDNNExhaustiveSearch(TestCUDNN):
self.exhaustive_search = True
# ---- test asymmetric padding ----
class TestConv3dOp_2(OpTest):
def setUp(self):
self.op_type = "conv3d"
self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "NCDHW"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
self.init_dilation()
self.init_data_format()
self.init_test_case()
self.init_paddings()
self.init_test_case_2()
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(input, filter, self.groups, conv3d_param,
self.padding_algorithm,
self.data_format).astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
self.outputs = {'Output': output}
def has_cudnn(self):
return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self):
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
self.check_grad_with_place(
place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_test_case_2(self):
pass
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_group(self):
self.groups = 1
def init_kernel_type(self):
pass
def init_paddings(self):
self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
def init_data_format(self):
self.data_format = "NCDHW"
class TestConv3dOp_AsyPadding(TestConv3dOp_2):
def init_paddings(self):
self.pad = [1, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestCase1_AsyPadding(TestConv3dOp_2):
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_paddings(self):
self.pad = [0, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithGroup1_AsyPadding(TestConv3dOp_2):
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [1, 1, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithGroup2_AsyPadding(TestConv3dOp_2):
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [1, 1, 0, 1, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestWith1x1_AsyPadding(TestConv3dOp_2):
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [0, 0, 1, 0, 0, 2]
self.padding_algorithm = "EXPLICIT"
class TestWithDilation_AsyPadding(TestConv3dOp_2):
def init_test_case(self):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 6, 6, 6]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 2, 2, 2]
def init_dilation(self):
self.dilations = [2, 2, 2]
def init_group(self):
self.groups = 3
def init_paddings(self):
self.pad = [0, 0, 1, 0, 1, 0]
self.padding_algorithm = "EXPLICIT"
create_test_cudnn_class(TestConv3dOp_AsyPadding)
create_test_cudnn_class(TestWithGroup1_AsyPadding)
create_test_cudnn_class(TestWithGroup2_AsyPadding)
create_test_cudnn_class(TestWith1x1_AsyPadding)
create_test_cudnn_class(TestWithDilation_AsyPadding)
create_test_padding_SAME_class(TestConv3dOp_AsyPadding)
create_test_padding_SAME_class(TestWithGroup1_AsyPadding)
create_test_padding_SAME_class(TestWith1x1_AsyPadding)
create_test_padding_VALID_class(TestConv3dOp_AsyPadding)
create_test_padding_VALID_class(TestWithGroup1_AsyPadding)
create_test_padding_VALID_class(TestWith1x1_AsyPadding)
create_test_cudnn_padding_SAME_class(TestConv3dOp_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWithGroup1_AsyPadding)
create_test_cudnn_padding_SAME_class(TestWith1x1_AsyPadding)
create_test_cudnn_padding_VALID_class(TestConv3dOp_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWithGroup1_AsyPadding)
create_test_cudnn_padding_VALID_class(TestWith1x1_AsyPadding)
create_test_channel_last_class(TestConv3dOp_AsyPadding)
create_test_channel_last_class(TestWithGroup1_AsyPadding)
create_test_channel_last_class(TestWith1x1_AsyPadding)
create_test_channel_last_class(TestConv3dOp_AsyPadding)
create_test_channel_last_class(TestWithGroup1_AsyPadding)
create_test_channel_last_class(TestWith1x1_AsyPadding)
create_test_cudnn_channel_last_class(TestConv3dOp_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup1_AsyPadding)
create_test_cudnn_channel_last_class(TestWith1x1_AsyPadding)
create_test_cudnn_channel_last_class(TestConv3dOp_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup1_AsyPadding)
create_test_cudnn_channel_last_class(TestWith1x1_AsyPadding)
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv3d"
# --------- test python API ---------------
class TestConv3dAPI(OpTest):
def test_api(self):
input_NDHWC = fluid.layers.data(
name="input_NDHWC",
shape=[2, 5, 5, 5, 3],
append_batch_size=False,
dtype="float32")
input_NCDHW = fluid.layers.data(
name="input_NCDHW",
shape=[2, 3, 5, 5, 3],
append_batch_size=False,
dtype="float32")
fluid.layers.conv3d(
input=input_NDHWC,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=0,
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[1, 2, 1, 0, 1, 0],
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[[0, 0], [0, 0], [1, 1], [1, 1], [1, 1]],
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NDHWC,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]],
dilation=[1, 1, 1],
groups=1,
data_format="NDHWC")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding="SAME",
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
fluid.layers.conv3d(
input=input_NCDHW,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding="VALID",
dilation=[1, 1, 1],
groups=1,
data_format="NCDHW")
class TestConv3dAPI_Error(OpTest):
def test_api(self):
input = fluid.layers.data(
name="input",
shape=[2, 5, 5, 5, 4],
append_batch_size=False,
dtype="float32")
# ValueError: cudnn
def run_1():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=[0],
data_format="NCDHW")
self.assertRaises(ValueError, run_1)
# ValueError: data_format
def run_2():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=[3, 3, 3],
stride=[1, 1, 1],
padding=0,
dilation=[1, 1, 1],
groups=1,
use_cudnn=False,
data_format="NCHWC")
self.assertRaises(ValueError, run_2)
# ValueError: padding
def run_3():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding="SAMEE",
dilation=1,
groups=1,
use_cudnn=False,
data_format="NCDHW")
self.assertRaises(ValueError, run_3)
def run_4():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=[[0, 1], [0, 0], [0, 1], [0, 1], [0, 1]],
dilation=1,
groups=1,
use_cudnn=False,
data_format="NCDHW")
self.assertRaises(ValueError, run_4)
def run_5():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=0,
stride=0,
padding=[[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]],
dilation=1,
groups=1,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_5)
# ValueError: channel dimmention
x = fluid.layers.data(
name="x",
shape=[2, 5, 5, 5, -1],
append_batch_size=False,
dtype="float32")
def run_6():
fluid.layers.conv3d(
input=x,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_6)
# ValueError: groups
def run_7():
fluid.layers.conv3d(
input=input,
num_filters=3,
filter_size=3,
stride=1,
padding=0,
dilation=1,
groups=3,
use_cudnn=False,
data_format="NDHWC")
self.assertRaises(ValueError, run_7)
if __name__ == '__main__':
unittest.main()
......@@ -28,11 +28,38 @@ from decorator_helper import prog_scope
class TestConvDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 7, 8]
shape = [2, 4, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, bias_attr=False)
y = layers.conv2d(x, 2, 1, groups=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
places = []
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConvDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 2, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
......@@ -53,11 +80,11 @@ class TestConvDoubleGradCheck(unittest.TestCase):
class TestConvDoubleGradCheckTest1(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 4, 5]
shape = [2, 3, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, padding=1, bias_attr=False)
y = layers.conv2d(x, 2, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
......@@ -82,7 +109,7 @@ class TestConv3DDoubleGradCheck(unittest.TestCase):
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 4, 1, bias_attr=False)
y = layers.conv3d(x, 2, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
......@@ -107,7 +134,326 @@ class TestConv3DDoubleGradCheckTest1(unittest.TestCase):
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 4, 1, padding=1, bias_attr=False)
y = layers.conv3d(x, 2, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv2DoubleGradCheck_AsyPadding(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 0, 0, 1],
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv2DoubleGradCheck_PaddingSAME(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(
input=x,
num_filters=2,
filter_size=1,
padding="SAME",
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv2DoubleGradCheck_PaddingVALID(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(
input=x,
num_filters=2,
filter_size=1,
padding="VALID",
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv2DoubleGradCheck_ChannelLast(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 1],
bias_attr=False,
use_cudnn=True,
groups=1,
data_format="NHWC")
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv2DoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 0, 1, 0],
bias_attr=False,
use_cudnn=True,
groups=1,
data_format="NHWC")
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DDoubleGradCheck_AsyPadding(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 2, 2, 2]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 0, 0, 1, 1, 2],
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DoubleGradCheck_PaddingSAME(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 2, 2, 2]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(
input=x,
num_filters=2,
filter_size=1,
padding="SAME",
groups=1,
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DoubleGradCheck_PaddingVALID(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 3, 3, 2]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(
input=x,
num_filters=2,
filter_size=1,
padding="VALID",
bias_attr=False,
use_cudnn=True)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DDoubleGradCheck_ChannelLast(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 2, 2, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 1, 1],
bias_attr=False,
use_cudnn=True,
groups=1,
data_format="NDHWC")
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3dDoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 2, 2, 2, 3]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(
input=x,
num_filters=2,
filter_size=1,
padding=[1, 0, 1, 0, 1, 0],
bias_attr=False,
use_cudnn=True,
groups=1,
data_format="NDHWC")
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册