提交 5338417b 编写于 作者: M minqiyang

Polish code style

上级 ae39709e
......@@ -209,7 +209,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr& attr = GetProtoAttr(name);
const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
......@@ -275,8 +275,8 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}
const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) {
proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto();
const proto::OpProto::Attr &OpDesc::GetProtoAttr(const std::string &name) {
proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
......
......@@ -81,7 +81,7 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const;
const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const;
......
......@@ -364,8 +364,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc,
cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......
......@@ -159,8 +159,7 @@ class ParallelExecutor(object):
for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(cpt.to_text(var)
for var in self.persistable_vars), main.desc,
set(cpt.to_text(var) for var in self.persistable_vars), main.desc,
cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id)
......@@ -274,8 +273,7 @@ class ParallelExecutor(object):
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run(
cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
self.executor.run(cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist:
......
......@@ -259,9 +259,9 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(
cpt.to_text(x)).desc = self._find_var(
block_desc, cache_var, is_forward)
self._program.block(block_desc.id).var(cpt.to_text(
x)).desc = self._find_var(block_desc, cache_var,
is_forward)
self._update_graph(x, cache_var, begin_idx=i)
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册