提交 cf6919bf 编写于 作者: Z Zhang Ting 提交者: hong

conv_transpose supports channel_last input, test=develop, test=document_preview (#20072)

上级 c9139c3d
......@@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '02972097e089629efdb0ed9404fd36ae'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'ab58296b567bf0c686084add7f3280a4'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'fe15dbfb17d97d3d29b2fa7ee6390ee6'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCHW')), ('document', '9391d75358b6cba0cc5d22a01a223420'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '74bce3cd4224e6ff133d54508dc7f150'))
paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0'))
paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50'))
paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62'))
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
......@@ -29,26 +31,189 @@ using DataLayout = platform::DataLayout;
static constexpr size_t kConvCUDNNWorkspaceLimitBytes = 1024 * 1024 * 1024;
template <typename T, int D>
static void DataTranspose(const framework::ExecutionContext& ctx,
const Tensor* input, Tensor* output,
const std::vector<int>& axis, int flag = 0) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::Transpose<platform::CUDADeviceContext, T, D> transpose;
auto in_dims = input->dims();
std::vector<int64_t> input_transpose_vec;
for (size_t i = 0; i < axis.size(); ++i) {
if (flag == 0)
input_transpose_vec.push_back(in_dims[axis[i]]);
else
input_transpose_vec.push_back(in_dims[i]);
}
framework::DDim input_transpose_dims(
framework::make_ddim(input_transpose_vec));
output->mutable_data<T>(input_transpose_dims, ctx.GetPlace());
transpose(dev_ctx, *input, output, axis);
}
static inline bool IsSymmetricPadding(const std::vector<int>& paddings,
const int data_dim) {
bool is_sys_pad = true;
if (paddings.size() == data_dim * 2) {
for (size_t i = 0; i < data_dim; ++i) {
if (paddings[2 * i] != paddings[2 * i + 1]) {
is_sys_pad = false;
return is_sys_pad;
}
}
}
return is_sys_pad;
}
template <typename T>
class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first
Tensor input_transpose;
std::vector<int> input_vec = framework::vectorize<int>(input->dims());
std::vector<int> output_vec = framework::vectorize<int>(output->dims());
if (data_layout == DataLayout::kNHWC) {
if (strides.size() == 2U) {
std::vector<int> axis = {0, 3, 1, 2};
for (size_t i = 0; i < axis.size(); ++i) {
input_vec[i] = input->dims()[axis[i]];
output_vec[i] = output->dims()[axis[i]];
}
DataTranspose<T, 4>(ctx, input, &input_transpose, axis);
} else if (strides.size() == 3U) {
std::vector<int> axis = {0, 4, 1, 2, 3};
for (size_t i = 0; i < axis.size(); ++i) {
input_vec[i] = input->dims()[axis[i]];
output_vec[i] = output->dims()[axis[i]];
}
DataTranspose<T, 5>(ctx, input, &input_transpose, axis);
}
} else {
input_transpose = *input;
}
// update padding and dilation
auto in_dims = input_transpose.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
if (!is_sys_pad) {
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = input_transpose.dims()[0];
new_input_shape_vec[1] = input_transpose.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
input_transpose.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_input =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
const int rank = input_transpose.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, input_transpose, pad_value, &transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, input_transpose, pad_value, &transformed_input);
} break;
default:
PADDLE_ENFORCE_EQ(
rank == 4 || rank == 5, true,
"Op(ConvTranspose) only supports 4-D or 5-D input Tensor.");
}
} else {
transformed_input = input_transpose;
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
std::vector<int64_t> starts(data_dim, 0);
std::vector<int64_t> ends(data_dim, 0);
std::vector<int64_t> axes(data_dim, 0);
for (size_t i = 0; i < data_dim; ++i) {
starts[i] = input_pad[2 * i + 4] * (strides[i] + 1);
ends[i] = starts[i] + output_vec[i + 2];
axes[i] = i + 2;
}
const T* input_data = transformed_input.data<T>();
input_vec = framework::vectorize<int>(transformed_input.dims());
std::vector<int> transformed_output_vec = output_vec;
for (size_t i = 0; i < data_dim; ++i) {
transformed_output_vec[i + 2] =
output_vec[i + 2] +
(input_pad[2 * i + 4] + input_pad[2 * i + 5]) * strides[i] -
2 * padding_common[i] + paddings[2 * i] + paddings[2 * i + 1];
}
Tensor transformed_output;
if (!is_sys_pad) {
DDim transformed_output_shape(
framework::make_ddim(transformed_output_vec));
transformed_output.mutable_data<T>(transformed_output_shape,
ctx.GetPlace());
} else {
output->mutable_data<T>(ctx.GetPlace());
transformed_output.ShareDataWith(*output);
transformed_output.Resize(framework::make_ddim(transformed_output_vec));
}
T* transformed_output_data = transformed_output.data<T>();
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
......@@ -63,16 +228,16 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
}
// (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, input_vec, groups);
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, transformed_output_vec, groups);
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
conv_desc.descriptor<T>(padding_common, strides, dilations);
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
......@@ -99,8 +264,10 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
cudnn_output_desc, algo, &workspace_size_in_bytes));
// ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups;
int output_offset = output->numel() / output->dims()[0] / groups;
int input_offset =
transformed_input.numel() / transformed_input.dims()[0] / groups;
int output_offset =
transformed_output.numel() / transformed_output.dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
......@@ -110,10 +277,34 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
cudnn_output_desc, transformed_output_data + output_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
if (!is_sys_pad && strides.size() == 2U) {
Slice<paddle::platform::CUDADeviceContext, T, 4>(
ctx, &transformed_output, output, starts, ends, axes);
} else if (!is_sys_pad && strides.size() == 3U) {
Slice<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_output, output, starts, ends, axes);
}
if (data_layout == DataLayout::kNHWC) {
Tensor output_transpose;
Tensor output_nchw;
output_nchw.ShareDataWith(*output);
output_nchw.Resize(framework::make_ddim(output_vec));
if (strides.size() == 2U) {
std::vector<int> axis = {0, 2, 3, 1};
DataTranspose<T, 4>(ctx, &output_nchw, &output_transpose, axis);
*output = output_transpose;
} else if (strides.size() == 3U) {
std::vector<int> axis = {0, 2, 3, 4, 1};
DataTranspose<T, 5>(ctx, &output_nchw, &output_transpose, axis);
*output = output_transpose;
}
}
}
};
......@@ -128,8 +319,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const T* input_data = input->data<T>();
const T* output_grad_data = output_grad->data<T>();
const T* filter_data = filter->data<T>();
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
......@@ -137,27 +326,141 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first
Tensor input_transpose;
Tensor output_grad_transpose;
std::vector<int> input_vec = framework::vectorize<int>(input->dims());
std::vector<int> output_vec =
framework::vectorize<int>(output_grad->dims());
if (data_layout == DataLayout::kNHWC) {
if (strides.size() == 2U) {
std::vector<int> axis = {0, 3, 1, 2};
for (size_t i = 0; i < axis.size(); ++i) {
input_vec[i] = input->dims()[axis[i]];
output_vec[i] = output_grad->dims()[axis[i]];
}
DataTranspose<T, 4>(ctx, input, &input_transpose, axis);
DataTranspose<T, 4>(ctx, output_grad, &output_grad_transpose, axis);
} else if (strides.size() == 3U) {
std::vector<int> axis = {0, 4, 1, 2, 3};
for (size_t i = 0; i < axis.size(); ++i) {
input_vec[i] = input->dims()[axis[i]];
output_vec[i] = output_grad->dims()[axis[i]];
}
DataTranspose<T, 5>(ctx, input, &input_transpose, axis);
DataTranspose<T, 5>(ctx, output_grad, &output_grad_transpose, axis);
}
} else {
input_transpose = *input;
output_grad_transpose = *output_grad;
}
// update padding and dilation
auto in_dims = input_transpose.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = IsSymmetricPadding(paddings, data_dim);
std::vector<int> input_pad(input_transpose.dims().size() * 2, 0);
Tensor transformed_output_grad;
std::vector<int> padding_common(data_dim, 0);
if (!is_sys_pad) {
std::vector<int> padding_diff(data_dim);
std::vector<int> new_output_grad_shape_vec(data_dim + 2);
new_output_grad_shape_vec[0] = output_grad_transpose.dims()[0];
new_output_grad_shape_vec[1] = output_grad_transpose.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_output_grad_shape_vec[i + 2] =
output_grad_transpose.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_output_grad_shape(
framework::make_ddim(new_output_grad_shape_vec));
transformed_output_grad.Resize(new_output_grad_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_output_grad =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_output_grad_shape, dev_ctx);
const int rank = input_transpose.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, output_grad_transpose, pad_value,
&transformed_output_grad);
} break;
default:
PADDLE_ENFORCE_EQ(
rank == 4 || rank == 5, true,
"Op(ConvTranspose) only supports 4-D or 5-D input Tensor.");
}
} else {
transformed_output_grad = output_grad_transpose;
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* input_data = input_transpose.data<T>();
const T* output_grad_data = transformed_output_grad.data<T>();
output_vec = framework::vectorize<int>(transformed_output_grad.dims());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
// Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, input_vec, groups);
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output_grad->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, output_vec, groups);
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
conv_desc.descriptor<T>(padding_common, strides, dilations);
// ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionFwdAlgo_t data_algo;
......@@ -204,8 +507,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups;
int output_grad_offset =
output_grad->numel() / output_grad->dims()[0] / groups;
int output_grad_offset = transformed_output_grad.numel() /
transformed_output_grad.dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
......@@ -223,6 +526,24 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
if (data_layout == DataLayout::kNHWC) {
Tensor input_grad_transpose;
Tensor input_grad_nchw;
input_grad_nchw.ShareDataWith(*input_grad);
input_grad_nchw.Resize(framework::make_ddim(input_vec));
if (strides.size() == 2U) {
std::vector<int> axis = {0, 2, 3, 1};
DataTranspose<T, 4>(ctx, &input_grad_nchw, &input_grad_transpose,
axis);
*input_grad = input_grad_transpose;
} else if (strides.size() == 3U) {
std::vector<int> axis = {0, 2, 3, 4, 1};
DataTranspose<T, 5>(ctx, &input_grad_nchw, &input_grad_transpose,
axis);
*input_grad = input_grad_transpose;
}
}
}
// ------------------- cudnn conv backward filter ---------------------
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -25,13 +26,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
using DataLayout = framework::DataLayout;
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvTransposeOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
......@@ -41,52 +44,75 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_format"));
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
"ConvTransposeOp input dimension and filter dimension "
"should be the same.");
PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
if (output_size.size())
PADDLE_ENFORCE_EQ(output_size.size(), strides.size(),
"ConvTransposeOp output_size dimension and strides "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
const int64_t C =
(data_layout == DataLayout::kNCHW ? in_dims[1]
: in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ(
C, filter_dims[0],
"The number of input channels of Op(ConvTransposeOp) should "
"be equal to the number of filter's channels.");
framework::DDim in_data_dims;
if (data_layout == DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (data_layout == DataLayout::kNCHW) {
output_shape.push_back(filter_dims[1] * groups);
}
const int offset = (data_layout == DataLayout::kNCHW ? 2 : 1);
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
auto infer_shape =
(in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + filter_extent;
auto infer_shape = (in_dims[i + offset] - 1) * strides[i] -
paddings[2 * i] - paddings[2 * i + 1] + filter_extent;
if (output_size.size()) {
PADDLE_ENFORCE((output_size[i] >= infer_shape &&
output_size[i] < infer_shape + strides[i]),
"ConvTransposeOp output_size should be "
"in appropriate range.");
PADDLE_ENFORCE_EQ((output_size[i] >= infer_shape &&
output_size[i] < infer_shape + strides[i]),
true,
"output_size of Op(ConvTransposeOp) should be "
"in appropriate range.");
output_shape.push_back(output_size[i]);
} else {
output_shape.push_back(infer_shape);
}
}
if (data_layout == DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
......@@ -115,12 +141,11 @@ void Conv2DTransposeOpMaker::Make() {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H is the height of the feature, and "
"W is the width of the feature.");
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW or NHWC. Where N is batch size, "
"C is the number of input channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput(
"Filter",
"(Tensor) The filter tensor of convolution transpose operator. "
......@@ -137,7 +162,7 @@ void Conv2DTransposeOpMaker::Make() {
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW.");
"The format of output tensor is the same as input tensor.");
AddAttr<std::vector<int>>("output_size",
"(vector<int> default: []), the "
"size of the output tensor")
......@@ -182,10 +207,15 @@ void Conv2DTransposeOpMaker::Make() {
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>("workspace_size_MB",
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......@@ -199,7 +229,7 @@ Convolution2D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
Input(Input) and output(Output) are in NCHW or NHWC format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
C is the number of output feature channels, H is the height of the filter,
......@@ -216,19 +246,19 @@ For an example:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1
H_{out} = (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1
$$
)DOC");
}
void Conv3DTransposeOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and "
"W is the width of the feature.");
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW or NDHWC. Where N is batch "
"size, C is the number of channels, D is the depth of the feature, "
"H is the height of the feature, and W is the width of the feature.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is MCDHW, where M is the number of "
......@@ -240,7 +270,7 @@ void Conv3DTransposeOpMaker::Make() {
"the convolution3d transpose scenario.");
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is also NCDHW."
"The format of output tensor is the same as input tensor."
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature.");
......@@ -278,10 +308,15 @@ void Conv3DTransposeOpMaker::Make() {
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>("workspace_size_MB",
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......@@ -295,7 +330,7 @@ Convolution3D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
Input(Input) and output(Output) are in NCDHW or NDHWC format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature.
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
......@@ -313,9 +348,9 @@ Example:
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Where
$$
D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + dilations[2] * (W_f - 1) + 1
D_{out} = (D_{in} - 1) * strides[0] - pad_depth_front - pad_depth_back + dilations[0] * (D_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[1] - pad_height_top - pad_height_bottom + dilations[1] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[2] - pad_width_left - pad_width_right + dilations[2] * (W_f - 1) + 1
$$
)DOC");
}
......@@ -348,8 +383,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
......
......@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
......@@ -27,6 +30,94 @@ namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
const Tensor* input, Tensor* out,
const std::vector<int64_t>& begin_vec,
const std::vector<int64_t>& end_vec,
const std::vector<int64_t>& axes_vec) {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto in_dims = input->dims();
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = in_dims[i];
}
std::vector<int64_t> out_shape_vec = framework::vectorize(in_dims);
for (size_t i = 0; i < axes_vec.size(); ++i) {
offsets[axes_vec[i]] = begin_vec[i];
extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
}
framework::DDim out_dims(framework::make_ddim(out_shape_vec));
out->mutable_data<T>(out_dims, context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
out_t.device(place) = in_t.slice(offsets, extents);
out->Resize(out_dims);
}
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
const Tensor* input, Tensor* out, int64_t begin_idx,
int64_t end_idx, int64_t axes) {
std::vector<int64_t> begin_vec = {begin_idx};
std::vector<int64_t> end_vec = {end_idx};
std::vector<int64_t> axes_vec = {axes};
Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilation,
const std::string padding_algorithm,
const framework::DDim data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
// set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims);
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
PADDLE_ENFORCE_EQ(
data_dims.size() * 2, paddings->size(),
"Paddings size should be the same or twice as the input data size.");
}
// when padding_algorithm is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
int pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilation->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -63,6 +154,10 @@ template <typename DeviceContext, typename T>
class GemmConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter");
......@@ -72,28 +167,54 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
int groups = context.Attr<int>("groups");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
auto out_dims = output->dims();
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
// input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c/g, k_h, k_w, h, w} or {c/g, k_d, k_h, k_w, d, h, w}
// col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = output->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
if (data_layout == framework::DataLayout::kNCHW) {
col_shape_vec[0] = out_dims[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
} else {
col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
}
}
DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w)
// size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
Tensor col;
......@@ -105,15 +226,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
// output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
// output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
DDim output_shape =
framework::slice_ddim(output->dims(), 1, output->dims().size());
// input matrix size: (m, h * w) or (m, d * h * w)
DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
}
// filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
} else {
filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
}
filter.Resize(filter_matrix_shape);
output->mutable_data<T>(context.GetPlace());
......@@ -122,43 +255,84 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, output, static_cast<T>(0));
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
int in_step =
(data_layout == framework::DataLayout::kNCHW
? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int out_step =
(data_layout == framework::DataLayout::kNCHW
? static_cast<int>(out_dims[1]) / groups
: static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol;
math::ConcatFunctor<DeviceContext, T> concat_functor;
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
size_t D = input->dims().size();
for (int i = 0; i < batch_size; i++) {
// batch with size (m, h * w) or (m, d * h * w)
// batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
// batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
// output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
// output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
std::vector<Tensor> output_batch_vec;
for (int g = 0; g < groups; g++) {
Tensor in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
int64_t start = g * in_step;
int64_t end = (g + 1) * in_step;
int axes = (data_layout == framework::DataLayout::kNCHW ? 0 : 1);
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice, out_slice;
// col_matrix = filter_slice * input_slice
// of shape (c/g * k_h * k_w, h * w)
// or (c/g * k_d * k_h * k_w, d * h * w)
blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
// of shape (o_c/g * k_h * k_w, h * w)
// or (o_c/g * k_d * k_h * k_w, d * h * w)
if (data_layout == framework::DataLayout::kNCHW) {
in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
} else {
Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
end, axes);
start = g * out_step;
end = (g + 1) * out_step;
axes = D - 2;
if (D == 4U) {
Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
start, end, axes);
} else if (D == 5U) {
Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
start, end, axes);
}
blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
}
if (data_dim == 2U) {
// col2im: col_matrix -> dy
// from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w)
// from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
// o_c/g)
col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&out_slice);
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&out_slice, data_layout);
} else if (data_dim == 3U) {
// col2vol: col_matrix -> dy
// from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w)
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice);
// from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
// or (o_d, o_h, o_w, o_c/g)
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
data_layout);
}
output_batch_vec.push_back(out_slice);
}
if (data_layout == framework::DataLayout::kNHWC) {
concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
&output_batch);
}
}
}
......@@ -168,6 +342,10 @@ template <typename DeviceContext, typename T>
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
......@@ -185,41 +363,84 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
int groups = context.Attr<int>("groups");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
auto out_grad_dims = output_grad->dims();
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
// input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
// col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = output_grad->dims()[1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
if (data_layout == framework::DataLayout::kNCHW) {
col_shape_vec[0] = out_grad_dims[1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
} else {
col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
}
}
DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
// size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
// output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
// output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
output_grad->dims().size());
// input matrix size: (m, h * w) or (m, d * h * w)
DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
}
// filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0] / groups};
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
} else {
filter_matrix_shape = {in_dims[in_dims.size() - 1],
col_matrix_shape[0] / groups};
}
filter.Resize(filter_matrix_shape);
int in_step = static_cast<int>(input->dims()[1]) / groups;
int in_step =
(data_layout == framework::DataLayout::kNCHW
? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
// convolution transpose grad on input:
......@@ -242,75 +463,136 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::ConcatFunctor<DeviceContext, T> concat_functor;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
}
if (filter_grad) { // filter size (m, c/g, k_h, k_w)
if (filter_grad) { // filter_grad_ size (i_c, o_c/g, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
}
size_t D = input->dims().size();
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_h * o_w)
// batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
// channel_first
// batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
// channel_last
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
if (data_dim == 2U) {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
// from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
// channel_first
// from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
// channel_last
im2col(dev_ctx, output_grad_batch, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
std::vector<int>{paddings[0], paddings[2], paddings[1],
paddings[3]},
&col, data_layout);
} else if (data_dim == 3U) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
// from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
// i_w) for channel_first
// from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
// k_w) for channel_last
vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
&col);
&col, data_layout);
}
if (input_grad) {
// batch with size (m, h, w)
// batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
Tensor input_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w)
// (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
// * i_w)
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
// (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
// i_w) -> (i_c,
// i_d, i_h, i_w)
// gemm: dx = dy^T * filter^T for channel_last
std::vector<Tensor> input_grad_batch_vec;
for (int g = 0; g < groups; g++) {
Tensor input_grad_slice =
input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
// input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
// for channel_first
// input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
// for channel_last
// filter_slice: (i_c/g, o_c/g * k_h * k_w)
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
// col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
// k_h * k_w, d * h * w)
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
blas.MatMul(filter_slice, false, col_matrix_slice, false,
static_cast<T>(1.0), &input_grad_slice,
static_cast<T>(0.0));
if (data_layout == framework::DataLayout::kNCHW) {
Tensor input_grad_slice =
input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(filter_slice, false, col_matrix_slice, false,
static_cast<T>(1.0), &input_grad_slice,
static_cast<T>(0.0));
} else {
Tensor input_grad_slice;
Slice<DeviceContext, T, 2>(context, &input_grad_batch,
&input_grad_slice, g * in_step,
(g + 1) * in_step, 1);
blas.MatMul(col_matrix_slice, true, filter_slice, true,
static_cast<T>(1.0), &input_grad_slice,
static_cast<T>(0.0));
DDim input_grad_slice_shape;
if (data_dim == 2U) {
input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
} else {
input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
in_step};
}
input_grad_slice =
input_grad_slice.Resize(input_grad_slice_shape);
input_grad_batch_vec.push_back(input_grad_slice);
}
}
if (data_layout == framework::DataLayout::kNHWC) {
concat_functor(dev_ctx, input_grad_batch_vec,
static_cast<int>(D - 2), &input_grad_batch);
}
}
if (filter_grad) {
// input batch
// input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: d_filter = x * dy^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w)
// (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
// * k_w)
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
// -> (i_c, o_c * k_d *
// k_h * k_w)
// gemm: d_filter = x^T * dy^T for channel_last
for (int g = 0; g < groups; g++) {
Tensor in_batch_slice =
in_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor filter_grad_slice =
filter_grad_.Slice(g * in_step, (g + 1) * in_step);
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
static_cast<T>(1.0), &filter_grad_slice,
static_cast<T>(1.0));
if (data_layout == framework::DataLayout::kNCHW) {
Tensor in_batch_slice =
in_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
static_cast<T>(1.0), &filter_grad_slice,
static_cast<T>(1.0));
} else {
Tensor in_batch_slice;
Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
g * in_step, (g + 1) * in_step, 1);
blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
static_cast<T>(1.0), &filter_grad_slice,
static_cast<T>(1.0));
}
}
}
}
......@@ -322,6 +604,10 @@ template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
......@@ -333,10 +619,27 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1);
}
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
......@@ -344,8 +647,10 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
dilations, output);
depthwiseConvInputGrad(
dev_ctx, *output, filter, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, output, data_layout);
}
};
......@@ -353,6 +658,10 @@ template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
......@@ -368,11 +677,30 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
if (input_grad) {
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, dilations,
input_grad);
depthwiseConv(
dev_ctx, *output_grad, filter, strides, paddings,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
input_grad, data_layout);
}
if (filter_grad) {
......@@ -382,8 +710,10 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
dilations, filter_grad);
depthwiseConvFilterGrad(
dev_ctx, *output_grad, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, filter_grad, data_layout);
}
}
};
......
......@@ -39,7 +39,8 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, T *const output_data
const int dilate_height, const int dilate_width, T *const output_data, \
const DataLayout data_layout = DataLayout::kNCHW
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
......@@ -58,8 +59,13 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;
const int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
int in_offset;
if (data_layout == DataLayout::kNCHW) {
in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
} else {
in_offset = batch * input_height * input_width * input_channels;
}
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
......@@ -71,7 +77,13 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
w_in < w_end) {
const int offset = in_offset + h_in * input_width + w_in;
int offset;
if (data_layout == DataLayout::kNCHW) {
offset = in_offset + h_in * input_width + w_in;
} else {
offset = in_offset +
(h_in * input_width + w_in) * input_channels + c_in;
}
if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, input_data[offset]);
} else {
......@@ -81,9 +93,16 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
weight_offset++;
}
}
int index =
((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
w_out;
int index;
if (data_layout == DataLayout::kNCHW) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out;
} else {
index = ((batch * output_height + h_out) * output_width + w_out) *
gridDim.x +
c_out;
}
output_data[index] = value;
}
}
......@@ -111,8 +130,13 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
const int h_in_end = h_in_start + c_filter * dilate_height;
const int w_in_end = w_in_start + c_filter * dilate_width;
const int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
int in_offset;
if (data_layout == DataLayout::kNCHW) {
in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
} else {
in_offset = batch * input_height * input_width * input_channels;
}
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
......@@ -125,7 +149,13 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
w_in += dilate_width, w_f++) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
w_in < input_width) {
const int offset = in_offset + h_in * input_width + w_in;
int offset;
if (data_layout == DataLayout::kNCHW) {
offset = in_offset + h_in * input_width + w_in;
} else {
offset = in_offset +
(h_in * input_width + w_in) * input_channels + c_in;
}
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
max(0.0f, input_data[offset]);
......@@ -135,9 +165,16 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
}
}
}
int index =
((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
w_out;
int index;
if (data_layout == DataLayout::kNCHW) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out;
} else {
index = ((batch * output_height + h_out) * output_width + w_out) *
gridDim.x +
c_out;
}
output_data[index] = value;
}
}
......@@ -153,14 +190,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data);
dilate_width, output_data, data_layout);
else
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data);
dilate_width, output_data, data_layout);
} else {
if (c_filter == -1)
KernelDepthwiseConv<T, fuse_relu_before_conv>(
......@@ -168,14 +205,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
output_data);
output_data, data_layout);
else
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
output_data);
output_data, data_layout);
}
}
......@@ -190,7 +227,8 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, \
T *const input_grad_data
T *const input_grad_data, \
const DataLayout data_layout = DataLayout::kNCHW
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGrad(
......@@ -213,9 +251,17 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
int w_out_end = w_in + padding_width;
T value = 0;
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
int index;
if (data_layout == DataLayout::kNCHW) {
index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
} else {
index =
((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
c_in;
}
if (fuse_relu_before_conv) {
if (input_data[index] <= 0) {
input_grad_data[index] = 0;
......@@ -236,11 +282,20 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
const int output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
int output_grad_offset;
if (data_layout == DataLayout::kNCHW) {
output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
} else {
output_grad_offset =
((batch * output_height + s_h_out) * output_width +
s_w_out) *
output_channels +
c_out;
}
value += output_grad_data[output_grad_offset] *
filter_data[filter_offset];
}
......@@ -279,9 +334,16 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
T value = 0;
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
int index;
if (data_layout == DataLayout::kNCHW) {
index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
} else {
index =
((batch * input_height + h_in) * input_width + w_in) * gridDim.x +
c_in;
}
if (fuse_relu_before_conv) {
if (input_data[index] <= 0) {
input_grad_data[index] = 0;
......@@ -300,11 +362,20 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
const int output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
int output_grad_offset;
if (data_layout == DataLayout::kNCHW) {
output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
} else {
output_grad_offset =
((batch * output_height + s_h_out) * output_width +
s_w_out) *
output_channels +
c_out;
}
value +=
output_grad_data[output_grad_offset] *
r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
......@@ -327,14 +398,14 @@ __global__ void KernelDepthwiseConvInputGradSp(
output_height, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, input_grad_data);
dilate_width, input_grad_data, data_layout);
else if (c_filter == -1)
KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
input_data, output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
input_grad_data);
input_grad_data, data_layout);
else
KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
fuse_relu_before_conv>(
......@@ -342,7 +413,7 @@ __global__ void KernelDepthwiseConvInputGradSp(
output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
input_grad_data);
input_grad_data, data_layout);
}
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
......@@ -354,7 +425,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* filter_grad_data) {
const int dilate_width, T* filter_grad_data,
const DataLayout data_layout = DataLayout::kNCHW) {
T s = 0;
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
......@@ -374,18 +446,35 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
int input_id = ((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) *
input_height +
image_hk) *
input_width +
image_wk;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
max(0.0f, input_data[input_id]);
#define gaid_nhwc(N, H, W, C) \
((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
int input_id;
if (data_layout == DataLayout::kNCHW) {
input_id = ((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) *
input_height +
image_hk) *
input_width +
image_wk;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
max(0.0f, input_data[input_id]);
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
}
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
input_id =
((bid * input_height + image_hk) * input_width + image_wk) *
(gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
max(0.0f, input_data[input_id]);
} else {
s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] *
input_data[input_id];
}
}
#undef gaid
......@@ -403,21 +492,22 @@ __global__ void KernelDepthwiseConvFilterGradSp(
const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* filter_grad_data) {
const int dilate_width, T* filter_grad_data,
const DataLayout data_layout = DataLayout::kNCHW) {
if (c_filter_multiplier == 0)
KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
output_grad_data, input_data, num, output_channels, output_height,
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, filter_grad_data);
dilate_width, filter_grad_data, data_layout);
else
KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
output_grad_data, input_data, num, output_channels, output_height,
output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, filter_grad_data);
dilate_width, filter_grad_data, data_layout);
}
/*
......@@ -434,15 +524,24 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* output) {
const std::vector<int>& dilations, framework::Tensor* output,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output->dims()[1]
: output->dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output->dims()[2]
: output->dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output->dims()[3]
: output->dims()[2]);
const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3];
const int stride_height = strides[0];
......@@ -478,7 +577,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
output_width, input_channels, input_height, input_width, \
filter_multiplier, ksize_height, ksize_width, stride_height, \
stride_width, padding_height, padding_width, dilate_height, \
dilate_width, output_data); \
dilate_width, output_data, data_layout); \
return; \
}
check_case(1, 1, 3);
......@@ -511,14 +610,24 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* input_grad) {
framework::Tensor* input_grad,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
: output_grad.dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
: output_grad.dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
: output_grad.dims()[2]);
const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3];
const int stride_height = strides[0];
......@@ -556,7 +665,8 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
output_channels, output_height, output_width, input_channels, \
input_height, input_width, filter_multiplier, ksize_height, \
ksize_width, stride_height, stride_width, padding_height, \
padding_width, dilate_height, dilate_width, input_grad_data); \
padding_width, dilate_height, dilate_width, input_grad_data, \
data_layout); \
return; \
}
check_case(1, 1, 3);
......@@ -588,14 +698,24 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* filter_grad) {
framework::Tensor* filter_grad,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
: output_grad.dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
: output_grad.dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
: output_grad.dims()[2]);
const int ksize_height = filter_grad->dims()[2];
const int ksize_width = filter_grad->dims()[3];
const int stride_height = strides[0];
......@@ -629,7 +749,7 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
output_height, output_width, input_channels, input_height, \
input_width, filter_multiplier, ksize_height, ksize_width, \
stride_height, stride_width, padding_height, padding_width, \
dilate_height, dilate_width, filter_grad_data); \
dilate_height, dilate_width, filter_grad_data, data_layout); \
return; \
}
check_case(1);
......
......@@ -22,6 +22,8 @@ namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/*
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
......@@ -34,7 +36,8 @@ class DepthwiseConvFunctor {
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations, framework::Tensor* output);
const std::vector<int>& dilations, framework::Tensor* output,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
......@@ -47,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* input_grad);
framework::Tensor* input_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
......@@ -59,7 +63,8 @@ class DepthwiseConvFilterGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* filter_grad);
framework::Tensor* filter_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math
......
......@@ -32,7 +32,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
......@@ -41,16 +42,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
dilation[1] == 1) {
if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
padding[3] == 0) {
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
return;
} else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
padding[3] == 1) {
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
return;
}
// TODO(TJ): complete padding >=2
}
im2col_common<T>(im, dilation, stride, padding, col);
im2col_common<T>(im, dilation, stride, padding, col, data_layout);
}
};
......@@ -67,13 +68,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
......@@ -109,7 +114,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
int im_offset;
if (data_layout == DataLayout::kNCHW) {
im_offset =
(c_im * im_height + im_row_idx) * im_width + im_col_idx;
} else {
im_offset =
(im_row_idx * im_width + im_col_idx) * im_channels + c_im;
}
im_data[im_offset] +=
col_data[(c * col_height + h) * col_width + w];
}
}
......@@ -139,7 +152,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
......@@ -202,7 +216,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
......
......@@ -26,27 +26,41 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
int im_width, int dilation_h, int dilation_w,
int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width,
int col_height, int col_width, T* data_col) {
int col_height, int col_width, T* data_col,
const DataLayout data_layout) {
int input_channels = num_outs / col_height / col_width;
int channels_col = input_channels * filter_height * filter_width;
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) {
int w_out = index % col_width;
int h_out = (index / col_width) % col_height;
int channel_in = index / col_width / col_height;
int w_out = (data_layout == DataLayout::kNCHW
? index % col_width
: (index / input_channels) % col_width);
int h_out = (data_layout == DataLayout::kNCHW
? (index / col_width) % col_height
: (index / input_channels / col_width) % col_height);
int channel_in =
(data_layout == DataLayout::kNCHW ? index / col_width / col_height
: index % input_channels);
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height - padding_height;
int w_in = w_out * stride_width - padding_width;
data_col += (channel_out * col_height + h_out) * col_width + w_out;
data_im += (channel_in * im_height + h_in) * im_width + w_in;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int rIdx = h_in + i * dilation_h;
int cIdx = w_in + j * dilation_w;
int im_idx;
if (data_layout == DataLayout::kNCHW) {
im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
} else {
im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
}
*data_col =
(rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
? 0
: data_im[i * dilation_h * im_width + j * dilation_w];
: data_im[im_idx];
data_col += col_height * col_width;
}
}
......@@ -65,13 +79,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3);
PADDLE_ENFORCE_EQ(col->dims().size(), 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
......@@ -86,7 +105,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0, context.stream()>>>(
im.data<T>(), num_outputs, im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[1], col_height, col_width, col->data<T>());
padding[0], padding[1], col_height, col_width, col->data<T>(),
data_layout);
}
};
......@@ -95,18 +115,27 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
int dilation_h, int dilation_w, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int col_height,
int col_width, T* data_im) {
int col_width, T* data_im,
const DataLayout data_layout) {
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
int input_channels = n / im_height / im_width;
if (index < n) {
T val = 0;
int w = index % im_width + padding_width;
int h = (index / im_width) % im_height + padding_height;
int c = index / (im_width * im_height);
int w = (data_layout == DataLayout::kNCHW
? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width);
int h = (data_layout == DataLayout::kNCHW
? (index / im_width) % im_height + padding_height
: (index / input_channels / im_width) % im_height +
padding_height);
int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height
: index % input_channels);
// compute the start and end of the output
int w_col_start =
......@@ -151,13 +180,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3);
PADDLE_ENFORCE_EQ(col.dims().size(), 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
......@@ -191,7 +225,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[2], col_height, col_width, im->data<T>());
padding[0], padding[1], col_height, col_width, im->data<T>(),
data_layout);
}
};
......@@ -248,9 +283,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3);
PADDLE_ENFORCE_EQ(col->dims().size(), 5);
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
......@@ -330,9 +368,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3);
PADDLE_ENFORCE_EQ(col.dims().size(), 5);
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......
......@@ -23,6 +23,8 @@ namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 };
......@@ -86,7 +88,8 @@ class Im2ColFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col);
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <ColFormat Format, typename DeviceContext, typename T>
......@@ -95,7 +98,8 @@ class Col2ImFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im);
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math
......
......@@ -30,10 +30,14 @@ inline void im2col_common(const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -50,8 +54,14 @@ inline void im2col_common(const framework::Tensor& im,
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_idx;
if (data_layout == DataLayout::kNCHW) {
im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
} else {
im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
}
int col_idx = (c * output_height + h) * output_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
......@@ -65,11 +75,15 @@ inline void im2col_common(const framework::Tensor& im,
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
inline void im2col_sh1sw1dh1dw1ph0pw0(
const framework::Tensor& im, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -89,7 +103,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
const T* src_data = src_data_ic;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
std::memcpy(dst_data, src_data + kw, copy_size);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + kw, copy_size);
} else {
for (int kow = 0; kow < output_width; ++kow) {
dst_data[kow] =
im_data[((oh + kh) * im_width + kw + kow) * im_channels + ic];
}
}
dst_data = dst_data + col_matrix_width;
}
src_data = src_data + im_width;
......@@ -107,10 +128,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
framework::Tensor* col,
const DataLayout data_layout) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -180,7 +205,17 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width;
continue;
}
std::memcpy(dst_data + plw, src_data, copy_size);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data + plw, src_data, copy_size);
} else {
for (int kow = 0; kow < output_width - plw - prw; ++kow) {
dst_data[plw + kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
src_data = src_data + im_width;
}
......@@ -226,19 +261,49 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
// TODO(TJ): reuse plw-kw outside this for
// try to unify
for (int kw = 0; kw < plw; ++kw) {
std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw)));
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw)));
} else {
for (int kow = 0; kow < output_width - (plw - kw); ++kow) {
dst_data[plw - kw + kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
for (int kw = plw; kw < filter_width - prw; ++kw) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width);
} else {
for (int kow = 0; kow < output_width; ++kow) {
dst_data[kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kw - plw + kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
int i = 1;
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i));
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i));
} else {
for (int kow = 0; kow < output_width - i; ++kow) {
dst_data[kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kw - plw + kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
src_data = src_data + im_width;
......
......@@ -32,16 +32,21 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
......@@ -59,6 +64,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
......@@ -97,10 +103,16 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
c_in;
}
col_data[col_idx] =
(h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
......@@ -126,16 +138,21 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......@@ -191,11 +208,17 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx =
((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx =
((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
cIm;
}
int col_idx =
((c * output_depth + d) * output_height + h) * output_width +
w;
......
......@@ -28,7 +28,12 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int filter_width, int stride_depth, int stride_height,
int stride_width, int padding_depth, int padding_height,
int padding_width, int output_detph, int output_height,
int output_width, T* data_col) {
int output_width, T* data_col,
const DataLayout data_layout) {
int input_channels =
num_kernels / output_detph / output_height / output_width;
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % output_width;
......@@ -43,18 +48,22 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
output_width +
w_out;
data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filter_depth; ++k) {
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int d = d_in + k * dilation_d;
int h = h_in + i * dilation_h;
int w = w_in + j * dilation_w;
int col_idx = (k * dilation_d * height + i * dilation_h) * width +
j * dilation_w;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((channel_in * depth + d) * height + h) * width + w;
} else {
vol_idx =
((d * height + h) * width + w) * input_channels + channel_in;
}
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? data_vol[col_idx]
? data_vol[vol_idx]
: 0;
data_col += output_detph * output_height * output_width;
}
......@@ -64,7 +73,10 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
}
/*
* im = [input_channels,intpu_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
......@@ -76,15 +88,21 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4);
PADDLE_ENFORCE_EQ(col->dims().size(), 7);
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
......@@ -130,7 +148,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>());
pad_w_left, output_depth, output_height, output_width, col->data<T>(),
data_layout);
}
};
......@@ -141,18 +160,27 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int filter_width, int stride_depth, int stride_height,
int stride_width, int padding_depth, int padding_height,
int padding_width, int output_detph, int output_height,
int output_width, T* data_vol) {
int output_width, T* data_vol,
const DataLayout data_layout) {
const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
int input_channels = num_kernels / depth / height / width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
T src_val = 0;
int w = index % width + padding_width;
int h = (index / width) % height + padding_height;
int d = (index / width / height) % depth + padding_depth;
int c = index / width / height / depth;
int w = (data_layout == DataLayout::kNCHW
? index % width + padding_width
: (index / input_channels) % width + padding_width);
int h = (data_layout == DataLayout::kNCHW
? (index / width) % height + padding_height
: (index / input_channels / width) % height + padding_height);
int d = (data_layout == DataLayout::kNCHW
? (index / width / height) % depth + padding_depth
: index / input_channels / width / height + padding_depth);
int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth
: index % input_channels);
// compute the start and end of the output
int w_col_start =
......@@ -196,7 +224,10 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
}
/*
* im = [input_channels, input_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
......@@ -208,15 +239,21 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4);
PADDLE_ENFORCE_EQ(col.dims().size(), 7);
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......@@ -263,7 +300,8 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>());
pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
data_layout);
}
};
......
......@@ -22,6 +22,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/*
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
* seven dimensions in the Vol2ColFunctor calculation,
......@@ -70,8 +73,8 @@ class Vol2ColFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const;
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) const;
};
template <typename DeviceContext, typename T>
......@@ -80,8 +83,8 @@ class Col2VolFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const;
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout = DataLayout::kNCHW) const;
};
} // namespace math
......
......@@ -4424,13 +4424,14 @@ def conv2d_transpose(input,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
name=None,
data_format='NCHW'):
"""
**Convlution2D transpose layer**
The convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input(Input) and output(Output)
are in NCHW format. Where N is batch size, C is the number of channels,
are in NCHW or NHWC format. Where N is batch size, C is the number of channels,
H is the height of the feature, and W is the width of the feature.
Parameters(dilations, strides, paddings) are two elements. These two elements
represent height and width, respectively. The details of convolution transpose
......@@ -4448,12 +4449,12 @@ def conv2d_transpose(input,
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`W`: Filter value, a tensor with MCHW format.
* :math:`X`: Input value, a 4-D Tensor with NCHW or NHWC format.
* :math:`W`: Filter value, a 4-D Tensor with MCHW format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`b`: Bias value, a 2-D Tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
* :math:`Out`: Output value, a 4-D Tensor with data format 'NCHW' or 'NHWC', the shape of :math:`Out` and :math:`X` may be different.
Example:
......@@ -4471,10 +4472,12 @@ def conv2d_transpose(input,
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
H^\prime_{out} &= (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1 \\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ] \\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ]
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ]
padding mode is 'SAME' and 'VALID' can reference this link<https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleGAN/network/base_network.py#L181>`_
Note:
if output_size is None, :math:`H_{out} = H^\prime_{out}, W_{out} = W^\prime_{out}`;
......@@ -4484,51 +4487,63 @@ def conv2d_transpose(input,
conv2d_transpose can compute the kernel size automatically.
Args:
input(Variable): The input image with [N, C, H, W] format.
input(Variable): 4-D Tensor with [N, C, H, W] or [N, H, W, C] format,
its data type is float32 or float64.
num_filters(int): The number of the filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
output_size(int|tuple, optional): The output image size. If output size is a
tuple, it must contain two integers, (image_height, image_width). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
should follow the formula above. Default: None.
filter_size(int|tuple, optional): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_height, filter_size_width).
Otherwise, filter_size_height = filter_size_width = filter_size. None if
use output size to calculate filter_size.
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_height, padding_width). Otherwise,
padding_height = padding_width = padding. Default: padding = 0.
stride(int|tuple): The stride size. If stride is a tuple, it must
use output size to calculate filter_size. Default: None.
padding(int|list|str|tuple, optional):The padding size. If `padding` is a
string, either 'VALID' or 'SAME' supported, which is the padding algorithm.
If `padding` is a tuple or list, it could be in three forms:
`[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and
when `data_format` is `'NCHW'`,
`padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `'NHWC'`, `padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
stride(int|tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_height, stride_width). Otherwise,
stride_height = stride_width = stride. Default: stride = 1.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
dilation(int|tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_height, dilation_width). Otherwise,
dilation_height = dilation_width = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d transpose layer. Inspired by
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str): Activation type, if it is set to None, activation is not appended.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
name(str, optional): A name for this layer(optional). If set None, the layer
will be named automatically. Default: True.
data_format(str, optional): The data format of the input and output data. An optional string
from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'.
Returns:
Variable: The tensor variable storing the convolution transpose result.
Variable: A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or
(num_batches, out_h, out_w, channels).
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
......@@ -4542,8 +4557,12 @@ def conv2d_transpose(input,
conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
"""
assert param_attr is not False, "param_attr should not be False in conv2d_transpose."
input_channel = input.shape[1]
if data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Attr(data_format) of Op(fluid.layers.conv2d_transpose) got wrong value: received "
+ data_format + " but only NCHW or NHWC supported.")
input_channel = input.shape[1] if data_format == 'NCHW' else input.shape[-1]
op_type = 'conv2d_transpose'
if (input_channel == groups and num_filters == input_channel and
not use_cudnn):
......@@ -4553,26 +4572,68 @@ def conv2d_transpose(input,
if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable")
padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
def _update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 4:
if is_list_or_tuple(padding[0]) and (data_format == "NCHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:4]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"):
if not (padding[0] == [0, 0] and padding[3] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:3]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding')
else:
padding = utils.convert_to_list(padding, 2, 'padding')
padding = [padding[0], padding[0], padding[1], padding[1]]
return padding
padding_algorithm = "EXPLICIT"
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." %
str(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0, 0, 0, 0]
elif padding == "SAME":
padding_algorithm = "SAME"
padding = [0, 0, 0, 0]
padding = _update_padding(padding, data_format)
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
if isinstance(output_size, int):
output_size = [output_size, output_size]
h_in = input.shape[2]
w_in = input.shape[3]
h_in = input.shape[2] if data_format == 'NCHW' else input.shape[1]
w_in = input.shape[3] if data_format == 'NCHW' else input.shape[2]
filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 *
padding[0] - 1) // dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) // dilation[1] + 1
filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + padding[0] +
padding[1] - 1) // dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + padding[2] +
padding[3] - 1) // dilation[1] + 1
filter_size = [filter_size_h, filter_size_w]
else:
filter_size = utils.convert_to_list(filter_size, 2,
......@@ -4584,7 +4645,6 @@ def conv2d_transpose(input,
output_size = utils.convert_to_list(output_size, 2, 'output_size')
else:
raise ValueError("output_size should be list or int")
padding = utils.convert_to_list(padding, 2, 'padding')
groups = 1 if groups is None else groups
filter_shape = [input_channel, num_filters // groups] + filter_size
......@@ -4601,9 +4661,11 @@ def conv2d_transpose(input,
'output_size': output_size,
'strides': stride,
'paddings': padding,
'padding_algorithm': padding_algorithm,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn
'use_cudnn': use_cudnn,
'data_format': data_format
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......@@ -4623,13 +4685,14 @@ def conv3d_transpose(input,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
name=None,
data_format='NCDHW'):
"""
**Convlution3D transpose layer**
The convolution3D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input(Input) and output(Output)
are in NCDHW format. Where N is batch size, C is the number of channels,
are in NCDHW or NDHWC format. Where N is batch size, C is the number of channels,
D is the depth of the feature, H is the height of the feature, and W
is the width of the feature. Parameters(dilations, strides, paddings) are
two elements. These two elements represent height and width, respectively.
......@@ -4647,10 +4710,10 @@ def conv3d_transpose(input,
In the above equation:
* :math:`X`: Input value, a tensor with NCDHW format.
* :math:`W`: Filter value, a tensor with MCDHW format.
* :math:`X`: Input value, a Tensor with NCDHW or NDHWC format.
* :math:`W`: Filter value, a Tensor with MCDHW format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`b`: Bias value, a 2-D Tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
......@@ -4670,55 +4733,68 @@ def conv3d_transpose(input,
.. math::
D_{out} &= (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\\\
H_{out} &= (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\\\
W_{out} &= (W_{in} - 1) * strides[2] - 2 * paddings[2] + dilations[2] * (W_f - 1) + 1
D_{out} &= (D_{in} - 1) * strides[0] - pad_depth_front - pad_depth_back + dilations[0] * (D_f - 1) + 1 \\\\
H_{out} &= (H_{in} - 1) * strides[1] - pad_height_top - pad_height_bottom + dilations[1] * (H_f - 1) + 1 \\\\
W_{out} &= (W_{in} - 1) * strides[2] - pad_width_left - pad_width_right + dilations[2] * (W_f - 1) + 1
Padding mode is 'SAME' and 'VALID' can reference this
link<https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleGAN/network/base_network.py#L181>`_
Args:
input(Variable): The input image with [N, C, D, H, W] format.
input(Variable): A 5-D Tensor with [N, C, H, W] or [N, H, W, C] format. Its data type is float32 or float64.
num_filters(int): The number of the filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
output_size(int|tuple, optional): The output image size. If output size is a
tuple, it must contain three integers, (image_D, image_H, image_W). This
parameter only works when filter_size is None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
filter_size(int|tuple, optional): The filter size. If filter_size is a tuple,
it must contain three integers, (filter_size_depth, filter_size_height, \
filter_size_width). Otherwise, filter_size_depth = filter_size_height = \
filter_size_width = filter_size. None if use output size to
calculate filter_size.
padding(int|tuple): The padding size. If padding is a tuple, it must
contain three integers, (padding_depth, padding_height, padding_width). Otherwise,
padding_depth = padding_height = padding_width = padding. Default: padding = 0.
stride(int|tuple): The stride size. If stride is a tuple, it must
padding(int|list|str|tuple, optional): The padding size. if `padding` is a string,
either 'VALID' or 'SAME' supported, which is the padding algorithm. If `padding`
is a tuple or list, it could be in three forms: `[pad_depth, pad_height, pad_width]` or
`[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
and when `data_format` is `'NCDHW'`, `padding` can be in the form
`[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `'NDHWC'`, `padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
stride(int|tuple, optional): The stride size. If stride is a tuple, it must
contain three integers, (stride_depth, stride_height, stride_width). Otherwise,
stride_depth = stride_height = stride_width = stride. Default: stride = 1.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
dilation(int|tuple, optional): The dilation size. If dilation is a tuple, it must
contain three integers, (dilation_depth, dilation_height, dilation_width). Otherwise,
dilation_depth = dilation_height = dilation_width = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv3d transpose layer. Inspired by
groups(int, optional): The groups number of the Conv3d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups=1
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv3d_transpose. If it is set to None or one attribute of ParamAttr, conv3d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv3d_transpose.
bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv3d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv3d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
name(str, optional): A name for this layer(optional). If set None, the layer
will be named automatically.
data_format(str, optional):The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`.
Default: 'NCDHW'.
Returns:
Variable: The tensor variable storing the convolution transpose result.
A 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or
(num_batches, out_d, out_h, out_w, channels).
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
......@@ -4732,35 +4808,89 @@ def conv3d_transpose(input,
conv3d_transpose = fluid.layers.conv3d_transpose(input=data, num_filters=2, filter_size=3)
"""
assert param_attr is not False, "param_attr should not be False in conv3d_transpose."
if data_format not in ['NCDHW', 'NDHWC']:
raise ValueError(
"Param(data_format) of Op(fluid.layers.conv3d_transpose) got wrong value: received "
+ data_format + " but only NCDHW or NDHWC supported.")
l_type = "conv3d_transpose"
helper = LayerHelper(l_type, **locals())
if not isinstance(input, Variable):
raise TypeError("Input of conv3d_transpose must be Variable")
input_channel = input.shape[1]
input_channel = input.shape[1] if data_format == 'NCDHW' else input.shape[
-1]
padding = utils.convert_to_list(padding, 3, 'padding')
stride = utils.convert_to_list(stride, 3, 'stride')
dilation = utils.convert_to_list(dilation, 3, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
def _update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 5:
if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:5]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"):
if not (padding[0] == [0, 0] and padding[4] == [0, 0]):
raise ValueError(
"Non-zero padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding')
elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding')
else:
padding = utils.convert_to_list(padding, 3, 'padding')
padding = [
padding[0], padding[0], padding[1], padding[1], padding[2],
padding[2]
]
return padding
padding_algorithm = "EXPLICIT"
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." %
str(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0, 0, 0, 0, 0, 0]
elif padding == "SAME":
padding_algorithm = "SAME"
padding = [0, 0, 0, 0, 0, 0]
padding = _update_padding(padding, data_format)
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
if isinstance(output_size, int):
output_size = [output_size, output_size]
d_in = input.shape[2]
h_in = input.shape[3]
w_in = input.shape[4]
d_in = input.shape[2] if data_format == 'NCDHW' else input.shape[1]
h_in = input.shape[3] if data_format == 'NCDHW' else input.shape[2]
w_in = input.shape[4] if data_format == 'NCDHW' else input.shape[3]
filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + 2 *
padding[0] - 1) // dilation[0] + 1
filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + 2 *
padding[1] - 1) // dilation[1] + 1
filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + 2 *
padding[2] - 1) // dilation[2] + 1
filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + padding[0] +
padding[1] - 1) // dilation[0] + 1
filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + padding[2] +
padding[3] - 1) // dilation[1] + 1
filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + padding[4] +
padding[5] - 1) // dilation[2] + 1
filter_size = [filter_size_d, filter_size_h, filter_size_w]
else:
filter_size = utils.convert_to_list(filter_size, 3,
......@@ -4771,6 +4901,11 @@ def conv3d_transpose(input,
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
if data_format == 'NCDHW':
data_format = 'NCHW'
if data_format == 'NDHWC':
data_format = 'NHWC'
pre_bias = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type=l_type,
......@@ -4780,9 +4915,11 @@ def conv3d_transpose(input,
attrs={
'strides': stride,
'paddings': padding,
'padding_algorithm': padding_algorithm,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn
'use_cudnn': use_cudnn,
'data_format': data_format
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......
......@@ -18,10 +18,19 @@ import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if attrs['data_format'] == 'NHWC':
input_ = np.transpose(input_, [0, 3, 1, 2])
in_n, in_c, in_h, in_w = input_.shape
f_c, f_out_c, f_h, f_w = filter_.shape
groups = attrs['groups']
......@@ -31,14 +40,47 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
# update pad and dilation
def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride):
padding = []
for input_size, filter_size, stride_size in zip(
input_shape, kernel_size, kernel_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
ksize = filter_.shape[2:4]
if padding_algorithm == "VALID":
pad = [0, 0, 0, 0]
elif padding_algorithm == "SAME":
dilation = [1, 1]
input_data_shape = []
if attrs['data_format'] == "NCHW":
input_data_shape = input_.shape[2:4]
elif attrs['data_format'] == "NHWC":
input_data_shape = input_.shape[1:3]
pad = _get_padding_with_SAME(input_data_shape, ksize, stride)
pad_h_0, pad_h_1 = pad[0], pad[0]
pad_w_0, pad_w_1 = pad[1], pad[1]
if len(pad) == 4:
pad_h_0, pad_h_1 = pad[0], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[3]
d_bolck_h = dilations[0] * (f_h - 1) + 1
d_bolck_w = dilations[1] * (f_w - 1) + 1
out_h = (in_h - 1) * stride[0] + d_bolck_h
out_w = (in_w - 1) * stride[1] + d_bolck_w
if 'output_size' in attrs:
output_size = attrs['output_size']
out_h = output_size[0] + 2 * pad[0]
out_w = output_size[1] + 2 * pad[1]
out_h = output_size[0] + pad_h_0 + pad_h_1
out_w = output_size[1] + pad_w_0 + pad_w_1
out = np.zeros((in_n, out_c, out_h, out_w))
......@@ -61,7 +103,9 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
out[n, g * f_out_c + k, i1:i2:dilations[0], j1:j2:
dilations[1]] += tmp_out
out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]]
out = out[:, :, pad_h_0:out_h - pad_h_1, pad_w_0:out_w - pad_w_1]
if attrs['data_format'] == 'NHWC':
out = np.transpose(out, [0, 2, 3, 1])
return out
......@@ -72,7 +116,9 @@ class TestConv2dTransposeOp(OpTest):
self.use_cudnn = False
self.use_mkldnn = False
self.output_size = None
self.data_format = "AnyLayout"
self.data_format = "NCHW"
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_test_case()
......@@ -83,6 +129,7 @@ class TestConv2dTransposeOp(OpTest):
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
......@@ -160,7 +207,7 @@ class TestConv2dTransposeOp(OpTest):
self.op_type = "conv2d_transpose"
class TestWithPad(TestConv2dTransposeOp):
class TestWithSymmetricPad(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
......@@ -171,6 +218,39 @@ class TestWithPad(TestConv2dTransposeOp):
self.filter_size = [f_c, 6, 3, 3]
class TestWithAsymmetricPad(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
class TestWithSAMEPad(TestConv2dTransposeOp):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'SAME'
class TestWithVALIDPad(TestConv2dTransposeOp):
def init_test_case(self):
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.padding_algorithm = 'VALID'
class TestWithGroups(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
......@@ -216,6 +296,91 @@ class TestWithEvenUpsample(TestConv2dTransposeOp):
self.filter_size = [f_c, 6, 5, 5]
class Test_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithSymmetricPad_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithAsymmetricPad_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithGroups_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithStride_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithDilation_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
class TestWithEvenUpsample_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
# ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
......@@ -227,7 +392,7 @@ class TestCUDNN(TestConv2dTransposeOp):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad):
class TestCUDNNWithSymmetricPad(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
......@@ -242,6 +407,57 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithAsymmetricPad(TestWithAsymmetricPad):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithSAMEPad(TestWithSAMEPad):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithVALIDPad(TestWithVALIDPad):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride):
......@@ -276,19 +492,6 @@ class TestCUDNNWithGroups(TestWithGroups):
self.op_type = "conv2d_transpose"
class TestDepthwiseConvTranspose(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NCHW
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 4, 4]
self.op_type = "depthwise_conv2d_transpose"
# ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
......@@ -312,5 +515,334 @@ class TestCUDNNWithEvenUpsample(TestWithEvenUpsample):
# def init_op_type(self):
# self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithSymmetricPad_NHWC(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithAsymmetricPad_NHWC(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 0, 2, 3]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride_NHWC(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.input_size = [2, 5, 5, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups_NHWC(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithEvenUpsample_NHWC(TestWithEvenUpsample):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1]
self.output_size = [14, 14]
self.input_size = [2, 7, 7, 3] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 5, 5]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d_transpose"
class TestDepthwiseConvTranspose(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NCHW
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 4, 4]
self.op_type = "depthwise_conv2d_transpose"
class TestDepthwiseConvTransposeAsymmetricPad(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NCHW
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 3, 3]
self.op_type = "depthwise_conv2d_transpose"
self.data_format = 'NCHW'
class TestDepthwiseConvTransposeSAMEPad(TestConv2dTransposeOp):
def init_test_case(self):
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NHWC
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 3, 3]
self.op_type = "depthwise_conv2d_transpose"
self.padding_algorithm = 'SAME'
class TestDepthwiseConvTransposeVALIDPad(TestConv2dTransposeOp):
def init_test_case(self):
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NHWC
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 3, 3]
self.op_type = "depthwise_conv2d_transpose"
self.padding_algorithm = 'VALID'
class TestDepthwiseConvTranspose_NHWC_4x4kernel(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 16, 16, 8] # NHWC
self.groups = 8
assert np.mod(self.input_size[3], self.groups) == 0
f_c = self.input_size[3] // self.groups
self.filter_size = [self.input_size[3], f_c, 4, 4]
self.op_type = "depthwise_conv2d_transpose"
self.data_format = 'NHWC'
class TestDepthwiseConvTranspose_NHWC_3x3kernel(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 16, 16, 8] # NHWC
self.groups = 8
assert np.mod(self.input_size[3], self.groups) == 0
f_c = self.input_size[3] // self.groups
self.filter_size = [self.input_size[3], f_c, 3, 3]
self.op_type = "depthwise_conv2d_transpose"
self.data_format = 'NHWC'
class TestDepthwiseConvTransposeAsymmetricPad_NHWC(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 2]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 16, 16, 8] # NHWC
self.groups = 8
assert np.mod(self.input_size[3], self.groups) == 0
f_c = self.input_size[3] // self.groups
self.filter_size = [self.input_size[3], f_c, 3, 3]
self.op_type = "depthwise_conv2d_transpose"
self.data_format = 'NHWC'
class TestConv2dTransposeAPI(OpTest):
def test_case1(self):
data1 = fluid.layers.data(
name='data1', shape=[3, 5, 5], dtype='float32')
data2 = fluid.layers.data(
name='data2', shape=[5, 5, 3], dtype='float32')
out1 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
data_format='NCHW')
out2 = fluid.layers.conv2d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
data_format='NHWC')
out3 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
data_format='NHWC')
out4 = fluid.layers.conv2d_transpose(
input=data1,
groups=3,
num_filters=6,
filter_size=3,
padding=[[0, 0], [0, 0], [2, 1], [0, 0]],
data_format='NCHW')
out5 = fluid.layers.conv2d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
padding='SAME',
data_format='NCHW')
out6 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding='VALID',
data_format='NHWC')
out7 = fluid.layers.conv2d_transpose(
input=data1,
groups=1,
num_filters=6,
output_size=[7, 7],
padding=[0, 0],
data_format='NHWC')
data1_np = np.random.random((2, 3, 5, 5)).astype("float32")
data2_np = np.random.random((2, 5, 5, 3)).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(
fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2, out3, out4, out5, out6, out7],
return_numpy=True)
self.assertIsNotNone(results[0])
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[2])
self.assertIsNotNone(results[3])
self.assertIsNotNone(results[4])
self.assertIsNotNone(results[5])
self.assertIsNotNone(results[6])
class TestConv2dTransposeOpException(OpTest):
def test_exception(self):
data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32")
def attr_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
data_format="NCDHW")
self.assertRaises(ValueError, attr_data_format)
def attr_padding_str():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding='Vald')
self.assertRaises(ValueError, attr_padding_str)
def attr_padding_list():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [1, 1], [0, 0], [0, 0]])
self.assertRaises(ValueError, attr_padding_list)
def attr_padding_with_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [0, 0], [0, 0], [1, 1]],
data_format='NHWC')
self.assertRaises(ValueError, attr_padding_with_data_format)
if __name__ == '__main__':
unittest.main()
......@@ -18,10 +18,19 @@ import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
def conv3dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if attrs['data_format'] == 'NHWC':
input_ = np.transpose(input_, [0, 4, 1, 2, 3])
in_n, in_c, in_d, in_h, in_w = input_.shape
f_c, f_out_c, f_d, f_h, f_w = filter_.shape
groups = attrs['groups']
......@@ -32,6 +41,39 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride):
padding = []
for input_size, filter_size, stride_size in zip(
input_shape, kernel_size, kernel_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
ksize = filter_.shape[2:5]
if padding_algorithm == "VALID":
pad = [0, 0, 0, 0, 0, 0]
elif padding_algorithm == "SAME":
dilation = [1, 1, 1]
input_data_shape = []
if attrs['data_format'] == "NCHW":
input_data_shape = input_.shape[2:5]
elif attrs['data_format'] == "NHWC":
input_data_shape = input_.shape[1:4]
pad = _get_padding_with_SAME(input_data_shape, ksize, stride)
pad_d_0, pad_d_1 = pad[0], pad[0]
pad_h_0, pad_h_1 = pad[1], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[2]
if len(pad) == 6:
pad_d_0, pad_d_1 = pad[0], pad[1]
pad_h_0, pad_h_1 = pad[2], pad[3]
pad_w_0, pad_w_1 = pad[4], pad[5]
d_bolck_d = dilations[0] * (f_d - 1) + 1
d_bolck_h = dilations[1] * (f_h - 1) + 1
d_bolck_w = dilations[2] * (f_w - 1) + 1
......@@ -62,8 +104,10 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
out[n, g * f_out_c + k, d1:d2:dilations[0], i1:i2:
dilations[1], j1:j2:dilations[2]] += tmp_out
out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w -
pad[2]]
out = out[:, :, pad_d_0:out_d - pad_d_1, pad_h_0:out_h - pad_h_1, pad_w_0:
out_w - pad_w_1]
if attrs['data_format'] == 'NHWC':
out = np.transpose(out, [0, 2, 3, 4, 1])
return out
......@@ -71,6 +115,9 @@ class TestConv3dTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.use_cudnn = False
self.data_format = 'NCHW'
self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_test_case()
......@@ -81,10 +128,11 @@ class TestConv3dTransposeOp(OpTest):
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'dilations': self.dilations,
'groups': self.groups,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
'data_format': self.data_format
}
output = conv3dtranspose_forward_naive(input_, filter_,
......@@ -154,7 +202,7 @@ class TestConv3dTransposeOp(OpTest):
self.op_type = "conv3d_transpose"
class TestWithPad(TestConv3dTransposeOp):
class TestWithSymmetricPad(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
......@@ -165,6 +213,39 @@ class TestWithPad(TestConv3dTransposeOp):
self.filter_size = [f_c, 6, 3, 3, 3]
class TestWithAsymmetricPad(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 0, 1, 2]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
class TestWithSAMEPad(TestConv3dTransposeOp):
def init_test_case(self):
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.padding_algorithm = 'SAME'
class TestWithVALIDPad(TestConv3dTransposeOp):
def init_test_case(self):
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.padding_algorithm = 'VALID'
class TestWithGroups(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
......@@ -198,6 +279,78 @@ class TestWithDilation(TestConv3dTransposeOp):
self.filter_size = [f_c, 6, 3, 3, 3]
class Test_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithSymmetricPad_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithAsymmetricPad_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 0, 1, 0, 1, 2]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithGroups_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 5, 4] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithStride_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [2, 2, 2]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NCDHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
class TestWithDilation_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [2, 2, 2]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NCDHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
# ------------ test_cudnn ------------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
......@@ -209,7 +362,7 @@ class TestCUDNN(TestConv3dTransposeOp):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithPad(TestWithPad):
class TestCUDNNWithSymmetricPad(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
......@@ -224,6 +377,57 @@ class TestCUDNNWithPad(TestWithPad):
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithAsymmetricPad(TestWithAsymmetricPad):
def init_test_case(self):
self.pad = [1, 1, 1, 0, 0, 2]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 4, 4, 4] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithSAMEPad(TestWithSAMEPad):
def init_test_case(self):
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.padding_algorithm = 'SAME'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithVALIDPad(TestWithVALIDPad):
def init_test_case(self):
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.padding_algorithm = 'VALID'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride(TestWithStride):
......@@ -272,5 +476,222 @@ class TestCUDNNWithGroups(TestWithGroups):
# def init_op_type(self):
# self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNN_NHWC(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithSymmetricPad_NHWC(TestWithSymmetricPad):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithAsymmetricPad_NHWC(TestWithAsymmetricPad):
def init_test_case(self):
self.pad = [1, 0, 1, 0, 0, 2]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NDHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithStride_NHWC(TestWithStride):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [2, 2, 2]
self.dilations = [1, 1, 1]
self.groups = 1
self.input_size = [2, 5, 5, 5, 3] # NCDHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 6, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNWithGroups_NHWC(TestWithGroups):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 5, 4] # NCHW
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3, 3]
self.data_format = 'NHWC'
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d_transpose"
class TestConv3dTransposeAPI(OpTest):
def test_case1(self):
data1 = fluid.layers.data(
name='data1', shape=[3, 5, 5, 5], dtype='float32')
data2 = fluid.layers.data(
name='data2', shape=[5, 5, 5, 3], dtype='float32')
out1 = fluid.layers.conv3d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
data_format='NCDHW')
out2 = fluid.layers.conv3d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
data_format='NDHWC')
out3 = fluid.layers.conv3d_transpose(
input=data1,
groups=1,
num_filters=6,
filter_size=3,
padding=[[0, 0], [0, 0], [1, 1], [0, 0], [1, 1]],
data_format='NCDHW')
out4 = fluid.layers.conv3d_transpose(
input=data2,
groups=3,
num_filters=6,
filter_size=3,
padding=[[0, 0], [0, 0], [1, 1], [1, 2], [0, 0]],
data_format='NDHWC')
out5 = fluid.layers.conv3d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
padding='SAME',
data_format='NCDHW')
out6 = fluid.layers.conv3d_transpose(
input=data2,
groups=1,
num_filters=6,
filter_size=3,
padding='VALID',
data_format='NDHWC')
out7 = fluid.layers.conv3d_transpose(
input=data2,
groups=1,
num_filters=6,
output_size=[7, 7, 7],
padding=[0, 0, 0],
data_format='NDHWC')
data1_np = np.random.random((2, 3, 5, 5, 5)).astype("float32")
data2_np = np.random.random((2, 5, 5, 5, 3)).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(
fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2, out3, out4, out5, out6, out7],
return_numpy=True)
self.assertIsNotNone(results[0])
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[2])
self.assertIsNotNone(results[3])
self.assertIsNotNone(results[4])
self.assertIsNotNone(results[5])
self.assertIsNotNone(results[6])
class TestConv3dTransposeOpException(OpTest):
def test_exception(self):
data = fluid.layers.data(
name='data', shape=[3, 5, 5, 5], dtype="float32")
def attr_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
data_format="NCDW")
self.assertRaises(ValueError, attr_data_format)
def attr_padding_str():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding='Vald')
self.assertRaises(ValueError, attr_padding_str)
def attr_padding_list():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [1, 1], [0, 0], [0, 0], [1, 1]])
self.assertRaises(ValueError, attr_padding_list)
def attr_padding_with_data_format():
out = fluid.layers.conv2d_transpose(
input=data,
groups=1,
num_filters=6,
filter_size=3,
padding=[[1, 1], [0, 0], [0, 0], [1, 0], [1, 1]],
data_format='NDHWC')
self.assertRaises(ValueError, attr_padding_with_data_format)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册