shuffle_channel_op.cc 5.1 KB
Newer Older
S
shippingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/shuffle_channel_op.h"
13

S
sneaxiy 已提交
14
#include <memory>
15
#include <string>
S
shippingwang 已提交
16 17 18 19 20 21 22 23 24

namespace paddle {
namespace operators {

class ShuffleChannelOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
25 26
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
S
shippingwang 已提交
27 28

    auto input_dims = ctx->GetInputDim("X");
29
    PADDLE_ENFORCE_EQ(
30 31
        input_dims.size(),
        4,
32
        platform::errors::InvalidArgument("The layout of input is NCHW."));
S
shippingwang 已提交
33 34 35

    ctx->SetOutputDim("Out", input_dims);
  }
S
shippingwang 已提交
36 37 38 39

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
40 41 42 43 44
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
45 46
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
47 48 49 50 51
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
S
shippingwang 已提交
52
  }
S
shippingwang 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66
};

class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor, default Tensor<float>), "
             "the input feature data of ShuffleChannelOp, the layout is NCHW.");
    AddOutput("Out",
              "(Tensor, default Tensor<float>), the output of "
              "ShuffleChannelOp. The layout is NCHW.");
    AddAttr<int>("group", "the number of groups.")
        .SetDefault(1)
        .AddCustomChecker([](const int& group) {
67 68
          PADDLE_ENFORCE_GE(group,
                            1,
69 70
                            platform::errors::InvalidArgument(
                                "group should be larger than 0."));
S
shippingwang 已提交
71 72 73
        });
    AddComment(R"DOC(
		Shuffle Channel operator
S
shippingwang 已提交
74 75 76
		This opearator shuffles the channels of input x.
		It  divide the input channels in each group into several subgroups,
		and obtain a new order by selecting element from every subgroup one by one.
S
shippingwang 已提交
77 78 79 80 81 82 83 84 85

		Shuffle channel operation makes it possible to build more powerful structures
		with multiple group convolutional layers.
		please get more information from the following paper:
		https://arxiv.org/pdf/1707.01083.pdf
        )DOC");
  }
};

S
shippingwang 已提交
86
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
S
shippingwang 已提交
87 88 89 90
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
91
    auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
92
    PADDLE_ENFORCE_EQ(
93 94
        input_dims.size(),
        4,
95
        platform::errors::InvalidArgument("The layout of input is NCHW."));
S
shippingwang 已提交
96

S
shippingwang 已提交
97 98
    ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
  }
S
shippingwang 已提交
99 100 101 102

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
103 104 105
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
S
shippingwang 已提交
106
  }
S
shippingwang 已提交
107 108
};

H
hong 已提交
109 110
template <typename T>
class ShuffleChannelGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
111
 public:
H
hong 已提交
112
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
113 114

 protected:
115
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
116
    op->SetType("shuffle_channel_grad");
H
hong 已提交
117 118 119
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
120 121 122
  }
};

S
shippingwang 已提交
123 124 125 126
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
127 128
REGISTER_OPERATOR(shuffle_channel,
                  ops::ShuffleChannelOp,
H
hong 已提交
129 130 131
                  ops::ShuffleChannelOpMaker,
                  ops::ShuffleChannelGradMaker<paddle::framework::OpDesc>,
                  ops::ShuffleChannelGradMaker<paddle::imperative::OpBase>);
S
shippingwang 已提交
132

S
shippingwang 已提交
133
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
S
shippingwang 已提交
134

L
Leo Chen 已提交
135 136 137
REGISTER_OP_CPU_KERNEL(shuffle_channel,
                       ops::ShuffleChannelOpKernel<phi::CPUContext, float>,
                       ops::ShuffleChannelOpKernel<phi::CPUContext, double>);
S
shippingwang 已提交
138 139

REGISTER_OP_CPU_KERNEL(
S
shippingwang 已提交
140
    shuffle_channel_grad,
L
Leo Chen 已提交
141 142
    ops::ShuffleChannelGradOpKernel<phi::CPUContext, float>,
    ops::ShuffleChannelGradOpKernel<phi::CPUContext, double>);