conv_transpose_op.h 2.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
F
From00 已提交
16 17 18 19

#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
C
chengduoZH 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {

// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
26 27
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
28
  void Make() override;
C
chengduoZH 已提交
29 30
};

C
chengduoZH 已提交
31 32
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
33
  void Make() override;
C
chengduoZH 已提交
34 35
};

C
chengduoZH 已提交
36
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
37 38
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
39 40 41 42

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
43 44

  framework::OpKernelType GetKernelTypeForVar(
F
From00 已提交
45
      const std::string& var_name, const framework::Tensor& tensor,
46
      const framework::OpKernelType& expected_kernel_type) const override;
C
chengduoZH 已提交
47 48
};

C
chengduoZH 已提交
49
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
50 51
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
52 53 54 55

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
56 57
};

58 59 60 61 62 63 64 65 66
class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
};

C
chengduoZH 已提交
67 68
}  // namespace operators
}  // namespace paddle