unpool_op.cc 6.3 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sweetsky0901 已提交
14

15
#include <memory>
16 17
#include <string>
#include <vector>
X
xiaoting 已提交
18 19 20 21 22 23

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"

S
sweetsky0901 已提交
24 25 26
namespace paddle {
namespace operators {

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
class Unpool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput(
        "X",
        "(Tensor) The input tensor of unpool operator. "
        "The format of input tensor is NCDHW. Where N is batch size, C is the "
        "number of channels, D, H and W is the depth, height and width of "
        "feature.");
    AddInput(
        "Indices",
        "(Tensor) The input tensor of the indices given out by MaxPool3d. "
        "The format of input tensor is NCDHW. Where N is batch size, C is the "
        "number of channels, D, H and W is the depth, height and width of "
        "feature.");
    AddOutput("Out",
              "(Tensor) The output tensor of unpool operator."
              "The format of output tensor is also NCDHW."
              "Where N is batch size, C is "
              "the number of channels, D, H and W is the depth, height and "
              "width of feature.");
    AddAttr<std::vector<int>>(
        "ksize",
        "(vector), the unpooling window size(depth, height, width) "
        "of unpooling operator.");
    AddAttr<std::vector<int>>(
        "strides",
        "(vector, default:{1, 1, 1}), "
        "strides (depth, height, width) of unpooling operator.")
        .SetDefault({1, 1, 1});
    AddAttr<std::vector<int>>(
        "paddings",
        "(vector default:{0, 0,0}), "
        "paddings (depth, height, width) of unpooling operator.")
        .SetDefault({0, 0, 0});
    AddAttr<std::string>(
        "unpooling_type",
        "(string), unpooling type, can be \"max\" for max-unpooling ")
        .InEnum({"max"});
    AddAttr<std::vector<int>>("output_size",
                              "(vector, optional). The shape of output.")
        .SetDefault({0, 0, 0});
    AddAttr<std::string>(
        "data_format",
        "(string, default NCDHW)"
        "Defaults to \"NCDHW\". Specify the data format of the output data, ")
        .SetDefault("NCDHW");
    AddComment(R"DOC(
Input shape is: $(N, C_{in}, D_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, D_{out}, H_{out}, W_{out})$, where
$$
D_{out} = (D_{in}-1) * strides[0] - 2 * paddings[0] + ksize[0] \\
H_{out} = (H_{in}-1) * strides[1] - 2 * paddings[1] + ksize[1] \\
W_{out} = (W_{in}-1) * strides[2] - 2 * paddings[2] + ksize[2]
$$
)DOC");
  }
};

Y
Yang Yang 已提交
86
int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) {
S
sweetsky0901 已提交
87
  int output_size = (input_size - 1) * stride - 2 * padding + ksize;
S
sweetsky0901 已提交
88 89 90 91
  return output_size;
}

class UnpoolOp : public framework::OperatorWithKernel {
S
sweetsky0901 已提交
92
 protected:
93
  phi::KernelKey GetExpectedKernelType(
S
sweetsky0901 已提交
94
      const framework::ExecutionContext& ctx) const override {
95 96
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.GetPlace());
S
sweetsky0901 已提交
97
  }
S
sweetsky0901 已提交
98

S
sweetsky0901 已提交
99 100
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
S
sweetsky0901 已提交
101 102
};

103 104
class Unpool3dOp : public framework::OperatorWithKernel {
 protected:
105
  phi::KernelKey GetExpectedKernelType(
106
      const framework::ExecutionContext& ctx) const override {
107 108
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.GetPlace());
109 110 111 112 113 114
  }

 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
};

115 116 117 118
template <typename T>
class UnpoolOpGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
119
  void Apply(GradOpPtr<T> op) const override {
120 121 122 123 124 125 126 127 128 129
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    op->SetInput("Indices", this->Input("Indices"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
template <typename T>
class Unpool3dOpGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
  void Apply(GradOpPtr<T> op) const override {
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    op->SetInput("Indices", this->Input("Indices"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

class Unpool3dOpGrad : public framework::OperatorWithKernel {
 protected:
147
  phi::KernelKey GetExpectedKernelType(
148
      const framework::ExecutionContext& ctx) const override {
149 150
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.GetPlace());
151 152 153 154 155 156
  }

 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
};

S
sweetsky0901 已提交
157 158
}  // namespace operators
}  // namespace paddle
S
sweetsky0901 已提交
159 160

namespace ops = paddle::operators;
X
xiaoting 已提交
161 162 163 164

DECLARE_INFER_SHAPE_FUNCTOR(unpool,
                            Unpool3dInferShapeFunctor,
                            PD_INFER_META(phi::Unpool3dInferMeta));
165

166 167 168
REGISTER_OPERATOR(unpool3d,
                  ops::Unpool3dOp,
                  ops::Unpool3dOpMaker,
169
                  ops::Unpool3dOpGradMaker<paddle::framework::OpDesc>,
X
xiaoting 已提交
170 171 172 173 174 175
                  ops::Unpool3dOpGradMaker<paddle::imperative::OpBase>,
                  Unpool3dInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(unpool3d_grad,
                            Unpool3dGradInferShapeFunctor,
                            PD_INFER_META(phi::UnchangedInferMeta));
176

X
xiaoting 已提交
177 178 179
REGISTER_OPERATOR(unpool3d_grad,
                  ops::Unpool3dOpGrad,
                  Unpool3dGradInferShapeFunctor);