From 71acaff1bdbe67a5cf412a5c5e5dbc1399c01e45 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 10 Aug 2017 18:30:22 +0800 Subject: [PATCH] Tiny fix --- paddle/framework/grad_op_builder.cc | 9 +++++---- paddle/framework/pybind.cc | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 27f37d992..c51a563a6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -30,19 +30,20 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, bool is_grad) { const auto& src_inout = src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; - auto& dst_inout = dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; - const OpProto& proto = OpRegistry::protos().at(src_op->type_); + + const OpProto& proto = OpProtos().at(src_op->type_); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { + if (arg.no_gradient() && !is_grad) continue; std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; + dst_inout[dst_name].reserve(src_inout.at(src_name).size()); for (auto& var_name : src_inout.at(src_name)) { - std::string s = is_grad ? GradVarName(var_name) - : (arg.no_gradient() ? kEmptyVarName : var_name); + std::string s = is_grad ? GradVarName(var_name) : var_name; dst_inout[dst_name].emplace_back(s); } } diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 94d2a4c68..d6ddd5dea 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -57,8 +57,8 @@ void ExposeOperator(ClassType &m) { .def("outputs", [](const typename ClassType::type &op) -> std::unordered_map> { - return op.outputs_; - }) + return op.outputs_; + }) .def("__str__", &ClassType::type::DebugString); } @@ -152,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &protos = OpRegistry::protos(); + auto &protos = OpProtos(); std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), -- GitLab