grad_op_builder_test.cc 904 字节
Newer Older
F
fengjiayi 已提交
1
#include "paddle/framework/grad_op_builder.h"
2 3 4 5 6 7 8 9 10
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

USE_OP(add_two);

namespace paddle {
namespace framework {

F
fengjiayi 已提交
11
TEST(GradOpBuilder, AddTwo) {
F
fengjiayi 已提交
12 13 14
  std::shared_ptr<OperatorBase> add_op(
      OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
  std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
15 16 17 18 19 20 21 22 23 24 25 26
  EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
  EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
  EXPECT_EQ(grad_add_op->Input("X"), "x");
  EXPECT_EQ(grad_add_op->Input("Y"), "y");
  EXPECT_EQ(grad_add_op->Input("Out"), "out");
  EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
  EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
  EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}

}  // namespace framework
}  // namespace paddle