multiplex_op.cc 3.7 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/operators/multiplex_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class MultiplexOp : public framework::OperatorWithKernel {
 public:
  MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs,
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    auto *out = ctx.Output<Tensor>("Out");
    auto num_ins = ins.size();
    PADDLE_ENFORCE(num_ins > 2,
                   "multiplex operator should have more than 2 inputs.");
    PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1,
                      "The first input must be a index vector.");
    auto in_dim = ins[1]->dims();

    for (size_t i = 2; i < num_ins; i++) {
      auto dim = ins[i]->dims();
      PADDLE_ENFORCE(
          in_dim == dim,
          "All the input tensors except the first one must have the same size");
    }
    out->Resize(in_dim);
  }
};

class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  MultiplexOpMaker(framework::OpProto *proto,
                   framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "The input tensor of multiplex operator.").AsDuplicable();
    AddOutput("Out", "The output tensor of multiplex operator.");
    AddComment(R"DOC(Multiplex operator

Multiplex multiple tensors according to the index provided by the first
input tensor.

ins[0]: the index of the tensor to output of size batchSize.
ins[1:N]: the candidate output tensor.
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor.

For each i-th row of output:

y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)

where y is the output tensor. `x_{k}` is the k-th input layer
and `k = x{0}[i] + 1`.

)DOC");
  }
};

class MultiplexGradOp : public framework::OperatorWithKernel {
 public:
  MultiplexGradOp(const std::string &type,
                  const framework::VariableNameMap &inputs,
                  const framework::VariableNameMap &outputs,
                  const framework::AttributeMap &attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
                            "Input(Out@GRAD) shouldn't be null.");
    auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
    auto ins = ctx.MultiInput<Tensor>("X");
    for (size_t i = 0; i < ins.size(); i++) {
      auto dims = ins[i]->dims();
      d_ins[i]->Resize(dims);
    }
  }
};

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
            ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(multiplex, ops::MultiplexCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(multiplex_grad, ops::MultiplexGradCPUKernel<float>);