grad_op_builder_test.cc 5.5 KB
Newer Older
F
fengjiayi 已提交
1
#include "paddle/framework/grad_op_builder.h"
2 3 4 5 6 7 8 9 10
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

USE_OP(add_two);

namespace paddle {
namespace framework {

Y
Yi Wang 已提交
11
class NOP : public OperatorBase {
F
fengjiayi 已提交
12
 public:
Y
Yu Yang 已提交
13
  using OperatorBase::OperatorBase;
F
fengjiayi 已提交
14 15 16 17 18 19 20 21 22 23
  void InferShape(const Scope &scope) const override {}
  void Run(const Scope &scope,
           const platform::DeviceContext &dev_ctx) const override {}
};

class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
 public:
  MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("In1", "a single input");
Y
Yu Yang 已提交
24
    AddInput("In2_mult", "a multiple input").AsDuplicable();
F
fengjiayi 已提交
25 26
    AddInput("In3", "another single input");
    AddOutput("Out1", "a single output");
Y
Yu Yang 已提交
27
    AddOutput("Out2_mult", "a multiple output").AsDuplicable();
F
fengjiayi 已提交
28 29 30 31 32 33 34 35 36
    AddComment("test op with multiple inputs and outputs");
  }
};

class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
 public:
  IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("In1", "a single input");
Y
Yu Yang 已提交
37 38 39 40
    AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
    AddInput("In3_mult", "another multiple input").AsDuplicable();
    AddOutput("Out1_mult", "a multiple output").AsDuplicable();
    AddOutput("Out2", "a single output").AsNoGradient();
F
fengjiayi 已提交
41 42 43 44 45 46 47 48 49
    AddComment("op with inputs and outputs ignored in gradient calculating");
  }
};

}  // namespace framework
}  // namespace paddle

namespace f = paddle::framework;

F
fengjiayi 已提交
50
TEST(GradOpBuilder, AddTwo) {
Y
Yu Yang 已提交
51 52
  std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
      "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
F
fengjiayi 已提交
53 54
  std::shared_ptr<f::OperatorBase> grad_add_op =
      f::OpRegistry::CreateGradOp(*add_op);
55 56
  EXPECT_EQ(grad_add_op->inputs_.size(), 4UL);
  EXPECT_EQ(grad_add_op->outputs_.size(), 2UL);
57 58 59
  EXPECT_EQ(grad_add_op->Input("X"), "x");
  EXPECT_EQ(grad_add_op->Input("Y"), "y");
  EXPECT_EQ(grad_add_op->Input("Out"), "out");
60 61 62
  EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
  EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x"));
  EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
63 64
}

Y
Yi Wang 已提交
65 66 67 68
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
F
fengjiayi 已提交
69 70 71

TEST(GradOpBuilder, MutiInOut) {
  std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
Y
Yu Yang 已提交
72 73 74
      "mult_io", {{"In1", {"in1"}},
                  {"In2_mult", {"in2_1", "in2_2", "in2_3"}},
                  {"In3", {"in3"}}},
75
      {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
F
fengjiayi 已提交
76 77 78
  std::shared_ptr<f::OperatorBase> grad_test_op =
      f::OpRegistry::CreateGradOp(*test_op);

79
  ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL);
F
fengjiayi 已提交
80 81 82 83 84 85 86
  EXPECT_EQ(grad_test_op->Input("In1"), "in1");
  EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
            std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
  EXPECT_EQ(grad_test_op->Input("In3"), "in3");
  EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
  EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
            std::vector<std::string>({"out2_1", "out2_2"}));
87 88 89
  EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")),
            f::GradVarName("out1"));
  EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")),
Y
Yi Wang 已提交
90
            std::vector<std::string>(
91
                {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
F
fengjiayi 已提交
92

93 94 95 96 97 98 99
  ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
            std::vector<std::string>({f::GradVarName("in2_1"),
                                      f::GradVarName("in2_2"),
                                      f::GradVarName("in2_3")}));
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3"));
F
fengjiayi 已提交
100 101 102 103
}

TEST(GradOpBuilder, IOIgnoredInGradient) {
  std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
Y
Yu Yang 已提交
104 105 106
      "io_ignored", {{"In1", {"in1"}},
                     {"In2_mult", {"in2_1", "in2_2"}},
                     {"In3_mult", {"in3_1", "in3_2"}}},
107
      {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
F
fengjiayi 已提交
108 109 110 111
  std::shared_ptr<f::OperatorBase> grad_test_op =
      f::OpRegistry::CreateGradOp(*test_op);

  // 'In2' and 'Out2' are ignored in gradient calculating
Q
qingqing01 已提交
112
  ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL);
F
fengjiayi 已提交
113 114 115 116 117
  EXPECT_EQ(grad_test_op->Input("In1"), "in1");
  EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
            std::vector<std::string>({"in3_1", "in3_2"}));
  EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
            std::vector<std::string>({"out1_1", "out1_2"}));
118
  EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
Y
Yi Wang 已提交
119
            std::vector<std::string>(
120 121 122
                {f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
  EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
            f::GradVarName("out2"));
F
fengjiayi 已提交
123

124 125 126
  ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
Y
Yi Wang 已提交
127
            std::vector<std::string>(
128 129
                {f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")),
Y
Yi Wang 已提交
130
            std::vector<std::string>(
131
                {f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
F
fengjiayi 已提交
132
}