提交 4736b239 编写于 作者: F fengjiayi

Add a simple test for grad_op_creator

上级 9418717f
......@@ -22,6 +22,7 @@ cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_creator SRCS grad_op_creator.cc)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc grad_op_creator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry operator add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(add_two);
namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册