backward_test.cc 34.3 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yu Yang 已提交
14

Y
Yu Yang 已提交
15
#include "paddle/framework/backward.h"
D
dongzhihong 已提交
16

Y
Yu Yang 已提交
17
#include <gtest/gtest.h>
18 19
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
Y
Yu Yang 已提交
20
#include "paddle/framework/op_registry.h"
21
#include "paddle/framework/var_desc.h"
Y
Yan Chunwei 已提交
22
#include "paddle/operators/net_op.h"
Y
Yu Yang 已提交
23

Y
Yang Yu 已提交
24
USE_NO_KERNEL_OP(fill_constant);
Q
QI JUN 已提交
25

Y
Yu Yang 已提交
26 27 28
namespace paddle {
namespace framework {

D
dongzhihong 已提交
29 30
using DeviceContext = platform::DeviceContext;

Q
Qiao Longfei 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44
class NoneOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext *ctx) const override {}
};

template <typename Place, typename T>
class NoneKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {}
};

Y
Yu Yang 已提交
45
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
Y
Yu Yang 已提交
46
 public:
Y
Yu Yang 已提交
47
  RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Y
Yu Yang 已提交
48
      : OpProtoAndCheckerMaker(proto, op_checker) {
49 50 51
    AddInput("X", "Input X of Add");
    AddInput("b", "Bias of Add");
    AddOutput("Out", "Out of Add");
Y
Yu Yang 已提交
52 53 54 55
    AddComment("Add Op");
  }
};

56 57 58 59 60
class RowWiseAddGradMaker : public SingleGradOpDescMaker {
 public:
  using SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
61 62
  std::unique_ptr<OpDesc> Apply() const override {
    auto grad_op = new OpDesc();
Y
Yu Yang 已提交
63 64 65 66
    grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
    grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
    grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
    grad_op->SetType("rowwise_add_grad");
Y
Yu Yang 已提交
67
    return std::unique_ptr<OpDesc>(grad_op);
68 69 70
  }
};

Y
Yu Yang 已提交
71 72 73 74
class MulOpMaker : public OpProtoAndCheckerMaker {
 public:
  MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
75 76
    AddInput("X", "A");
    AddInput("Y", "B");
Y
Yu Yang 已提交
77
    AddOutput("Out", "Out");
F
fengjiayi 已提交
78 79
    AddAttr<int>("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
    AddAttr<int>("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
Y
Yu Yang 已提交
80 81 82 83 84 85 86 87 88
    AddComment("Mul");
  }
};

class SigmoidOpMaker : public OpProtoAndCheckerMaker {
 public:
  SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "X");
Y
Yu Yang 已提交
89
    AddOutput("Out", "Y");
Y
Yu Yang 已提交
90 91 92 93
    AddComment("Sigmoid");
  }
};

D
dongzhihong 已提交
94 95 96 97 98
class NoGradOpMaker : public OpProtoAndCheckerMaker {
 public:
  NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "X input");
Y
Yu Yang 已提交
99
    AddOutput("Out", "Y output");
D
dongzhihong 已提交
100 101 102 103
    AddComment("NoGradOp, same input output. no Grad");
  }
};

D
dongzhihong 已提交
104
class FcOp : public operators::NetOp {
Y
Yu Yang 已提交
105
 public:
Y
Yu Yang 已提交
106 107
  FcOp(const std::string &type, const VariableNameMap &inputs,
       const VariableNameMap &outputs, const AttributeMap &attrs)
Y
Yu Yang 已提交
108
      : NetOp(type, inputs, outputs, attrs) {
Y
Yiqun Liu 已提交
109 110 111
    AppendOp(OpRegistry::CreateOp(
        "mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}},
        {{"Out", {Output("mul_result")}}}, AttributeMap{}));
112
    auto input_b = Inputs("b");
Y
Yu Yang 已提交
113
    std::string before_act = "mul_result";
114
    if (input_b.size() != 0) {
Y
Yu Yang 已提交
115
      AppendOp(OpRegistry::CreateOp(
116
          "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
Y
Yiqun Liu 已提交
117
          {{"Out", {Output("add_result")}}}, AttributeMap{}));
Y
Yu Yang 已提交
118 119 120
      before_act = "add_result";
    } else {
      auto out_varname = Output("add_result");
121 122
      if (out_varname != kEmptyVarName) {
        this->Rename(out_varname, kEmptyVarName);
Y
Yu Yang 已提交
123
      }
Y
Yu Yang 已提交
124
    }
Y
Yu Yang 已提交
125

Y
Yu Yang 已提交
126
    AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
Y
Yiqun Liu 已提交
127
                                  {{"Out", {Output("Out")}}}, AttributeMap{}));
Y
Yu Yang 已提交
128 129 130 131 132 133 134 135 136 137 138
    CompleteAddOp(false);
  }
};

class FcOpMaker : public OpProtoAndCheckerMaker {
 public:
  FcOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "x");
    AddInput("W", "w");
    AddInput("b", "b");
Y
Yu Yang 已提交
139 140
    AddOutput("mul_result", "").AsIntermediate();
    AddOutput("add_result", "").AsIntermediate();
Y
Yu Yang 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
    AddOutput("Out", "");
    AddComment("");
  }
};

class ManyOutputOpMaker : public OpProtoAndCheckerMaker {
 public:
  ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("x", "x");
    AddOutput("y", "y");
    AddOutput("z", "z");
    AddComment("");
  }
};

class FillZeroOpMaker : public OpProtoAndCheckerMaker {
 public:
  FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
D
dangqingqing 已提交
161
    AddInput("X", "x");
F
fengjiayi 已提交
162
    AddOutput("Out", "out");
Y
Yu Yang 已提交
163 164 165
    AddComment("");
  }
};
Y
Yu Yang 已提交
166

D
dongzhihong 已提交
167
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yu Yang 已提交
168
 public:
169
  SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Y
Yu Yang 已提交
170
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
171 172
    AddInput("X", "the input tensors of sum operator.").AsDuplicable();
    AddOutput("Out", "the output tensor of sum operator.");
Y
Yu Yang 已提交
173 174 175
    AddComment("");
  }
};
D
dongzhihong 已提交
176

F
fengjiayi 已提交
177 178 179 180 181 182 183 184 185 186 187 188
class MultInOutOpMaker : public OpProtoAndCheckerMaker {
 public:
  MultInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "x");
    AddInput("H", "h");
    AddOutput("Y", "y");
    AddOutput("Z", "z");
    AddComment("");
  }
};

189 190 191 192
class MinusGradOpDescMaker : public GradOpDescMakerBase {
 public:
  using GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
193 194
  std::vector<std::unique_ptr<OpDesc>> operator()() const override {
    std::vector<std::unique_ptr<OpDesc>> retv;
195 196
    auto x_g = InputGrad("X");
    if (!x_g.empty()) {
Y
Yu Yang 已提交
197
      auto *op_desc = new OpDesc();
198 199 200 201 202 203 204 205 206
      op_desc->SetType("scale");
      op_desc->SetInput("X", OutputGrad("Out"));
      op_desc->SetOutput("Out", x_g);
      op_desc->SetAttr("scale", 1.0f);
      retv.emplace_back(op_desc);
    }

    auto y_g = InputGrad("Y");
    if (!y_g.empty()) {
Y
Yu Yang 已提交
207
      auto *op_desc = new OpDesc();
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
      op_desc->SetType("scale");
      op_desc->SetInput("X", OutputGrad("Out"));
      op_desc->SetOutput("Out", y_g);
      op_desc->SetAttr("scale", -1.0f);
      retv.emplace_back(op_desc);
    }
    return retv;
  }
};

class MinusOpMaker : public OpProtoAndCheckerMaker {
 public:
  MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "");
    AddInput("Y", "");
    AddOutput("Out", "");
    AddComment("minus for unittest");
  }
};
Y
Yu Yang 已提交
228 229 230 231
}  // namespace framework
}  // namespace paddle

namespace f = paddle::framework;
D
dongzhihong 已提交
232
namespace ops = paddle::operators;
Y
Yu Yang 已提交
233
using EnforceNotMet = paddle::platform::EnforceNotMet;
Q
Qiao Longfei 已提交
234 235
// rowwise_add
REGISTER_OPERATOR(rowwise_add, f::NoneOp, f::RowWiseAddOpMaker,
236
                  f::RowWiseAddGradMaker);
Q
Qiao Longfei 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
REGISTER_OP_CPU_KERNEL(rowwise_add,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
REGISTER_OPERATOR(rowwise_add_grad, f::NoneOp);
REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
// mul
REGISTER_OP(mul, f::NoneOp, f::MulOpMaker, mul_grad, f::NoneOp);
REGISTER_OP_CPU_KERNEL(mul, f::NoneKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
// sigmoid
REGISTER_OP(sigmoid, f::NoneOp, f::SigmoidOpMaker, sigmoid_grad, f::NoneOp);
REGISTER_OP_CPU_KERNEL(sigmoid,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NoneOp, f::NoGradOpMaker);
// fill_zeros_like
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NoneOp, f::FillZeroOpMaker);
REGISTER_OP_CPU_KERNEL(fill_zeros_like,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
// sum
REGISTER_OP(sum, f::NoneOp, f::SumOpMaker, sum_grad, f::NoneOp);
REGISTER_OP_CPU_KERNEL(sum, f::NoneKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
// fc
F
fengjiayi 已提交
262
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
Q
Qiao Longfei 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
// many_output_op
REGISTER_OP(many_output_op, f::NoneOp, f::ManyOutputOpMaker,
            many_output_op_grad, f::NoneOp);
// mult_in_out
REGISTER_OP(mult_in_out, f::NoneOp, f::MultInOutOpMaker, mult_in_out_grad,
            f::NoneOp);
REGISTER_OP_CPU_KERNEL(mult_in_out,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mult_in_out_grad,
                       f::NoneKernel<paddle::platform::CPUPlace, float>);
// minus
REGISTER_OPERATOR(minus, f::NoneOp, f::MinusOpMaker, f::MinusGradOpDescMaker);
REGISTER_OP_CPU_KERNEL(minus, f::NoneKernel<paddle::platform::CPUPlace, float>);
// scale
REGISTER_OPERATOR(scale, f::NoneOp);
REGISTER_OP_CPU_KERNEL(scale, f::NoneKernel<paddle::platform::CPUPlace, float>);
Y
Yu Yang 已提交
279

280
TEST(Backward, simple_op_not_need_grad) {
Y
Yiqun Liu 已提交
281 282 283
  auto fwd =
      f::OpRegistry::CreateOp("rowwise_add", {{"X", {"x"}}, {"b", {"b"}}},
                              {{"Out", {"out"}}}, f::AttributeMap{});
284 285 286 287 288 289 290
  ASSERT_NE(fwd, nullptr);
  auto gop = f::Backward(*fwd, {"x"});
  ASSERT_EQ(gop->Output(f::GradVarName("X")), f::kEmptyVarName);

  auto no_input_gop = f::Backward(*fwd, {"x", "b"});
  ASSERT_NE(no_input_gop, nullptr);
  ASSERT_TRUE(no_input_gop->IsNetOp());
Y
Yu Yang 已提交
291
  ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
292 293 294 295 296 297 298 299
}

TEST(Backward, net_fc_backward_normal) {
  std::shared_ptr<f::OperatorBase> fwd =
      f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
                              {{"mul_result", {"mul_res"}},
                               {"add_result", {"add_re"}},
                               {"Out", {"out"}}},
Y
Yiqun Liu 已提交
300
                              f::AttributeMap{});
301
  ASSERT_NE(fwd, nullptr);
Y
Yiqun Liu 已提交
302 303
  std::shared_ptr<f::OperatorBase> gop =
      f::Backward(*fwd, std::unordered_set<std::string>{});
304 305 306 307 308 309 310 311
  ASSERT_TRUE(gop->IsNetOp());
  auto net = static_cast<ops::NetOp *>(gop.get());

  ASSERT_NO_THROW(net->DebugString());

  ASSERT_EQ(3UL, net->ops_.size());

  f::OperatorBase &d_sigmoid = *net->ops_[0];
Q
qiaolongfei 已提交
312
  ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
313 314

  f::OperatorBase &d_add = *net->ops_[1];
Q
qiaolongfei 已提交
315
  ASSERT_EQ("rowwise_add_grad", d_add.Type());
316 317

  f::OperatorBase &d_mul = *net->ops_[2];
Q
qiaolongfei 已提交
318
  ASSERT_EQ("mul_grad", d_mul.Type());
319 320 321 322 323 324 325 326
}

TEST(Backward, net_fc_backward_not_have_b) {
  std::shared_ptr<f::OperatorBase> fwd =
      f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {}}},
                              {{"mul_result", {"mul_res"}},
                               {"add_result", {"add_res"}},
                               {"Out", {"tmp"}}},
Y
Yiqun Liu 已提交
327
                              f::AttributeMap{});
328
  ASSERT_NE(fwd, nullptr);
Y
Yiqun Liu 已提交
329 330
  std::shared_ptr<f::OperatorBase> gop =
      f::Backward(*fwd, std::unordered_set<std::string>{});
331 332 333 334 335 336 337 338
  ASSERT_TRUE(gop->IsNetOp());
  auto net = static_cast<ops::NetOp *>(gop.get());

  ASSERT_NO_THROW(net->DebugString());

  ASSERT_EQ(2UL, net->ops_.size());

  f::OperatorBase &d_sigmoid = *net->ops_[0];
Q
qiaolongfei 已提交
339
  ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
340 341

  f::OperatorBase &d_mul = *net->ops_[1];
Q
qiaolongfei 已提交
342
  ASSERT_EQ("mul_grad", d_mul.Type());
343 344 345 346
}

TEST(Backward, net_input_of_network_not_need_grad) {
  ops::NetOp net;
Y
Yu Yang 已提交
347
  net.AppendOp(f::OpRegistry::CreateOp(
348 349 350 351
      "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
      {{"mul_result", {"mul_tmp_0"}},
       {"add_result", {"add_tmp_0"}},
       {"Out", {"hidden0"}}},
Y
Yiqun Liu 已提交
352
      f::AttributeMap{}));
Y
Yu Yang 已提交
353
  net.AppendOp(f::OpRegistry::CreateOp(
354 355 356 357
      "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
      {{"mul_result", {"mul_tmp_1"}},
       {"add_result", {"add_tmp_1"}},
       {"Out", {"hidden1"}}},
Y
Yiqun Liu 已提交
358
      f::AttributeMap{}));
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
  net.CompleteAddOp();
  auto bwd = Backward(net, {"x"});  // x@GRAD is not need.
  ASSERT_TRUE(bwd->IsNetOp());
  auto bwd_net = static_cast<ops::NetOp *>(bwd.get());

  auto output_vars = bwd_net->OutputVars(true);
  std::unordered_set<std::string> all_outputs =
      std::unordered_set<std::string>(output_vars.begin(), output_vars.end());
  all_outputs.erase(f::kEmptyVarName);

  for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
    ASSERT_NE(all_outputs.find(f::GradVarName(out)), all_outputs.end());
  }

  // Not Generated X
  ASSERT_EQ(all_outputs.find(f::GradVarName("X")), all_outputs.end());

  ASSERT_EQ(2UL, bwd_net->ops_.size());
  ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
  auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
  ASSERT_EQ(3UL, first_fc_grad->ops_.size());
  ASSERT_EQ(f::kEmptyVarName,
            first_fc_grad->ops_[2]->Output(f::GradVarName("X")));
}

TEST(Backward, net_shared_weight) {
  ops::NetOp net;
Y
Yu Yang 已提交
386
  net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
Y
Yiqun Liu 已提交
387
                                       {{"Out", {"out"}}}, f::AttributeMap{}));
Y
Yu Yang 已提交
388
  net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
Y
Yiqun Liu 已提交
389 390
                                       {{"Out", {"FinalOut"}}},
                                       f::AttributeMap{}));
391 392
  net.CompleteAddOp();

Y
Yiqun Liu 已提交
393
  auto bwd = f::Backward(net, std::unordered_set<std::string>{});
394 395 396
  ASSERT_TRUE(bwd->IsNetOp());
  auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
  ASSERT_EQ(3UL, bwd_net->ops_.size());
D
dongzhihong 已提交
397
  ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
398 399 400
}

TEST(Backward, op_all_input_are_not_need) {
Y
Yiqun Liu 已提交
401 402 403
  auto fwd =
      f::OpRegistry::CreateOp("rowwise_add", {{"X", {"x"}}, {"b", {"b"}}},
                              {{"Out", {"out"}}}, f::AttributeMap{});
404 405 406 407 408 409 410
  auto backward = f::Backward(*fwd, {"x", "b"});
  ASSERT_TRUE(backward->IsNetOp());
  auto net = static_cast<ops::NetOp *>(backward.get());
  ASSERT_TRUE(net->ops_.empty());
}

TEST(Backward, op_all_output_are_not_need) {
Y
Yiqun Liu 已提交
411 412 413
  auto fwd =
      f::OpRegistry::CreateOp("rowwise_add", {{"X", {"x"}}, {"b", {"b"}}},
                              {{"Out", {"out"}}}, f::AttributeMap{});
414 415 416 417 418 419 420
  auto backward = f::Backward(*fwd, {"out"});
  ASSERT_TRUE(backward->IsNetOp());
  auto net = static_cast<ops::NetOp *>(backward.get());
  ASSERT_TRUE(net->ops_.empty());
}

TEST(Backward, op_part_of_output_are_not_need) {
Y
Yiqun Liu 已提交
421 422 423
  auto fwd =
      f::OpRegistry::CreateOp("many_output_op", {{"x", {"X"}}},
                              {{"y", {"Y"}}, {"z", {"Z"}}}, f::AttributeMap{});
424 425 426 427 428 429
  auto backward = f::Backward(*fwd, {"Z"});
  ASSERT_TRUE(backward->IsNetOp());
  auto net = static_cast<ops::NetOp *>(backward.get());
  ASSERT_EQ(net->ops_.size(), 2UL);

  auto &fill_zero = *net->ops_[0];
Q
qiaolongfei 已提交
430
  ASSERT_EQ("fill_zeros_like", fill_zero.Type());
D
dangqingqing 已提交
431 432
  ASSERT_EQ(1UL, fill_zero.Inputs("X").size());
  ASSERT_EQ("Z", fill_zero.Input("X"));
F
fengjiayi 已提交
433 434
  ASSERT_EQ(1UL, fill_zero.Outputs("Out").size());
  ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Out"));
435 436

  auto &d_many_out = *net->ops_[1];
Q
qiaolongfei 已提交
437 438
  ASSERT_EQ("many_output_op_grad", d_many_out.Type());
  ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size());  // I/O/OG
439 440 441 442 443 444 445 446
  ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
            d_many_out.Input(f::GradVarName("z")));
  ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
  ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x")));
}

TEST(Backward, op_part_of_input_are_not_need) {
  auto fwd = f::OpRegistry::CreateOp("mul", {{"X", {"a"}}, {"Y", {"b"}}},
Y
Yiqun Liu 已提交
447
                                     {{"Out", {"out"}}}, f::AttributeMap{});
448 449
  auto backward = f::Backward(*fwd, {"a"});
  auto &grad_mul = *backward;
Q
qiaolongfei 已提交
450 451 452
  ASSERT_EQ(grad_mul.Type(), "mul_grad");
  ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL);
  ASSERT_EQ(grad_mul.Outputs().size(), 2UL);
453 454 455 456 457 458 459 460 461 462
  ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName);
  ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b"));
  ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
  ASSERT_EQ(grad_mul.Input("X"), "a");
  ASSERT_EQ(grad_mul.Input("Y"), "b");
  ASSERT_EQ(grad_mul.Input("Out"), "out");
}

TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
  ops::NetOp net;
Y
Yu Yang 已提交
463
  net.AppendOp(f::OpRegistry::CreateOp(
464 465 466 467
      "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
      {{"mul_result", {"mul_out1"}},
       {"add_result", {"add_out1"}},
       {"Out", {"out1"}}},
Y
Yiqun Liu 已提交
468
      f::AttributeMap{}));
Y
Yu Yang 已提交
469
  net.AppendOp(f::OpRegistry::CreateOp(
470 471 472 473
      "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
      {{"mul_result", {"mul_out2"}},
       {"add_result", {"tmp_out2"}},
       {"Out", {"out2"}}},
Y
Yiqun Liu 已提交
474
      f::AttributeMap{}));
Y
Yu Yang 已提交
475
  net.AppendOp(f::OpRegistry::CreateOp(
476 477 478 479
      "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
      {{"mul_result", {"mul_out3"}},
       {"add_result", {"tmp_out3"}},
       {"Out", {"out3"}}},
Y
Yiqun Liu 已提交
480
      f::AttributeMap{}));
481 482 483 484 485 486 487
  net.CompleteAddOp();

  auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
  ASSERT_TRUE(backward->IsNetOp());
  auto bwd_net = static_cast<ops::NetOp *>(backward.get());
  ASSERT_EQ(bwd_net->ops_.size(), 3UL);
  auto &grad_fc = *bwd_net->ops_[0];
Y
Yu Yang 已提交
488 489

  const char *all = paddle::operators::NetOp::kAll;
Q
qiaolongfei 已提交
490
  EXPECT_EQ(grad_fc.Inputs(all).size(),
491 492 493
            2UL       /* external input number */
                + 1UL /* external output number*/
                + 1UL /* number of gradient of external output*/
494 495
                + 2UL /* internal variable number*/
            );
Q
qiaolongfei 已提交
496
  EXPECT_EQ(grad_fc.Outputs(all).size(),
497
            2UL       /* input number of mul*/
498 499 500
                + 2UL /* input number of rowwise_add*/
                + 1UL /* input number of sigmod */
                - 1UL /* out2 is not needed*/);
Q
qiaolongfei 已提交
501 502 503 504
  EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
  EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
  EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
  EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
505
}
506 507

TEST(Backward, simple_single_op) {
Y
Yu Yang 已提交
508 509
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Q
Qiao Longfei 已提交
510

Y
Yu Yang 已提交
511
  f::OpDesc *op = block->AppendOp();
512 513 514 515 516
  op->SetType("rowwise_add");
  op->SetInput("X", {"x"});
  op->SetInput("b", {"b"});
  op->SetOutput("Out", {"out"});

Y
Yu Yang 已提交
517
  auto target = f::VarDesc("out");
518
  target.SetShape({1});
Y
Yiqun Liu 已提交
519 520
  auto var_to_grad =
      AppendBackward(program, target, std::unordered_set<std::string>{});
521

522
  ASSERT_EQ(block->AllOps().size(), 3UL);
Y
Yu Yang 已提交
523
  f::OpDesc *fill_op = block->AllOps()[1];
524 525
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
526
  f::OpDesc *grad_op = block->AllOps()[2];
527 528 529 530 531 532 533 534 535
  EXPECT_EQ(grad_op->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out")}));
  EXPECT_EQ(grad_op->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("x")}));
  EXPECT_EQ(grad_op->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b")}));
F
fengjiayi 已提交
536

Q
Qiao Longfei 已提交
537
  EXPECT_EQ(var_to_grad.size(), 3UL);
F
fengjiayi 已提交
538 539 540 541 542
  EXPECT_EQ(var_to_grad.at("b"), f::GradVarInfo(f::GradVarName("b"), 0, 2));
  EXPECT_EQ(var_to_grad.at("x"), f::GradVarInfo(f::GradVarName("x"), 0, 2));

  EXPECT_TRUE(block->HasVar(f::GradVarName("b")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("x")));
543 544
}

F
fengjiayi 已提交
545
TEST(Backward, default_attribute) {
Y
Yu Yang 已提交
546 547 548
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
  f::OpDesc *op = block->AppendOp();
F
fengjiayi 已提交
549 550 551 552
  op->SetType("mul");
  op->SetInput("X", {"x"});
  op->SetInput("Y", {"y"});
  op->SetOutput("Out", {"out"});
553
  op->CheckAttrs();
F
fengjiayi 已提交
554

Y
Yu Yang 已提交
555
  auto target = f::VarDesc("out");
556
  target.SetShape({1});
Y
Yiqun Liu 已提交
557
  AppendBackward(program, target, std::unordered_set<std::string>{});
F
fengjiayi 已提交
558

559
  ASSERT_EQ(block->AllOps().size(), 3UL);
F
fengjiayi 已提交
560 561 562
  EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1);
  EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1);

Y
Yu Yang 已提交
563
  f::OpDesc *fill_op = block->AllOps()[1];
564 565
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
566
  f::OpDesc *grad_op = block->AllOps()[2];
F
fengjiayi 已提交
567 568 569 570 571
  ASSERT_EQ(grad_op->Type(), "mul_grad");
  EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1);
  EXPECT_EQ(boost::get<int>(grad_op->GetAttr("y_num_col_dims")), 1);
}

572
TEST(Backward, simple_mult_op) {
Y
Yu Yang 已提交
573 574 575
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
  f::OpDesc *op1 = block->AppendOp();
576 577 578 579 580
  op1->SetType("rowwise_add");
  op1->SetInput("X", {"x1"});
  op1->SetInput("b", {"b1"});
  op1->SetOutput("Out", {"out1"});

Y
Yu Yang 已提交
581
  f::OpDesc *op2 = block->AppendOp();
582 583 584 585 586
  op2->SetType("mul");
  op2->SetInput("X", {"out1"});
  op2->SetInput("Y", {"y2"});
  op2->SetOutput("Out", {"out2"});

Y
Yu Yang 已提交
587
  f::OpDesc *op3 = block->AppendOp();
588 589 590 591 592
  op3->SetType("rowwise_add");
  op3->SetInput("X", {"out2"});
  op3->SetInput("b", {"b3"});
  op3->SetOutput("Out", {"out3"});

Y
Yu Yang 已提交
593
  auto target = f::VarDesc("out3");
594
  target.SetShape({1});
595
  size_t forward_len = block->AllOps().size();
Y
Yiqun Liu 已提交
596 597
  auto var_to_grad =
      AppendBackward(program, target, std::unordered_set<std::string>{});
598

599
  ASSERT_EQ(block->AllOps().size(), 6UL + 1);
Y
Yu Yang 已提交
600
  f::OpDesc *fill_op = block->AllOps()[forward_len];
601 602
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
603
  f::OpDesc *grad_op1 = block->AllOps()[6];
604 605 606 607 608 609 610 611 612 613
  EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op1->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("x1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b1")}));

Y
Yu Yang 已提交
614
  f::OpDesc *grad_op2 = block->AllOps()[5];
615 616 617 618 619 620 621 622 623 624 625 626 627
  EXPECT_EQ(grad_op2->Type(), "mul_grad");
  ASSERT_EQ(grad_op2->InputNames().size(), 4UL);
  ASSERT_EQ(grad_op2->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op2->Input("X"), std::vector<std::string>({"out1"}));
  EXPECT_EQ(grad_op2->Input("Y"), std::vector<std::string>({"y2"}));
  EXPECT_EQ(grad_op2->Input("Out"), std::vector<std::string>({"out2"}));
  EXPECT_EQ(grad_op2->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out2")}));
  EXPECT_EQ(grad_op2->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("out1")}));
  EXPECT_EQ(grad_op2->Output(f::GradVarName("Y")),
            std::vector<std::string>({f::GradVarName("y2")}));

Y
Yu Yang 已提交
628
  f::OpDesc *grad_op3 = block->AllOps()[4];
629 630 631 632 633 634 635 636 637
  EXPECT_EQ(grad_op3->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op3->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op3->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op3->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out3")}));
  EXPECT_EQ(grad_op3->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("out2")}));
  EXPECT_EQ(grad_op3->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b3")}));
F
fengjiayi 已提交
638

Q
Qiao Longfei 已提交
639
  EXPECT_EQ(var_to_grad.size(), 7UL);
F
fengjiayi 已提交
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
  EXPECT_EQ(var_to_grad.at("x1"), f::GradVarInfo(f::GradVarName("x1"), 0, 6));
  EXPECT_EQ(var_to_grad.at("b1"), f::GradVarInfo(f::GradVarName("b1"), 0, 6));
  EXPECT_EQ(var_to_grad.at("out1"),
            f::GradVarInfo(f::GradVarName("out1"), 0, 5));
  EXPECT_EQ(var_to_grad.at("y2"), f::GradVarInfo(f::GradVarName("y2"), 0, 5));
  EXPECT_EQ(var_to_grad.at("out2"),
            f::GradVarInfo(f::GradVarName("out2"), 0, 4));
  EXPECT_EQ(var_to_grad.at("b3"), f::GradVarInfo(f::GradVarName("b3"), 0, 4));

  EXPECT_TRUE(block->HasVar(f::GradVarName("x1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("b1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("out1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("y2")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("out2")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("b3")));
F
fengjiayi 已提交
655 656 657
}

TEST(Backward, intermedia_var_no_grad) {
Y
Yu Yang 已提交
658 659 660
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
  f::OpDesc *op1 = block->AppendOp();
F
fengjiayi 已提交
661 662 663 664 665
  op1->SetType("rowwise_add");
  op1->SetInput("X", {"x1"});
  op1->SetInput("b", {"b1"});
  op1->SetOutput("Out", {"out1"});

Y
Yu Yang 已提交
666
  f::OpDesc *op2 = block->AppendOp();
F
fengjiayi 已提交
667 668 669 670 671
  op2->SetType("mul");
  op2->SetInput("X", {"x2"});
  op2->SetInput("Y", {"y2"});
  op2->SetOutput("Out", {"out2"});

Y
Yu Yang 已提交
672
  f::OpDesc *op3 = block->AppendOp();
F
fengjiayi 已提交
673 674 675 676 677
  op3->SetType("rowwise_add");
  op3->SetInput("X", {"out2"});
  op3->SetInput("b", {"b3"});
  op3->SetOutput("Out", {"out3"});

Y
Yu Yang 已提交
678
  f::OpDesc *op4 = block->AppendOp();
F
fengjiayi 已提交
679 680 681 682 683
  op4->SetType("mul");
  op4->SetInput("X", {"out1"});
  op4->SetInput("Y", {"out3"});
  op4->SetOutput("Out", {"out4"});

Y
Yu Yang 已提交
684
  auto target = f::VarDesc("out4");
685
  target.SetShape({1});
686
  size_t forward_len = block->AllOps().size();
F
fengjiayi 已提交
687
  auto var_to_grad = AppendBackward(program, target, {"out3"});
F
fengjiayi 已提交
688

689
  ASSERT_EQ(block->AllOps().size(), 7UL);
Y
Yu Yang 已提交
690
  f::OpDesc *fill_op = block->AllOps()[forward_len];
691 692
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
693
  f::OpDesc *grad_op1 = block->AllOps()[6];
F
fengjiayi 已提交
694 695 696 697 698 699 700 701 702 703
  EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op1->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("x1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b1")}));

Y
Yu Yang 已提交
704
  f::OpDesc *grad_op4 = block->AllOps()[5];
F
fengjiayi 已提交
705 706 707 708 709 710 711 712 713 714
  EXPECT_EQ(grad_op4->Type(), "mul_grad");
  ASSERT_EQ(grad_op4->InputNames().size(), 4UL);
  ASSERT_EQ(grad_op4->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op4->Input("X"), std::vector<std::string>({"out1"}));
  EXPECT_EQ(grad_op4->Input("Y"), std::vector<std::string>({"out3"}));
  EXPECT_EQ(grad_op4->Input("Out"), std::vector<std::string>({"out4"}));
  EXPECT_EQ(grad_op4->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out4")}));
  EXPECT_EQ(grad_op4->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("out1")}));
715
  EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), std::vector<std::string>());
F
fengjiayi 已提交
716

Q
Qiao Longfei 已提交
717
  EXPECT_EQ(var_to_grad.size(), 4UL);
F
fengjiayi 已提交
718 719 720 721 722 723 724 725
  EXPECT_EQ(var_to_grad.at("x1"), f::GradVarInfo(f::GradVarName("x1"), 0, 6));
  EXPECT_EQ(var_to_grad.at("b1"), f::GradVarInfo(f::GradVarName("b1"), 0, 6));
  EXPECT_EQ(var_to_grad.at("out1"),
            f::GradVarInfo(f::GradVarName("out1"), 0, 5));

  EXPECT_TRUE(block->HasVar(f::GradVarName("x1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("b1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("out1")));
F
fengjiayi 已提交
726 727 728
}

TEST(Backward, var_no_grad) {
Y
Yu Yang 已提交
729 730 731
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
  f::OpDesc *op1 = block->AppendOp();
F
fengjiayi 已提交
732 733 734 735 736 737
  op1->SetType("mult_in_out");
  op1->SetInput("X", {"x1"});
  op1->SetInput("H", {"h1"});
  op1->SetOutput("Y", {"y1"});
  op1->SetOutput("Z", {"z1"});

Y
Yu Yang 已提交
738
  f::OpDesc *op2 = block->AppendOp();
F
fengjiayi 已提交
739 740 741 742 743 744
  op2->SetType("mult_in_out");
  op2->SetInput("X", {"y1"});
  op2->SetInput("H", {"z1"});
  op2->SetOutput("Y", {"y2"});
  op2->SetOutput("Z", {"z2"});

Y
Yu Yang 已提交
745
  auto target = f::VarDesc("z2");
746
  target.SetShape({1});
747
  size_t forward_len = block->AllOps().size();
F
fengjiayi 已提交
748
  auto var_to_grad = AppendBackward(program, target, {"z1"});
F
fengjiayi 已提交
749

750
  ASSERT_EQ(block->AllOps().size(), 6UL);
Y
Yu Yang 已提交
751
  f::OpDesc *fill_op = block->AllOps()[forward_len];
752 753
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
754
  f::OpDesc *grad_op2 = block->AllOps()[3];
F
fengjiayi 已提交
755 756 757 758 759 760 761 762 763 764 765 766 767
  ASSERT_EQ(grad_op2->Type(), "mult_in_out_grad");
  ASSERT_EQ(grad_op2->InputNames().size(), 6UL);
  ASSERT_EQ(grad_op2->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op2->Input("X"), std::vector<std::string>({"y1"}));
  EXPECT_EQ(grad_op2->Input("H"), std::vector<std::string>({"z1"}));
  EXPECT_EQ(grad_op2->Input("Y"), std::vector<std::string>({"y2"}));
  EXPECT_EQ(grad_op2->Input("Z"), std::vector<std::string>({"z2"}));
  EXPECT_EQ(grad_op2->Input(f::GradVarName("Y")),
            std::vector<std::string>({f::GradVarName("y2")}));
  EXPECT_EQ(grad_op2->Input(f::GradVarName("Z")),
            std::vector<std::string>({f::GradVarName("z2")}));
  EXPECT_EQ(grad_op2->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("y1")}));
768
  EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector<std::string>());
F
fengjiayi 已提交
769

Y
Yu Yang 已提交
770
  f::OpDesc *fill_zero_op = block->AllOps()[4];
F
fengjiayi 已提交
771 772 773 774
  ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like");
  ASSERT_EQ(fill_zero_op->InputNames().size(), 1UL);
  ASSERT_EQ(fill_zero_op->OutputNames().size(), 1UL);
  EXPECT_EQ(fill_zero_op->Input("X"), std::vector<std::string>({"z1"}));
F
fengjiayi 已提交
775
  EXPECT_EQ(fill_zero_op->Output("Out"),
F
fengjiayi 已提交
776 777
            std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix}));

Y
Yu Yang 已提交
778
  f::OpDesc *grad_op1 = block->AllOps()[5];
F
fengjiayi 已提交
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
  ASSERT_EQ(grad_op1->Type(), "mult_in_out_grad");
  ASSERT_EQ(grad_op1->InputNames().size(), 6UL);
  ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op1->Input("X"), std::vector<std::string>({"x1"}));
  EXPECT_EQ(grad_op1->Input("H"), std::vector<std::string>({"h1"}));
  EXPECT_EQ(grad_op1->Input("Y"), std::vector<std::string>({"y1"}));
  EXPECT_EQ(grad_op1->Input("Z"), std::vector<std::string>({"z1"}));
  EXPECT_EQ(grad_op1->Input(f::GradVarName("Y")),
            std::vector<std::string>({f::GradVarName("y1")}));
  EXPECT_EQ(grad_op1->Input(f::GradVarName("Z")),
            std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("x1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("H")),
            std::vector<std::string>({f::GradVarName("h1")}));
F
fengjiayi 已提交
794

Q
Qiao Longfei 已提交
795
  EXPECT_EQ(var_to_grad.size(), 4UL);
F
fengjiayi 已提交
796 797 798 799 800 801 802
  EXPECT_EQ(var_to_grad.at("y1"), f::GradVarInfo(f::GradVarName("y1"), 0, 3));
  EXPECT_EQ(var_to_grad.at("x1"), f::GradVarInfo(f::GradVarName("x1"), 0, 5));
  EXPECT_EQ(var_to_grad.at("h1"), f::GradVarInfo(f::GradVarName("h1"), 0, 5));

  EXPECT_TRUE(block->HasVar(f::GradVarName("y1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("x1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("h1")));
F
fengjiayi 已提交
803 804 805
}

TEST(Backward, shared_var) {
Y
Yu Yang 已提交
806 807 808
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
  f::OpDesc *op1 = block->AppendOp();
F
fengjiayi 已提交
809 810 811 812 813
  op1->SetType("rowwise_add");
  op1->SetInput("X", {"x1"});
  op1->SetInput("b", {"b1"});
  op1->SetOutput("Out", {"out1"});

Y
Yu Yang 已提交
814
  f::OpDesc *op2 = block->AppendOp();
F
fengjiayi 已提交
815 816 817 818 819
  op2->SetType("mul");
  op2->SetInput("X", {"out1"});
  op2->SetInput("Y", {"y2"});
  op2->SetOutput("Out", {"out2"});

Y
Yu Yang 已提交
820
  f::OpDesc *op3 = block->AppendOp();
F
fengjiayi 已提交
821 822 823 824 825
  op3->SetType("rowwise_add");
  op3->SetInput("X", {"out1"});
  op3->SetInput("b", {"b3"});
  op3->SetOutput("Out", {"out3"});

Y
Yu Yang 已提交
826
  auto target = f::VarDesc("out3");
827
  target.SetShape({1});
828
  size_t forward_len = block->AllOps().size();
Y
Yiqun Liu 已提交
829 830
  auto var_to_grad =
      AppendBackward(program, target, std::unordered_set<std::string>{});
F
fengjiayi 已提交
831

832
  ASSERT_EQ(block->AllOps().size(), 8UL);
Y
Yu Yang 已提交
833
  f::OpDesc *fill_op = block->AllOps()[forward_len];
834 835
  EXPECT_EQ(fill_op->Type(), "fill_constant");

Y
Yu Yang 已提交
836
  f::OpDesc *grad_op3 = block->AllOps()[4];
F
fengjiayi 已提交
837 838 839 840 841 842 843 844 845 846
  ASSERT_EQ(grad_op3->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op3->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op3->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op3->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out3")}));
  EXPECT_EQ(grad_op3->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("out1") + "@RENAME@0"}));
  EXPECT_EQ(grad_op3->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b3")}));

Y
Yu Yang 已提交
847
  f::OpDesc *grad_op4 = block->AllOps()[5];
F
fengjiayi 已提交
848 849 850 851 852 853 854 855 856 857 858 859 860
  ASSERT_EQ(grad_op4->Type(), "mul_grad");
  ASSERT_EQ(grad_op4->InputNames().size(), 4UL);
  ASSERT_EQ(grad_op4->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op4->Input("X"), std::vector<std::string>({"out1"}));
  EXPECT_EQ(grad_op4->Input("Y"), std::vector<std::string>({"y2"}));
  EXPECT_EQ(grad_op4->Input("Out"), std::vector<std::string>({"out2"}));
  EXPECT_EQ(grad_op4->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out2")}));
  EXPECT_EQ(grad_op4->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("out1") + "@RENAME@1"}));
  EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")),
            std::vector<std::string>({f::GradVarName("y2")}));

Y
Yu Yang 已提交
861
  f::OpDesc *sum_op = block->AllOps()[6];
F
fengjiayi 已提交
862 863 864 865 866 867 868 869 870
  ASSERT_EQ(sum_op->Type(), "sum");
  ASSERT_EQ(sum_op->InputNames().size(), 1UL);
  ASSERT_EQ(sum_op->OutputNames().size(), 1UL);
  EXPECT_EQ(sum_op->Input("X"),
            std::vector<std::string>({f::GradVarName("out1") + "@RENAME@0",
                                      f::GradVarName("out1") + "@RENAME@1"}));
  EXPECT_EQ(sum_op->Output("Out"),
            std::vector<std::string>({f::GradVarName("out1")}));

Y
Yu Yang 已提交
871
  f::OpDesc *grad_op1 = block->AllOps()[7];
F
fengjiayi 已提交
872 873 874 875 876 877 878 879 880
  ASSERT_EQ(grad_op1->Type(), "rowwise_add_grad");
  ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
  ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
  EXPECT_EQ(grad_op1->Input(f::GradVarName("Out")),
            std::vector<std::string>({f::GradVarName("out1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("X")),
            std::vector<std::string>({f::GradVarName("x1")}));
  EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
            std::vector<std::string>({f::GradVarName("b1")}));
F
fengjiayi 已提交
881

Q
Qiao Longfei 已提交
882
  EXPECT_EQ(var_to_grad.size(), 6UL);
F
fengjiayi 已提交
883 884 885 886 887 888 889 890 891 892 893 894
  EXPECT_EQ(var_to_grad.at("b3"), f::GradVarInfo(f::GradVarName("b3"), 0, 4));
  EXPECT_EQ(var_to_grad.at("y2"), f::GradVarInfo(f::GradVarName("y2"), 0, 5));
  EXPECT_EQ(var_to_grad.at("out1"),
            f::GradVarInfo(f::GradVarName("out1"), 0, 6));
  EXPECT_EQ(var_to_grad.at("x1"), f::GradVarInfo(f::GradVarName("x1"), 0, 7));
  EXPECT_EQ(var_to_grad.at("b1"), f::GradVarInfo(f::GradVarName("b1"), 0, 7));

  EXPECT_TRUE(block->HasVar(f::GradVarName("b3")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("y2")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("out1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("x1")));
  EXPECT_TRUE(block->HasVar(f::GradVarName("b1")));
895 896 897
}

TEST(Backward, half_backward) {
Y
Yu Yang 已提交
898 899
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
900 901 902 903 904 905
  auto *op1 = block->AppendOp();
  op1->SetType("minus");
  op1->SetInput("X", {"a"});
  op1->SetInput("Y", {"b"});
  op1->SetOutput("Out", {"out"});

Y
Yu Yang 已提交
906
  auto target = f::VarDesc("out");
907
  target.SetShape({1});
908
  size_t forward_len = block->AllOps().size();
F
fengjiayi 已提交
909
  auto var_to_grad = AppendBackward(program, target, {"b"});
Y
Yu Yang 已提交
910
  f::OpDesc *fill_op = block->AllOps()[forward_len];
911
  EXPECT_EQ(fill_op->Type(), "fill_constant");
912
  auto ops = block->AllOps();
913
  ASSERT_EQ(3UL, ops.size());
F
fengjiayi 已提交
914

Q
Qiao Longfei 已提交
915
  EXPECT_EQ(var_to_grad.size(), 2UL);
F
fengjiayi 已提交
916 917
  EXPECT_EQ(var_to_grad.at("a"),
            f::GradVarInfo(f::GradVarName("a"), 0, forward_len + 1));
918
}