multiplex_op.cc 4.2 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/operators/multiplex_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class MultiplexOp : public framework::OperatorWithKernel {
 public:
24
  using framework::OperatorWithKernel::OperatorWithKernel;
Y
Yibing Liu 已提交
25 26 27

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {
28 29
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"),
                            "Input(Ids) shouldn't be null.");
30
    PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
31
                   "MultiInput(X) shouldn't be empty.");
32 33
    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
                            "Output(Out) shouldn't be null.");
34 35 36 37 38
    auto ids_dim = ctx.Input<Tensor>("Ids")->dims();
    PADDLE_ENFORCE(
        ids_dim.size() == 2 && ids_dim[1] == 1,
        "The index tensor must be a vector with size batchSize x 1.");

Y
Yibing Liu 已提交
39
    auto ins = ctx.MultiInput<Tensor>("X");
Y
Yibing Liu 已提交
40
    auto *out = ctx.Output<Tensor>("Out");
Y
Yibing Liu 已提交
41
    auto num_ins = ins.size();
42 43 44
    PADDLE_ENFORCE(num_ins > 1,
                   "multiplex operator should have more than "
                   "one candidate input tensors.");
Y
Yibing Liu 已提交
45

46
    auto in_dim = ins[0]->dims();
47 48
    PADDLE_ENFORCE(in_dim.size() >= 2,
                   "The rank of candidate tensors must be not less than 2.");
49
    for (size_t i = 1; i < num_ins; i++) {
Y
Yibing Liu 已提交
50
      auto dim = ins[i]->dims();
Y
Yibing Liu 已提交
51
      PADDLE_ENFORCE(in_dim == dim,
52
                     "All the candidate tensors must have the same size.");
Y
Yibing Liu 已提交
53 54 55 56 57 58 59 60 61 62
    }
    out->Resize(in_dim);
  }
};

class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  MultiplexOpMaker(framework::OpProto *proto,
                   framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
63 64 65
    AddInput("Ids", "The index tensor of multiplex operator.");
    AddInput("X", "The candidate tensors of multiplex operator.")
        .AsDuplicable();
Y
Yibing Liu 已提交
66 67 68
    AddOutput("Out", "The output tensor of multiplex operator.");
    AddComment(R"DOC(Multiplex operator

69
Multiplex multiple tensors according to the index provided by the index tensor.
Y
Yibing Liu 已提交
70

71 72
Ids: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2).
Y
Yibing Liu 已提交
73
For each index i from 0 to batchSize - 1, the output is the i-th row of the
74
the (Ids[i])-th tensor.
Y
Yibing Liu 已提交
75

76
For i-th row of the output tensor:
Y
Yibing Liu 已提交
77

78
y[i] = x_{k}[i]
Y
Yibing Liu 已提交
79

80
where y is the output tensor. `x_{k}` is the k-th input tensor
81
and `k = Ids[i]`.
Y
Yibing Liu 已提交
82 83 84 85 86 87
)DOC");
  }
};

class MultiplexGradOp : public framework::OperatorWithKernel {
 public:
88
  using framework::OperatorWithKernel::OperatorWithKernel;
Y
Yibing Liu 已提交
89 90 91

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {
92
    PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
Y
Yibing Liu 已提交
93
                   "Input(X) should not be null.");
94
    PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
Y
Yibing Liu 已提交
95
                   "Output(X@Grad) should not be null.");
Y
Yibing Liu 已提交
96 97
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
                            "Input(Out@GRAD) shouldn't be null.");
Y
Yibing Liu 已提交
98
    auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
Y
Yibing Liu 已提交
99
    auto ins = ctx.MultiInput<Tensor>("X");
100 101
    // No need to compute gradient for Input(Ids)
    for (size_t i = 0; i < ins.size(); i++) {
102 103 104
      if (d_ins[i]) {
        d_ins[i]->Resize(ins[i]->dims());
      }
Y
Yibing Liu 已提交
105 106 107 108 109 110 111 112 113 114
    }
  }
};

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
            ops::MultiplexGradOp);
Y
Yibing Liu 已提交
115 116
REGISTER_OP_CPU_KERNEL(
    multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
117 118
REGISTER_OP_CPU_KERNEL(
    multiplex_grad,
Y
Yibing Liu 已提交
119
    ops::MultiplexGradCPUKernel<paddle::platform::CPUPlace, float>);