未验证 提交 20c4a4cb 编写于 作者: Q Qiao Longfei 提交者: GitHub

Impl scalar switch case op with condition op (#8184)

Impl scalar switch case op with condition op
上级 e5832019
......@@ -10,8 +10,7 @@ The following example shows the usage of `fluid.switch`.
a = fluid.Var(10)
b = fluid.Var(0)
switch = fluid.switch()
with switch.block():
with switch() as switch:
with switch.case(fluid.less_equal(a, 10)):
fluid.print("Case 1")
with switch.case(fluid.larger(a, 0)):
......
......@@ -41,6 +41,21 @@ class ConditionalOp : public framework::OperatorBase {
});
return retv;
}
bool ScalarCondition(
const std::vector<const framework::LoDTensor *> &ips) const {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() &&
ips[0]->numel() == 1)) {
PADDLE_THROW(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
return ips[0]->data<bool>()[0];
}
};
class ConditionalBlockOp : public ConditionalOp {
......@@ -53,9 +68,15 @@ class ConditionalBlockOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}
if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope"));
......@@ -88,6 +109,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"scope is std::vector<Scope*>");
AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator");
AddAttr<bool>("is_scalar_condition",
"the input X is used as scalar "
"condition")
.SetDefault(false);
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
......@@ -106,9 +131,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}
if (need_run) {
auto *scope_var = scope.FindVar(Input("Scope"));
......@@ -182,6 +213,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
......
......@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from .. import core
from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
from ops import logical_and, logical_not, logical_or
__all__ = [
'split_lod_tensor',
......@@ -27,6 +28,7 @@ __all__ = [
'StaticRNNMemoryLink',
'WhileGuard',
'While',
'Switch',
'lod_rank_table',
'max_sequence_len',
'topk',
......@@ -1063,11 +1065,12 @@ class ConditionalBlockGuard(BlockGuard):
class ConditionalBlock(object):
def __init__(self, inputs, name=None):
def __init__(self, inputs, is_scalar_condition=False, name=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
self.inputs = inputs
self.is_scalar_condition = is_scalar_condition
self.helper = LayerHelper('conditional_block', name=name)
def block(self):
......@@ -1112,7 +1115,66 @@ class ConditionalBlock(object):
},
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'sub_block': inside_block})
attrs={
'sub_block': inside_block,
'is_scalar_condition': self.is_scalar_condition
})
class Switch(object):
def __init__(self, name=None):
self.helper = LayerHelper('switch', name=name)
self.inside_scope = False
self.pre_not_conditions = []
def case(self, condition):
"""create a new block for this condition
"""
if not self.inside_scope:
raise ValueError("case should be called inside with")
if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
not_cond = logical_not(x=condition)
self.pre_not_conditions.append(not_cond)
else:
pre_cond_num = len(self.pre_not_conditions)
pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
new_not_cond = logical_and(
x=pre_not_cond, y=logical_not(x=condition))
self.pre_not_conditions.append(new_not_cond)
cond_block = ConditionalBlock(
[logical_and(
x=pre_not_cond, y=condition)],
is_scalar_condition=True)
return ConditionalBlockGuard(cond_block)
def default(self):
"""create a default case for this switch
"""
pre_cond_num = len(self.pre_not_conditions)
if pre_cond_num == 0:
raise ValueError("there should be at least one condition")
cond_block = ConditionalBlock(
[self.pre_not_conditions[pre_cond_num - 1]],
is_scalar_condition=True)
return ConditionalBlockGuard(cond_block)
def __enter__(self):
"""
set flag that now is inside switch.block {}
:return:
"""
self.inside_scope = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.inside_scope = False
if exc_type is not None:
return False # re-raise exception
return True
class IfElseBlockGuard(object):
......
......@@ -61,6 +61,10 @@ __all__ = [
'clip_by_norm',
'softmax',
'sequence_softmax',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
] + __activations__
for _OP in set(__all__):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import default_startup_program
class TestSwitch(unittest.TestCase):
def check_switch(self, value):
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
three_var = layers.fill_constant(shape=[1], dtype='float32', value=3.0)
result = layers.create_global_var(
shape=[1], value=-1.0, dtype='float32', persistable=True)
with layers.Switch() as switch:
with switch.case(layers.less_than(x, zero_var)):
layers.assign(zero_var, result)
with switch.case(layers.less_than(x, one_var)):
layers.assign(one_var, result)
with switch.case(layers.less_than(x, two_var)):
layers.assign(two_var, result)
with switch.default():
layers.assign(three_var, result)
cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run(default_startup_program())
out = exe.run(feed={}, fetch_list=[result])[0][0]
return out
def test_switch(self):
test_data = {(-0.1, 0), (0.1, 1), (1.1, 2), (2.1, 3)}
for x, expected_result in test_data:
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
result = self.check_switch(x)
self.assertEqual(result, expected_result)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册