From 4736b239d978f5def9ef2dc3e13a7c8dea12f35d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 21 Jul 2017 11:25:11 +0800 Subject: [PATCH] Add a simple test for grad_op_creator --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/grad_op_creator_test.cc | 25 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 paddle/framework/grad_op_creator_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a43861f4cd..36da6f649b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -22,6 +22,7 @@ cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_creator SRCS grad_op_creator.cc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc grad_op_creator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) +cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry operator add_op) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/grad_op_creator_test.cc b/paddle/framework/grad_op_creator_test.cc new file mode 100644 index 0000000000..ad836727c3 --- /dev/null +++ b/paddle/framework/grad_op_creator_test.cc @@ -0,0 +1,25 @@ +#include "paddle/framework/grad_op_creator.h" +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +USE_OP(add_two); + +namespace paddle { +namespace framework { + +TEST(GradOpCreator, AddTwo) { + OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op); + EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); + EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); + EXPECT_EQ(grad_add_op->Input("X"), "x"); + EXPECT_EQ(grad_add_op->Input("Y"), "y"); + EXPECT_EQ(grad_add_op->Input("Out"), "out"); + EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); + EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); + EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file -- GitLab