Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7ccbdb1b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ccbdb1b
编写于
2月 06, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
for test
上级
c32040c3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
11 deletion
+37
-11
paddle/framework/block_desc.cc
paddle/framework/block_desc.cc
+24
-0
paddle/framework/block_desc.h
paddle/framework/block_desc.h
+2
-0
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+6
-0
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+1
-1
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+4
-10
未找到文件。
paddle/framework/block_desc.cc
浏览文件 @
7ccbdb1b
...
...
@@ -42,6 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const {
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
}
void
BlockDesc
::
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
if
(
this
->
HasVar
(
old_name
))
{
auto
*
var
=
this
->
Var
(
old_name
);
var
->
SetName
(
new_name
);
vars_
[
new_name
].
reset
(
var
);
vars_
.
erase
(
old_name
);
// rename inputs and outputs
for
(
const
auto
&
op
:
ops_
)
{
auto
*
it
=
op
.
get
();
for
(
auto
in_name
:
it
->
InputArgumentNames
())
{
if
(
in_name
==
old_name
)
{
it
->
RenameInput
(
old_name
,
new_name
);
}
}
for
(
auto
out_name
:
it
->
OutputArgumentNames
())
{
if
(
out_name
==
old_name
)
{
it
->
RenameOutput
(
old_name
,
new_name
);
}
}
}
}
}
VarDesc
*
BlockDesc
::
FindVarRecursive
(
const
std
::
string
&
name
)
const
{
if
(
name
==
kEmptyVarName
)
return
nullptr
;
...
...
paddle/framework/block_desc.h
浏览文件 @
7ccbdb1b
...
...
@@ -55,6 +55,8 @@ class BlockDesc {
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
void
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
&
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
...
...
paddle/pybind/protobuf.cc
浏览文件 @
7ccbdb1b
...
...
@@ -171,6 +171,12 @@ void BindBlockDesc(py::module &m) {
std
::
string
name
=
byte_name
;
return
self
.
HasVar
(
name
);
})
.
def
(
"rename_var"
,
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
,
py
::
bytes
byte_name_new
)
{
std
::
string
name
=
byte_name
;
std
::
string
new_name
=
byte_name_new
;
return
self
.
RenameVar
(
name
,
new_name
);
})
.
def
(
"has_var_recursive"
,
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
std
::
string
name
=
byte_name
;
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
7ccbdb1b
...
...
@@ -203,7 +203,7 @@ class DistributeTranspiler:
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
long
(
offset
),
long
(
size
)))
for
varname
,
splited
in
block_map
.
iteritems
():
orig_var
=
program
.
global_block
().
var
s
[
varname
]
orig_var
=
program
.
global_block
().
var
(
varname
)
if
len
(
splited
)
==
1
:
# rename var to the trainer_id var
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
7ccbdb1b
...
...
@@ -740,15 +740,9 @@ class Block(object):
"""
if
not
self
.
has_var
(
name
):
raise
ValueError
(
"var %s is not in current"
%
name
)
orig_var
=
self
.
var
(
name
)
del
self
.
vars
[
name
]
orig_var
.
name
=
new_name
self
.
vars
[
new_name
]
=
orig_var
for
op
in
self
.
ops
:
if
name
in
op
.
input_arg_names
:
op
.
rename_input
(
name
,
new_name
)
if
name
in
op
.
output_arg_names
:
op
.
rename_output
(
name
,
new_name
)
self
.
desc
.
rename_var
(
name
,
new_name
)
self
.
sync_with_cpp
()
print
(
"renamed var: "
,
self
.
var
(
new_name
))
def
create_parameter
(
self
,
*
args
,
**
kwargs
):
global_block
=
self
.
program
.
global_block
()
...
...
@@ -837,7 +831,7 @@ class Block(object):
for
p
in
other
.
iter_parameters
():
assert
isinstance
(
p
,
Parameter
)
v
=
self
.
vars
.
get
(
p
.
name
,
None
)
print
(
"var shape to copy"
,
v
)
print
(
"var shape to copy"
,
v
,
p
)
if
v
is
None
:
raise
ValueError
(
"copy_param_info_from should be invoked with "
"same topology"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录