提交 0c96c997 编写于 作者: Q qiaolongfei

change pybind and net_op_test

上级 2f74e608
...@@ -23,7 +23,7 @@ static void TransOpArg(const OperatorBase* src_op, ...@@ -23,7 +23,7 @@ static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars, OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) { const OpArgType& src_type, bool is_grad) {
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto& proto = OpProtos().at(src_op->Type()); const OpProto& proto = OpProtos().at(src_op->Type());
...@@ -39,13 +39,12 @@ static void TransOpArg(const OperatorBase* src_op, ...@@ -39,13 +39,12 @@ static void TransOpArg(const OperatorBase* src_op,
dst_inout[dst_name].emplace_back(s); dst_inout[dst_name].emplace_back(s);
} }
} }
return dst_inout;
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto gop_type_it = OpRegistry::grad_ops().find(op->type_); auto gop_type_it = OpRegistry::grad_ops().find(op->Type());
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(), PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_); "Operator %s do not register gradient type", op->Type());
auto& grad_op_type = gop_type_it->second; auto& grad_op_type = gop_type_it->second;
OperatorBase::VarNameMap inputs; OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs; OperatorBase::VarNameMap outputs;
...@@ -56,9 +55,9 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -56,9 +55,9 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
auto gop_it = OpRegistry::op_creators().find(grad_op_type); auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(), PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found", "Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type); op->Type(), grad_op_type);
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_); return gop_it->second(grad_op_type, inputs, outputs, op->Attrs());
} }
} // namespace framework } // namespace framework
......
...@@ -52,8 +52,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -52,8 +52,8 @@ TEST(GradOpBuilder, AddTwo) {
"add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op = std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op); f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->inputs_.size(), 4UL); EXPECT_EQ(grad_add_op->Inputs().size(), 4UL);
EXPECT_EQ(grad_add_op->outputs_.size(), 2UL); EXPECT_EQ(grad_add_op->Outputs().size(), 2UL);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out"); EXPECT_EQ(grad_add_op->Input("Out"), "out");
...@@ -76,7 +76,7 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -76,7 +76,7 @@ TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"), EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"})); std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
...@@ -90,7 +90,7 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -90,7 +90,7 @@ TEST(GradOpBuilder, MutiInOut) {
std::vector<std::string>( std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")})); {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"), std::vector<std::string>({f::GradVarName("in2_1"),
...@@ -109,7 +109,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -109,7 +109,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
// 'In2' and 'Out2' are ignored in gradient calculating // 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL); ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In3_mult"), EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"})); std::vector<std::string>({"in3_1", "in3_2"}));
...@@ -121,7 +121,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -121,7 +121,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
f::GradVarName("out2")); f::GradVarName("out2"));
ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>( std::vector<std::string>(
......
...@@ -53,15 +53,15 @@ void ExposeOperator(ClassType &m) { ...@@ -53,15 +53,15 @@ void ExposeOperator(ClassType &m) {
.def("run", &ClassType::type::Run) .def("run", &ClassType::type::Run)
.def("type", .def("type",
[](const typename ClassType::type &op) -> std::string { [](const typename ClassType::type &op) -> std::string {
return op.type_; return op.Type();
}) })
.def("outputs", .def("outputs",
[](const typename ClassType::type &op) [](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> { -> std::map<std::string, std::vector<std::string>> {
return op.outputs_; return op.Outputs();
}) })
.def("inputs", .def("inputs",
[](const typename ClassType::type &op) { return op.inputs_; }) [](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString) .def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs", .def("no_intermediate_outputs",
[](const typename ClassType::type &op) { [](const typename ClassType::type &op) {
...@@ -229,7 +229,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -229,7 +229,7 @@ All parameter, weight, gradient are variables in Paddle.
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> { []() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>(); auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net"; retv->SetType("plain_net");
return retv; return retv;
}) })
.def("add_op", &operators::NetOp::AddOp) .def("add_op", &operators::NetOp::AddOp)
......
...@@ -56,8 +56,8 @@ TEST(OpKernel, all) { ...@@ -56,8 +56,8 @@ TEST(OpKernel, all) {
net->CompleteAddOp(); net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
net->inputs_.at(NetOp::kAll)); net->Inputs(NetOp::kAll));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll)); AssertSameVectorWithoutOrder({"y", "z"}, net->Outputs(NetOp::kAll));
auto final_outs = net->OutputVars(false); auto final_outs = net->OutputVars(false);
......
...@@ -82,14 +82,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -82,14 +82,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net); arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>(); auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->outputs_.empty(), "net_op has no outputs"); PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
for (auto& input : net_op->inputs_) { for (auto& input : net_op->Inputs()) {
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
...@@ -98,7 +98,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -98,7 +98,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : net_op->outputs_) { for (const auto& output : net_op->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.NewVar(var_name);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册