shuffle_channel_op.cc 5.4 KB
Newer Older
S
shippingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/shuffle_channel_op.h"
13

S
sneaxiy 已提交
14
#include <memory>
15
#include <string>
S
shippingwang 已提交
16 17 18 19 20 21 22 23 24

namespace paddle {
namespace operators {

class ShuffleChannelOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
25 26
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
S
shippingwang 已提交
27 28

    auto input_dims = ctx->GetInputDim("X");
29
    PADDLE_ENFORCE_EQ(
30 31
        input_dims.size(),
        4,
32
        platform::errors::InvalidArgument("The layout of input is NCHW."));
S
shippingwang 已提交
33 34 35

    ctx->SetOutputDim("Out", input_dims);
  }
S
shippingwang 已提交
36 37 38 39

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
40 41 42 43 44
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
45 46
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
47 48 49 50 51
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
S
shippingwang 已提交
52
  }
S
shippingwang 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66
};

class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor, default Tensor<float>), "
             "the input feature data of ShuffleChannelOp, the layout is NCHW.");
    AddOutput("Out",
              "(Tensor, default Tensor<float>), the output of "
              "ShuffleChannelOp. The layout is NCHW.");
    AddAttr<int>("group", "the number of groups.")
        .SetDefault(1)
        .AddCustomChecker([](const int& group) {
67 68
          PADDLE_ENFORCE_GE(group,
                            1,
69 70
                            platform::errors::InvalidArgument(
                                "group should be larger than 0."));
S
shippingwang 已提交
71
        });
72 73 74 75
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false)
        .AsExtra();
S
shippingwang 已提交
76 77 78

    AddComment(R"DOC(
		Shuffle Channel operator
S
shippingwang 已提交
79 80 81
		This opearator shuffles the channels of input x.
		It  divide the input channels in each group into several subgroups,
		and obtain a new order by selecting element from every subgroup one by one.
S
shippingwang 已提交
82 83 84 85 86 87 88 89 90

		Shuffle channel operation makes it possible to build more powerful structures
		with multiple group convolutional layers.
		please get more information from the following paper:
		https://arxiv.org/pdf/1707.01083.pdf
        )DOC");
  }
};

S
shippingwang 已提交
91
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
S
shippingwang 已提交
92 93 94 95
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
96
    auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
97
    PADDLE_ENFORCE_EQ(
98 99
        input_dims.size(),
        4,
100
        platform::errors::InvalidArgument("The layout of input is NCHW."));
S
shippingwang 已提交
101

S
shippingwang 已提交
102 103
    ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
  }
S
shippingwang 已提交
104 105 106 107

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
108 109 110
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
S
shippingwang 已提交
111
  }
S
shippingwang 已提交
112 113
};

H
hong 已提交
114 115
template <typename T>
class ShuffleChannelGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
116
 public:
H
hong 已提交
117
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
118 119

 protected:
120
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
121
    op->SetType("shuffle_channel_grad");
H
hong 已提交
122 123 124
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
125 126 127
  }
};

S
shippingwang 已提交
128 129 130 131
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
132 133
REGISTER_OPERATOR(shuffle_channel,
                  ops::ShuffleChannelOp,
H
hong 已提交
134 135 136
                  ops::ShuffleChannelOpMaker,
                  ops::ShuffleChannelGradMaker<paddle::framework::OpDesc>,
                  ops::ShuffleChannelGradMaker<paddle::imperative::OpBase>);
S
shippingwang 已提交
137

S
shippingwang 已提交
138
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
S
shippingwang 已提交
139 140

REGISTER_OP_CPU_KERNEL(
S
shippingwang 已提交
141
    shuffle_channel,
S
shippingwang 已提交
142 143 144 145
    ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
S
shippingwang 已提交
146
    shuffle_channel_grad,
S
shippingwang 已提交
147 148 149
    ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
                                    double>);