Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5ab56871
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ab56871
编写于
3月 28, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove no necessary doc changes. test=develop
上级
6b854f3e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
198 addition
and
2 deletion
+198
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+198
-2
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
5ab56871
...
...
@@ -627,6 +627,183 @@ class Variable(object):
"""
self
.
error_clip
=
error_clip
def
_slice_indices
(
self
,
slice
,
length
):
"""
Reference implementation for the slice.indices method.
"""
# Compute step and length as integers.
step
=
1
if
slice
.
step
is
None
else
slice
.
step
# Raise ValueError for negative length or zero step.
if
length
<
0
:
raise
ValueError
(
"length should not be negative"
)
if
step
==
0
:
raise
ValueError
(
"slice step cannot be zero"
)
# Find lower and upper bounds for start and stop.
lower
=
-
1
if
step
<
0
else
0
upper
=
length
-
1
if
step
<
0
else
length
# Compute start.
if
slice
.
start
is
None
:
start
=
upper
if
step
<
0
else
lower
else
:
start
=
slice
.
start
start
=
max
(
start
+
length
,
lower
)
if
start
<
0
else
min
(
start
,
upper
)
# Compute stop.
if
slice
.
stop
is
None
:
stop
=
lower
if
step
<
0
else
upper
else
:
stop
=
slice
.
stop
stop
=
max
(
stop
+
length
,
lower
)
if
stop
<
0
else
min
(
stop
,
upper
)
return
start
,
stop
,
step
def
_detectEllipsis
(
self
,
item
):
has_ellipsis
=
False
start
=
0
end
=
len
(
self
.
shape
)
for
index
,
o
in
enumerate
(
item
):
if
o
is
Ellipsis
:
if
has_ellipsis
:
raise
ValueError
(
"Index can have one ellipsis only."
)
has_ellipsis
=
True
start
=
index
else
:
if
has_ellipsis
:
end
=
index
return
has_ellipsis
,
start
,
end
def
_reconstructSliceinfo
(
self
,
item
):
has_ellipsis
,
start
,
end
=
self
.
_detectEllipsis
(
item
)
if
has_ellipsis
:
newitem
=
[]
for
i
in
range
(
start
):
newitem
.
append
(
item
[
i
])
for
i
in
range
(
start
,
end
):
newitem
.
append
(
slice
(
None
,
None
,
None
))
for
i
in
range
(
end
,
len
(
item
)):
newitem
.
append
(
item
[
i
])
return
newitem
else
:
return
None
def
_detectContinuesSlice
(
self
,
item
):
starts
=
[]
ends
=
[]
for
index
,
o
in
enumerate
(
item
):
if
isinstance
(
o
,
int
):
start
=
int
(
o
)
if
(
index
>
0
and
index
>=
self
.
shape
[
index
])
\
or
(
index
<
0
and
(
index
+
self
.
shape
[
index
])
<
0
):
raise
IndexError
(
"invalid index"
)
start
=
max
(
start
+
self
.
shape
[
index
],
0
)
if
start
<
0
else
min
(
start
,
self
.
shape
[
index
])
starts
.
append
(
start
)
ends
.
append
(
start
+
1
)
elif
isinstance
(
o
,
slice
):
start
,
stop
,
step
=
self
.
_slice_indices
(
o
,
self
.
shape
[
index
])
if
step
==
1
or
step
==
-
1
:
starts
.
append
(
start
)
ends
.
append
(
stop
)
else
:
return
False
,
None
else
:
raise
IndexError
(
"Valid index accept int or slice or ellipsis"
)
return
True
,
[
starts
,
ends
]
def
_cloneVar
(
self
,
copy
=
False
):
if
not
copy
:
return
self
.
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
self
.
name
)),
dtype
=
self
.
dtype
,
persistable
=
self
.
persistable
,
stop_gradient
=
self
.
_stop_gradient
,
)
else
:
return
self
def
_sliceVar
(
self
,
axes
,
starts
,
ends
):
new_var
=
self
.
_cloneVar
()
self
.
block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
[
self
]},
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
'axes'
:
axes
,
'starts'
:
starts
,
'ends'
:
ends
})
return
new_var
def
_concatVar
(
self
,
inputs
,
axis
):
new_var
=
self
.
_cloneVar
()
self
.
block
.
append_op
(
type
=
"concat"
,
inputs
=
{
'X'
:
inputs
},
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
'axis'
:
axis
,
})
return
new_var
def
_sliceAndConcatVar
(
self
,
item
,
axis
):
if
isinstance
(
item
,
slice
):
if
self
.
shape
[
axis
]
<
0
:
return
self
.
_cloneVar
(
True
)
start
,
stop
,
step
=
self
.
_slice_indices
(
item
,
self
.
shape
[
axis
])
if
step
==
1
:
return
self
.
_sliceVar
([
axis
],
[
start
],
[
stop
])
else
:
vars
=
[]
if
step
>
0
:
while
start
<
stop
:
vars
.
append
(
self
.
_sliceVar
([
axis
],
[
start
],
[
start
+
1
]))
start
+=
step
else
:
while
start
>
stop
:
vars
.
append
(
self
.
_sliceVar
([
axis
],
[
start
],
[
start
+
1
]))
start
+=
step
return
self
.
_concatVar
(
vars
,
axis
)
elif
isinstance
(
item
,
int
):
if
self
.
shape
[
axis
]
<
0
:
return
self
.
_cloneVar
(
True
)
index
=
int
(
item
)
if
(
index
>
0
and
index
>=
self
.
shape
[
axis
])
\
or
(
index
<
0
and
(
index
+
self
.
shape
[
axis
])
<
0
):
raise
IndexError
(
"invalid index"
)
return
self
.
_sliceVar
([
axis
],
[
index
],
[
index
+
1
])
else
:
raise
IndexError
(
"Valid index accept int or slice or tuple"
)
def
__getitem__
(
self
,
item
):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
new_var
=
None
if
isinstance
(
item
,
tuple
):
if
len
(
item
)
>
len
(
self
.
shape
):
raise
IndexError
(
"Too many indexes"
)
newitem
=
self
.
_reconstructSliceinfo
(
item
)
or
item
check
,
info
=
self
.
_detectContinuesSlice
(
newitem
)
if
check
:
starts
=
info
[
0
]
ends
=
info
[
1
]
axes
=
[
i
for
i
in
range
(
len
(
starts
))]
return
self
.
_sliceVar
(
axes
,
starts
,
ends
)
else
:
new_var
=
self
for
index
,
o
in
enumerate
(
newitem
):
new_var
=
new_var
.
_sliceAndConcatVar
(
o
,
index
)
else
:
new_var
=
self
.
_sliceAndConcatVar
(
item
,
0
)
return
new_var
def
get_all_op_protos
():
"""
...
...
@@ -744,7 +921,7 @@ class Operator(object):
if
_in_imperative_mode
():
if
type
is
None
:
raise
ValueError
(
"`type` to initilized an Operator can not be None."
)
"`type` to initi
a
lized an Operator can not be None."
)
self
.
iop
=
core
.
OpBase
(
type
)
# TODO(minqiyang): remove these lines after we take apart all
...
...
@@ -906,7 +1083,10 @@ class Operator(object):
@
property
def
type
(
self
):
return
self
.
desc
.
type
()
if
_in_imperative_mode
():
return
self
.
iop
.
type
else
:
return
self
.
desc
.
type
()
def
input
(
self
,
name
):
"""
...
...
@@ -1022,6 +1202,9 @@ class Operator(object):
"""
self
.
_update_desc_attr
(
name
,
val
)
def
_remove_attr
(
self
,
name
):
self
.
desc
.
remove_attr
(
name
)
def
_update_desc_attr
(
self
,
name
,
val
):
"""
Update the value of desc's attribute by attribute's name.
...
...
@@ -2515,6 +2698,10 @@ class Program(object):
self
.
_trainers_endpoints
=
[]
# the distributed lookup table names
self
.
_distributed_lookup_table
=
None
# use Deep gradient comrepssion or not
self
.
_enable_dgc
=
False
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
self
.
__is_mem_optimized
=
False
...
...
@@ -2565,6 +2752,15 @@ class Program(object):
def
set_op_role_var
(
self
,
var_name
):
self
.
_op_role_var
=
[
var_name
]
@
contextlib
.
contextmanager
def
_backward_role_guard
(
self
):
tmp_role
=
self
.
_current_role
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
Backward
yield
self
.
_current_role
=
tmp_role
@
signature_safe_contextmanager
def
_optimized_guard
(
self
,
param_and_grads
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录