提交 aa770198 编写于 作者: C chengduoZH

add dilation in c++ code

上级 23e38216
......@@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
......@@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The input channel should be the same "
"as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
filter_dims[i + 2]);
filter_extent);
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......
......@@ -63,6 +63,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
// groups will alway be disabled in conv2dtranspose.
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -114,7 +115,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
......@@ -134,8 +134,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if (data_dim == 2U) {
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col,
std::vector<int>{dilations[0], dilations[1]}, strides,
col2im(context.device_context(), col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&output_batch);
......@@ -168,6 +167,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -221,7 +221,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......@@ -242,10 +241,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (data_dim == 2U) {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
im2col(context.device_context(), output_grad_batch, dilations,
strides, std::vector<int>{paddings[0], paddings[1],
paddings[0], paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col: dy -> col_matrix
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册