提交 5338417b 编写于 作者: M minqiyang

Polish code style

上级 ae39709e
...@@ -209,7 +209,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -209,7 +209,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
if (attr_type == proto::AttrType::INTS && if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) { boost::get<std::vector<int>>(v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value // Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr& attr = GetProtoAttr(name); const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) { switch (attr.type()) {
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(11) << "SetAttr: " << Type() << ", " << name
...@@ -275,8 +275,8 @@ Attribute OpDesc::GetAttr(const std::string &name) const { ...@@ -275,8 +275,8 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) { const proto::OpProto::Attr &OpDesc::GetProtoAttr(const std::string &name) {
proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto(); proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) { for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i); const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) { if (attr.name() == name) {
......
...@@ -81,7 +81,7 @@ class OpDesc { ...@@ -81,7 +81,7 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const; const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
......
...@@ -364,8 +364,7 @@ def _append_backward_ops_(block, ...@@ -364,8 +364,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
......
...@@ -159,8 +159,7 @@ class ParallelExecutor(object): ...@@ -159,8 +159,7 @@ class ParallelExecutor(object):
for p in main.global_block().iter_parameters() for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(cpt.to_text(var) set(cpt.to_text(var) for var in self.persistable_vars), main.desc,
for var in self.persistable_vars), main.desc,
cpt.to_text(loss_name) cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy, if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
...@@ -274,8 +273,7 @@ class ParallelExecutor(object): ...@@ -274,8 +273,7 @@ class ParallelExecutor(object):
self.executor.feed_tensors_into_local_scopes(res) self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@' fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run( self.executor.run(cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist: if self.is_dist:
......
...@@ -259,9 +259,9 @@ class ControlFlowGraph(object): ...@@ -259,9 +259,9 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with # Rename the var to the cache var already with
# memory allocated in order to reuse the memory. # memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i) _rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var( self._program.block(block_desc.id).var(cpt.to_text(
cpt.to_text(x)).desc = self._find_var( x)).desc = self._find_var(block_desc, cache_var,
block_desc, cache_var, is_forward) is_forward)
self._update_graph(x, cache_var, begin_idx=i) self._update_graph(x, cache_var, begin_idx=i)
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册