Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
20c4a4cb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
20c4a4cb
编写于
2月 07, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
2月 07, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Impl scalar switch case op with condition op (#8184)
Impl scalar switch case op with condition op
上级
e5832019
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
171 addition
and
10 deletion
+171
-10
doc/design/switch.md
doc/design/switch.md
+1
-2
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+38
-6
python/paddle/v2/fluid/layers/control_flow.py
python/paddle/v2/fluid/layers/control_flow.py
+64
-2
python/paddle/v2/fluid/layers/ops.py
python/paddle/v2/fluid/layers/ops.py
+4
-0
python/paddle/v2/fluid/tests/test_switch.py
python/paddle/v2/fluid/tests/test_switch.py
+64
-0
未找到文件。
doc/design/switch.md
浏览文件 @
20c4a4cb
...
...
@@ -10,8 +10,7 @@ The following example shows the usage of `fluid.switch`.
a
=
fluid
.
Var
(
10
)
b
=
fluid
.
Var
(
0
)
switch
=
fluid
.
switch
()
with
switch
.
block
():
with
switch
()
as
switch
:
with
switch
.
case
(
fluid
.
less_equal
(
a
,
10
)):
fluid
.
print
(
"Case 1"
)
with
switch
.
case
(
fluid
.
larger
(
a
,
0
)):
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
20c4a4cb
...
...
@@ -41,6 +41,21 @@ class ConditionalOp : public framework::OperatorBase {
});
return
retv
;
}
bool
ScalarCondition
(
const
std
::
vector
<
const
framework
::
LoDTensor
*>
&
ips
)
const
{
if
(
!
(
ips
.
size
()
==
1UL
&&
ips
[
0
]
->
IsInitialized
()))
{
PADDLE_THROW
(
"should have one initialized input as condition"
);
}
if
(
!
(
ips
[
0
]
->
type
().
hash_code
()
==
typeid
(
bool
).
hash_code
()
&&
ips
[
0
]
->
numel
()
==
1
))
{
PADDLE_THROW
(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d"
,
ips
[
0
]
->
numel
());
}
return
ips
[
0
]
->
data
<
bool
>
()[
0
];
}
};
class
ConditionalBlockOp
:
public
ConditionalOp
{
...
...
@@ -53,9 +68,15 @@ class ConditionalBlockOp : public ConditionalOp {
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
InputTensors
(
scope
);
bool
need_run
=
std
::
all_of
(
xs
.
begin
(),
xs
.
end
(),
[](
const
framework
::
LoDTensor
*
t
)
{
return
t
->
numel
()
!=
0
;
});
bool
need_run
;
if
(
Attr
<
bool
>
(
"is_scalar_condition"
))
{
need_run
=
ScalarCondition
(
xs
);
}
else
{
need_run
=
std
::
all_of
(
xs
.
begin
(),
xs
.
end
(),
[](
const
framework
::
LoDTensor
*
t
)
{
return
t
->
numel
()
!=
0
;
});
}
if
(
need_run
)
{
auto
*
scope_var
=
scope
.
FindVar
(
Output
(
"Scope"
));
...
...
@@ -88,6 +109,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"scope is std::vector<Scope*>"
);
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
"The step block of conditional block operator"
);
AddAttr
<
bool
>
(
"is_scalar_condition"
,
"the input X is used as scalar "
"condition"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
...
...
@@ -106,9 +131,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
xs
=
this
->
InputTensors
(
scope
);
bool
need_run
=
std
::
all_of
(
xs
.
begin
(),
xs
.
end
(),
[](
const
framework
::
LoDTensor
*
t
)
{
return
t
->
numel
()
!=
0
;
});
bool
need_run
;
if
(
Attr
<
bool
>
(
"is_scalar_condition"
))
{
need_run
=
ScalarCondition
(
xs
);
}
else
{
need_run
=
std
::
all_of
(
xs
.
begin
(),
xs
.
end
(),
[](
const
framework
::
LoDTensor
*
t
)
{
return
t
->
numel
()
!=
0
;
});
}
if
(
need_run
)
{
auto
*
scope_var
=
scope
.
FindVar
(
Input
(
"Scope"
));
...
...
@@ -182,6 +213,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
,
false
));
grad_op
->
SetBlockAttr
(
"sub_block"
,
*
this
->
grad_block_
[
0
]);
grad_op
->
SetAttr
(
"is_scalar_condition"
,
GetAttr
(
"is_scalar_condition"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
...
...
python/paddle/v2/fluid/layers/control_flow.py
浏览文件 @
20c4a4cb
...
...
@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from
..
import
core
from
..framework
import
Program
,
Variable
,
Operator
from
..layer_helper
import
LayerHelper
,
unique_name
from
ops
import
logical_and
,
logical_not
,
logical_or
__all__
=
[
'split_lod_tensor'
,
...
...
@@ -27,6 +28,7 @@ __all__ = [
'StaticRNNMemoryLink'
,
'WhileGuard'
,
'While'
,
'Switch'
,
'lod_rank_table'
,
'max_sequence_len'
,
'topk'
,
...
...
@@ -1063,11 +1065,12 @@ class ConditionalBlockGuard(BlockGuard):
class
ConditionalBlock
(
object
):
def
__init__
(
self
,
inputs
,
name
=
None
):
def
__init__
(
self
,
inputs
,
is_scalar_condition
=
False
,
name
=
None
):
for
each_input
in
inputs
:
if
not
isinstance
(
each_input
,
Variable
):
raise
TypeError
(
"Each input should be variable"
)
self
.
inputs
=
inputs
self
.
is_scalar_condition
=
is_scalar_condition
self
.
helper
=
LayerHelper
(
'conditional_block'
,
name
=
name
)
def
block
(
self
):
...
...
@@ -1112,7 +1115,66 @@ class ConditionalBlock(object):
},
outputs
=
{
'Out'
:
out_list
,
'Scope'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
inside_block
})
attrs
=
{
'sub_block'
:
inside_block
,
'is_scalar_condition'
:
self
.
is_scalar_condition
})
class
Switch
(
object
):
def
__init__
(
self
,
name
=
None
):
self
.
helper
=
LayerHelper
(
'switch'
,
name
=
name
)
self
.
inside_scope
=
False
self
.
pre_not_conditions
=
[]
def
case
(
self
,
condition
):
"""create a new block for this condition
"""
if
not
self
.
inside_scope
:
raise
ValueError
(
"case should be called inside with"
)
if
len
(
self
.
pre_not_conditions
)
==
0
:
cond_block
=
ConditionalBlock
([
condition
],
is_scalar_condition
=
True
)
not_cond
=
logical_not
(
x
=
condition
)
self
.
pre_not_conditions
.
append
(
not_cond
)
else
:
pre_cond_num
=
len
(
self
.
pre_not_conditions
)
pre_not_cond
=
self
.
pre_not_conditions
[
pre_cond_num
-
1
]
new_not_cond
=
logical_and
(
x
=
pre_not_cond
,
y
=
logical_not
(
x
=
condition
))
self
.
pre_not_conditions
.
append
(
new_not_cond
)
cond_block
=
ConditionalBlock
(
[
logical_and
(
x
=
pre_not_cond
,
y
=
condition
)],
is_scalar_condition
=
True
)
return
ConditionalBlockGuard
(
cond_block
)
def
default
(
self
):
"""create a default case for this switch
"""
pre_cond_num
=
len
(
self
.
pre_not_conditions
)
if
pre_cond_num
==
0
:
raise
ValueError
(
"there should be at least one condition"
)
cond_block
=
ConditionalBlock
(
[
self
.
pre_not_conditions
[
pre_cond_num
-
1
]],
is_scalar_condition
=
True
)
return
ConditionalBlockGuard
(
cond_block
)
def
__enter__
(
self
):
"""
set flag that now is inside switch.block {}
:return:
"""
self
.
inside_scope
=
True
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
inside_scope
=
False
if
exc_type
is
not
None
:
return
False
# re-raise exception
return
True
class
IfElseBlockGuard
(
object
):
...
...
python/paddle/v2/fluid/layers/ops.py
浏览文件 @
20c4a4cb
...
...
@@ -61,6 +61,10 @@ __all__ = [
'clip_by_norm'
,
'softmax'
,
'sequence_softmax'
,
'logical_and'
,
'logical_or'
,
'logical_xor'
,
'logical_not'
,
]
+
__activations__
for
_OP
in
set
(
__all__
):
...
...
python/paddle/v2/fluid/tests/test_switch.py
0 → 100644
浏览文件 @
20c4a4cb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.v2.fluid.core
as
core
import
paddle.v2.fluid.layers
as
layers
import
paddle.v2.fluid.framework
as
framework
from
paddle.v2.fluid.executor
import
Executor
from
paddle.v2.fluid.framework
import
default_startup_program
class
TestSwitch
(
unittest
.
TestCase
):
def
check_switch
(
self
,
value
):
x
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
value
)
zero_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
0.0
)
one_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
two_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
2.0
)
three_var
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
3.0
)
result
=
layers
.
create_global_var
(
shape
=
[
1
],
value
=-
1.0
,
dtype
=
'float32'
,
persistable
=
True
)
with
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
layers
.
less_than
(
x
,
zero_var
)):
layers
.
assign
(
zero_var
,
result
)
with
switch
.
case
(
layers
.
less_than
(
x
,
one_var
)):
layers
.
assign
(
one_var
,
result
)
with
switch
.
case
(
layers
.
less_than
(
x
,
two_var
)):
layers
.
assign
(
two_var
,
result
)
with
switch
.
default
():
layers
.
assign
(
three_var
,
result
)
cpu
=
core
.
CPUPlace
()
exe
=
Executor
(
cpu
)
exe
.
run
(
default_startup_program
())
out
=
exe
.
run
(
feed
=
{},
fetch_list
=
[
result
])[
0
][
0
]
return
out
def
test_switch
(
self
):
test_data
=
{(
-
0.1
,
0
),
(
0.1
,
1
),
(
1.1
,
2
),
(
2.1
,
3
)}
for
x
,
expected_result
in
test_data
:
main_program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
with
framework
.
program_guard
(
main_program
,
startup_program
):
result
=
self
.
check_switch
(
x
)
self
.
assertEqual
(
result
,
expected_result
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录