提交 aa770198 编写于 作者: C chengduoZH

add dilation in c++ code

上级 23e38216
...@@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor."); "ConvTransposeOp intput should be 4-D or 5-D tensor.");
...@@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"ConvTransposeOp paddings dimension and strides " "ConvTransposeOp paddings dimension and strides "
"dimension should be the same."); "dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The input channel should be the same " "In ConvTransposeOp, The input channel should be the same "
"as the number of filters."); "as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
filter_dims[i + 2]); filter_extent);
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
......
...@@ -63,6 +63,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -63,6 +63,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -114,7 +115,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -114,7 +115,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<Place, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input) // on input)
...@@ -134,8 +134,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -134,8 +134,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if (data_dim == 2U) { if (data_dim == 2U) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col, col2im(context.device_context(), col, dilations, strides,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&output_batch); &output_batch);
...@@ -168,6 +167,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -168,6 +167,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -221,7 +221,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -221,7 +221,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
...@@ -242,10 +241,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -242,10 +241,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (data_dim == 2U) { if (data_dim == 2U) {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, im2col(context.device_context(), output_grad_batch, dilations,
std::vector<int>{dilations[0], dilations[1]}, strides, strides, std::vector<int>{paddings[0], paddings[1],
std::vector<int>{paddings[0], paddings[1], paddings[0], paddings[0], paddings[1]},
paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册