grad_op_builder_test.cc 5.2 KB
Newer Older
F
fengjiayi 已提交
1
#include "paddle/framework/grad_op_builder.h"
2 3 4 5 6 7 8 9 10
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

USE_OP(add_two);

namespace paddle {
namespace framework {

F
fengjiayi 已提交
11 12 13 14 15
class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
 public:
  MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("In1", "a single input");
Y
Yu Yang 已提交
16
    AddInput("In2_mult", "a multiple input").AsDuplicable();
F
fengjiayi 已提交
17 18
    AddInput("In3", "another single input");
    AddOutput("Out1", "a single output");
Y
Yu Yang 已提交
19
    AddOutput("Out2_mult", "a multiple output").AsDuplicable();
F
fengjiayi 已提交
20 21 22 23 24 25 26 27 28
    AddComment("test op with multiple inputs and outputs");
  }
};

class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
 public:
  IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("In1", "a single input");
Y
Yu Yang 已提交
29 30 31 32
    AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
    AddInput("In3_mult", "another multiple input").AsDuplicable();
    AddOutput("Out1_mult", "a multiple output").AsDuplicable();
    AddOutput("Out2", "a single output").AsNoGradient();
F
fengjiayi 已提交
33 34 35 36 37 38 39 40 41
    AddComment("op with inputs and outputs ignored in gradient calculating");
  }
};

}  // namespace framework
}  // namespace paddle

namespace f = paddle::framework;

F
fengjiayi 已提交
42
TEST(GradOpBuilder, AddTwo) {
Y
Yu Yang 已提交
43 44
  std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
      "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
F
fengjiayi 已提交
45 46
  std::shared_ptr<f::OperatorBase> grad_add_op =
      f::OpRegistry::CreateGradOp(*add_op);
47 48
  EXPECT_EQ(grad_add_op->inputs_.size(), 4UL);
  EXPECT_EQ(grad_add_op->outputs_.size(), 2UL);
49 50 51
  EXPECT_EQ(grad_add_op->Input("X"), "x");
  EXPECT_EQ(grad_add_op->Input("Y"), "y");
  EXPECT_EQ(grad_add_op->Input("Out"), "out");
52 53 54
  EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
  EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x"));
  EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
55 56
}

57 58
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
F
fengjiayi 已提交
59 60 61

TEST(GradOpBuilder, MutiInOut) {
  std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
F
fengjiayi 已提交
62 63 64
      "mult_io", {{"In1", {"in1"}},
                  {"In2_mult", {"in2_1", "in2_2", "in2_3"}},
                  {"In3", {"in3"}}},
65
      {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
F
fengjiayi 已提交
66 67 68
  std::shared_ptr<f::OperatorBase> grad_test_op =
      f::OpRegistry::CreateGradOp(*test_op);

69
  ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL);
F
fengjiayi 已提交
70 71 72 73 74 75 76
  EXPECT_EQ(grad_test_op->Input("In1"), "in1");
  EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
            std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
  EXPECT_EQ(grad_test_op->Input("In3"), "in3");
  EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
  EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
            std::vector<std::string>({"out2_1", "out2_2"}));
77 78 79
  EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")),
            f::GradVarName("out1"));
  EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")),
Y
Yi Wang 已提交
80
            std::vector<std::string>(
81
                {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
F
fengjiayi 已提交
82

83
  ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
84 85 86 87 88 89
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
            std::vector<std::string>({f::GradVarName("in2_1"),
                                      f::GradVarName("in2_2"),
                                      f::GradVarName("in2_3")}));
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3"));
F
fengjiayi 已提交
90 91 92 93
}

TEST(GradOpBuilder, IOIgnoredInGradient) {
  std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
F
fengjiayi 已提交
94 95 96
      "io_ignored", {{"In1", {"in1"}},
                     {"In2_mult", {"in2_1", "in2_2"}},
                     {"In3_mult", {"in3_1", "in3_2"}}},
97
      {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
F
fengjiayi 已提交
98 99 100 101
  std::shared_ptr<f::OperatorBase> grad_test_op =
      f::OpRegistry::CreateGradOp(*test_op);

  // 'In2' and 'Out2' are ignored in gradient calculating
Q
qingqing01 已提交
102
  ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL);
F
fengjiayi 已提交
103 104 105 106 107
  EXPECT_EQ(grad_test_op->Input("In1"), "in1");
  EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
            std::vector<std::string>({"in3_1", "in3_2"}));
  EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
            std::vector<std::string>({"out1_1", "out1_2"}));
108
  EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
Y
Yi Wang 已提交
109
            std::vector<std::string>(
110 111 112
                {f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
  EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
            f::GradVarName("out2"));
F
fengjiayi 已提交
113

114
  ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
115 116
  EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
Y
Yi Wang 已提交
117
            std::vector<std::string>(
118 119
                {f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
  EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")),
Y
Yi Wang 已提交
120
            std::vector<std::string>(
121
                {f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
F
fengjiayi 已提交
122
}