未验证 提交 ef4895df 编写于 作者: Q qingqing01 提交者: GitHub

Make IfElse operator works and fix unit testing. (#11972)

1. Fix bug when only true or false branch works.
2. Fix bug in unit testing.
上级 1bd9fe87
...@@ -205,9 +205,10 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { ...@@ -205,9 +205,10 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
context->SetOutputsDim(framework::GradVarName("Params"), context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params")); context->GetInputsDim("Params"));
} }
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X"))); if (context->HasOutputs(framework::GradVarName("X"))) {
context->SetOutputsDim(framework::GradVarName("X"), context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X")); context->GetInputsDim("X"));
}
} }
}; };
......
...@@ -44,8 +44,10 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -44,8 +44,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto level = static_cast<size_t>(Attr<int>("level")); auto level = static_cast<size_t>(Attr<int>("level"));
auto &mask_dim = mask.dims(); PADDLE_ENFORCE(in_true.numel() || in_false.numel(),
"Input(InTrue) or Input(InFalse) should be initialized.");
auto &mask_dim = mask.dims();
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()}; std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) { if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask); cpu_mask->ShareDataWith(mask);
...@@ -59,19 +61,27 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -59,19 +61,27 @@ class MergeLoDTensorOp : public framework::OperatorBase {
} }
auto *mask_data = cpu_mask->data<bool>(); auto *mask_data = cpu_mask->data<bool>();
int rank = in_true.dims().size(); platform::Place place = dev_place;
platform::Place place = in_true.place();
std::type_index data_type = in_true.type();
framework::DDim in_true_dims =
framework::slice_ddim(in_true.dims(), 1, rank);
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
auto in_true_dim_vec = framework::vectorize(in_true_dims); std::type_index data_type =
in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size); in_true.IsInitialized() ? in_true.type() : in_false.type();
int rank;
framework::DDim in_dims;
if (in_true.IsInitialized()) {
rank = in_true.dims().size();
in_dims = framework::slice_ddim(in_true.dims(), 1, rank);
} else {
rank = in_false.dims().size();
in_dims = framework::slice_ddim(in_false.dims(), 1, rank);
}
auto in_dim_vec = framework::vectorize(in_dims);
in_dim_vec.insert(in_dim_vec.begin(), batch_size);
framework::DDim out_dims = framework::make_ddim(in_true_dim_vec); framework::DDim out_dims = framework::make_ddim(in_dim_vec);
out->Resize(out_dims); out->Resize(out_dims);
out->mutable_data(place, data_type); out->mutable_data(place, data_type);
auto *out_lod = out->mutable_lod(); auto *out_lod = out->mutable_lod();
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
import paddle import paddle
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard, default_main_program, default_startup_program from paddle.fluid.framework import Program, program_guard
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import MomentumOptimizer from paddle.fluid.optimizer import MomentumOptimizer
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid
import unittest import unittest
import numpy as np import numpy as np
...@@ -31,14 +32,13 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -31,14 +32,13 @@ class TestMNISTIfElseOp(unittest.TestCase):
label = layers.data(name='y', shape=[1], dtype='int64') label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant_batch_size_like( limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
input=label, dtype='int64', shape=[1], value=5.0)
cond = layers.less_than(x=label, y=limit) cond = layers.less_than(x=label, y=limit)
true_image, false_image = layers.split_lod_tensor( true_image, false_image = layers.split_lod_tensor(
input=image, mask=cond) input=image, mask=cond)
true_out = layers.create_tensor(dtype='float32') true_out = layers.create_tensor(dtype='float32')
true_cond = layers.ConditionalBlock([true_image]) true_cond = layers.ConditionalBlock([cond])
with true_cond.block(): with true_cond.block():
hidden = layers.fc(input=true_image, size=100, act='tanh') hidden = layers.fc(input=true_image, size=100, act='tanh')
...@@ -46,7 +46,7 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -46,7 +46,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
layers.assign(input=prob, output=true_out) layers.assign(input=prob, output=true_out)
false_out = layers.create_tensor(dtype='float32') false_out = layers.create_tensor(dtype='float32')
false_cond = layers.ConditionalBlock([false_image]) false_cond = layers.ConditionalBlock([cond])
with false_cond.block(): with false_cond.block():
hidden = layers.fc(input=false_image, size=200, act='tanh') hidden = layers.fc(input=false_image, size=200, act='tanh')
...@@ -64,7 +64,7 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192), paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200) batch_size=10)
place = core.CPUPlace() place = core.CPUPlace()
exe = Executor(place) exe = Executor(place)
...@@ -94,8 +94,7 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -94,8 +94,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
label = layers.data(name='y', shape=[1], dtype='int64') label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant_batch_size_like( limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
input=label, dtype='int64', shape=[1], value=5.0)
cond = layers.less_than(x=label, y=limit) cond = layers.less_than(x=label, y=limit)
ie = layers.IfElse(cond) ie = layers.IfElse(cond)
...@@ -125,7 +124,7 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -125,7 +124,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
place = core.CPUPlace() place = core.CPUPlace()
exe = Executor(place) exe = Executor(place)
exe.run(kwargs['startup_program']) exe.run(startup_prog)
PASS_NUM = 100 PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
...@@ -133,7 +132,7 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -133,7 +132,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1)) y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(kwargs['main_program'], outs = exe.run(prog,
feed={'x': x_data, feed={'x': x_data,
'y': y_data}, 'y': y_data},
fetch_list=[avg_loss]) fetch_list=[avg_loss])
...@@ -143,6 +142,67 @@ class TestMNISTIfElseOp(unittest.TestCase): ...@@ -143,6 +142,67 @@ class TestMNISTIfElseOp(unittest.TestCase):
self.assertFalse(True) self.assertFalse(True)
class TestIfElse(unittest.TestCase):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
src = layers.data(name='data', shape=[1], dtype='float32')
cond = layers.fill_constant(
[1], dtype='float32', value=self.cond_value)
ifcond = layers.less_than(x=src, y=cond)
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
ie.output(false_target)
if_out = ie()
out = layers.reduce_sum(if_out)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fetch_list = [out]
o1, = exe.run(fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out])
o2 = np.sum(self.data)
self.assertTrue(
np.allclose(
o1, o2, atol=1e-8),
"IfElse result : " + str(o1) + "\n Numpy result :" + str(o2))
def test_cpu(self):
self.compare_ifelse_op_and_numpy(fluid.CPUPlace())
def test_cuda(self):
if not core.is_compiled_with_cuda():
return
self.compare_ifelse_op_and_numpy(fluid.CUDAPlace(0))
class TestIfElseTrueBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 10.
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseFalseBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = -10.
self.data = np.random.rand(25, 1).astype(np.float32)
if __name__ == '__main__': if __name__ == '__main__':
# temp disable if else unittest since it could be buggy. unittest.main()
exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册