add_op.cc 1.5 KB
Newer Older
Q
qijun 已提交
1 2 3
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
Y
Yu Yang 已提交
4 5

namespace paddle {
Y
Yu Yang 已提交
6
namespace operators {
Y
Yu Yang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38

class AddOp : public framework::OperatorWithKernel {
protected:
  void InferShape(
      const std::vector<const framework::Tensor *> &inputs,
      const std::vector<framework::Tensor *> &outputs) const override {
    PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
    PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
    PADDLE_ENFORCE(
        inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
        "Inputs/Outputs of AddOp must all be set");
    PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
                   "Two input of Add Op's dimension must be same.");
    // Need set dims in Tensor
    // outputs[0]->set_dims(inputs[0]->dims())
  }
};

class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
  AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "The first input of add op");
    AddInput("Y", "The second input of add op");
    AddOutput("Out", "The output of add op");
    AddComment(R"DOC(
Two Element Add Operator.

The equation is: Out = X + Y
)DOC");
  }
};
Q
qijun 已提交
39
}  // namespace operators
Y
Yu Yang 已提交
40 41
}  // namespace paddle

Y
Yu Yang 已提交
42 43
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
Q
qijun 已提交
44 45
    add_two,
    ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>);