未验证 提交 6b0d9590 编写于 作者: G Ghost Screaming 提交者: GitHub

Clean and migrate fluid APIs of paddle.fluid.layers.control_flow (#48233)

* Merge branch 'reduce_sum' of https://github.com/GhostScreaming/Paddle into mine_fluid_clean_common.

* Fix some bugs.

* Clean APIs in python/paddle/fluid/layers/control_flow.py

* Polish code style.

* Change API.

* Fix some bugs.

* Fix some bugs.
上级 3a387df6
......@@ -1594,7 +1594,7 @@ def _dynamic_decode_declarative(
max_step_num = tensor.fill_constant(
shape=[1], dtype="int64", value=max_step_num
)
while_op = control_flow.While(cond, is_test=is_test)
while_op = paddle.static.nn.control_flow.While(cond, is_test=is_test)
sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
sequence_lengths.stop_gradient = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.layers.control_flow import (
ConditionalBlock,
merge_lod_tensor,
split_lod_tensor,
)
from paddle.fluid.optimizer import MomentumOptimizer
paddle.enable_static()
class TestMNISTIfElseOp(unittest.TestCase):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_raw_api(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
true_image, false_image = split_lod_tensor(input=image, mask=cond)
true_out = paddle.tensor.create_tensor(dtype='float32')
true_cond = ConditionalBlock([cond])
with true_cond.block():
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
layers.assign(input=prob, output=true_out)
false_out = paddle.tensor.create_tensor(dtype='float32')
false_cond = ConditionalBlock([cond])
with false_cond.block():
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
layers.assign(input=prob, output=false_out)
prob = merge_lod_tensor(
in_true=true_out, in_false=false_out, mask=cond, x=image
)
loss = layers.cross_entropy(input=prob, label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=10,
)
place = core.CPUPlace()
exe = Executor(place)
exe.run(startup_prog)
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = np.expand_dims(y_data, axis=1)
outs = exe.run(
prog, feed={'x': x_data, 'y': y_data}, fetch_list=[avg_loss]
)
print(outs[0])
if outs[0] < 1.0:
return
self.assertFalse(True)
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_ifelse(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
ie = layers.IfElse(cond)
with ie.true_block():
true_image = ie.input(image)
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
with ie.false_block():
false_image = ie.input(image)
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
prob = ie()
loss = layers.cross_entropy(input=prob[0], label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200,
)
place = core.CPUPlace()
exe = Executor(place)
exe.run(startup_prog)
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(
prog, feed={'x': x_data, 'y': y_data}, fetch_list=[avg_loss]
)
print(outs[0])
if outs[0] < 1.0:
return
self.assertFalse(True)
class TestIfElse(unittest.TestCase):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def numpy_cal(self):
s1 = self.data[np.where(self.data < self.cond_value)]
res = np.sum(np.exp(s1))
s2 = self.data[np.where(self.data >= self.cond_value)]
res += np.sum(np.tanh(s2))
return res
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
src = layers.data(name='data', shape=[1], dtype='float32')
cond = layers.fill_constant(
[1], dtype='float32', value=self.cond_value
)
ifcond = paddle.less_than(x=src, y=cond)
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = paddle.exp(true_target)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
false_target = paddle.tanh(false_target)
ie.output(false_target)
if_out = ie()
out = paddle.sum(if_out[0])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fetch_list = [out]
(o1,) = exe.run(
fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out],
)
o2 = self.numpy_cal()
np.testing.assert_allclose(
o1,
o2,
rtol=1e-05,
atol=1e-08,
)
def test_cpu(self):
self.compare_ifelse_op_and_numpy(fluid.CPUPlace())
def test_cuda(self):
if not core.is_compiled_with_cuda():
return
self.compare_ifelse_op_and_numpy(fluid.CUDAPlace(0))
class TestIfElseTrueBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 10.0
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseFalseBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = -10.0
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseError(unittest.TestCase):
def test_input_type_error(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
src = layers.data(name='data', shape=[1], dtype='float32')
const_value = layers.fill_constant(
[1], dtype='float32', value=123.0
)
ifcond = paddle.less_than(x=src, y=const_value)
with self.assertRaises(TypeError):
ie = layers.IfElse(set())
with self.assertRaises(TypeError):
ie = layers.IfElse(ifcond, set())
with self.assertRaises(TypeError):
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = paddle.exp(true_target)
ie.output([])
if __name__ == '__main__':
unittest.main()
......@@ -174,7 +174,7 @@ def get_program():
cond = paddle.less_than(x=i, y=loop_len)
auto.shard_tensor(cond, _g_process_mesh, [None])
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
pre_input = fluid.layers.array_read(array=input_array, i=i)
......
......@@ -84,7 +84,9 @@ class TestHybridParallelInferenceHelperClass(unittest.TestCase):
)
print(cond_int.shape)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond, is_test=True)
while_op = paddle.static.nn.control_flow.While(
cond, is_test=True
)
with while_op.block():
with paddle.fluid.device_guard(f'{device}:all'):
......
......@@ -1763,7 +1763,7 @@ def fast_decode(
shape=[1], dtype=start_tokens.dtype, value=0
)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
while_op = paddle.static.nn.control_flow.While(cond)
# array states will be stored for each step.
ids = layers.array_write(
paddle.reshape(start_tokens, (-1, 1)), step_idx
......
......@@ -161,7 +161,7 @@ def dyfunc_ifExp_with_while(x):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
i, ten, y = fluid.layers.while_loop(cond, body, [i, ten, y])
i, ten, y = paddle.static.nn.while_loop(cond, body, [i, ten, y])
return y[0]
......
......@@ -145,7 +145,7 @@ class TestSetValueItemSlice5(TestSetValueApi):
# return i, x
#
# i = paddle.zeros(shape=(1, ), dtype='int32')
# i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
# i, x = paddle.static.nn.while_loop(cond, body, [i, x])
#
# def _get_answer(self):
# self.data[0] = self.value
......
......@@ -147,7 +147,7 @@ class TestSetValueItemSlice4(TestSetValueApi):
# return i, x
# i = paddle.zeros(shape=(1, ), dtype='int32')
# i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
# i, x = paddle.static.nn.while_loop(cond, body, [i, x])
# def _get_answer(self):
# self.data[0] = self.value
......
......@@ -64,8 +64,8 @@ class TestWhileOp(unittest.TestCase):
cond2 = paddle.logical_or(x=j, y=array_len2)
cond2 = paddle.ones(shape=[1], dtype='int32')
cond2 = layers.cast(cond2, 'bool')
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......
......@@ -17,9 +17,20 @@ import unittest
import numpy as np
import paddle
sys.path.append("../")
from op_test import OpTest, skip_check_grad_ci
from test_reorder_lod_tensor import convert_to_offset
paddle.enable_static()
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
def compute_seqpool_sum(x, offset, out, pad_value=0.0):
......
......@@ -24,6 +24,8 @@ import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
class TestAPICase(unittest.TestCase):
def test_return_single_var(self):
......@@ -46,25 +48,29 @@ class TestAPICase(unittest.TestCase):
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
# call fn_1
out_0 = layers.case(
out_0 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
)
# call fn_2
out_1 = layers.case(
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
# call default fn_3
out_2 = layers.case(
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
)
# no default, call fn_2
out_3 = layers.case(pred_fn_pairs=[(pred_1, fn_2)])
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_2)]
)
# no default, call fn_2. but pred_2 is false
out_4 = layers.case(pred_fn_pairs=[(pred_2, fn_2)])
out_4 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_2)]
)
place = (
fluid.CUDAPlace(0)
......@@ -109,7 +115,9 @@ class TestAPICase(unittest.TestCase):
pred_1 = paddle.equal(x, y) # true
pred_2 = paddle.equal(x, z) # false
out = layers.case(((pred_1, fn_1), (pred_2, fn_2)), fn_3)
out = paddle.static.nn.control_flow.case(
((pred_1, fn_1), (pred_2, fn_2)), fn_3
)
place = (
fluid.CUDAPlace(0)
......@@ -132,7 +140,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_1(x=1):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(
var_5 < var_6,
......@@ -159,7 +167,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_2(x=2):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(var_5 < var_6, partial(fn_1, x=x)),
(
......@@ -178,7 +186,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_3():
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(var_5 < var_6, partial(fn_2, x=3)),
(
......@@ -203,15 +211,15 @@ class TestAPICase_Nested(unittest.TestCase):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
out_1 = layers.case(
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
out_3 = layers.case(
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3
)
......@@ -243,37 +251,49 @@ class TestAPICase_Error(unittest.TestCase):
# The type of 'pred_fn_pairs' in case must be list or tuple
def type_error_pred_fn_pairs():
layers.case(pred_fn_pairs=1, default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=1, default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_pairs)
# The elements' type of 'pred_fn_pairs' in Op(case) must be tuple
def type_error_pred_fn_1():
layers.case(pred_fn_pairs=[1], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[1], default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_1)
# The tuple's size of 'pred_fn_pairs' in Op(case) must be 2
def type_error_pred_fn_2():
layers.case(pred_fn_pairs=[(1, 2, 3)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(1, 2, 3)], default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_2)
# The pred's type of 'pred_fn_pairs' in Op(case) must be bool Variable
def type_error_pred():
layers.case(pred_fn_pairs=[(1, fn_1)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(1, fn_1)], default=fn_1
)
self.assertRaises(TypeError, type_error_pred)
# The function of pred_fn_pairs in case must be callable
def type_error_fn():
layers.case(pred_fn_pairs=[(pred_1, 2)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, 2)], default=fn_1
)
self.assertRaises(TypeError, type_error_fn)
# The default in Op(case) must be callable
def type_error_default():
layers.case(pred_fn_pairs=[(pred_1, fn_1)], default=fn_1())
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1)], default=fn_1()
)
self.assertRaises(TypeError, type_error_default)
......@@ -308,7 +328,9 @@ class TestMutiTask(unittest.TestCase):
loss = paddle.mean(sum, name="f_2_loss")
adagrad.minimize(loss)
layers.case(pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2
)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......
......@@ -19,6 +19,8 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
paddle.enable_static()
def execute(main_program, startup_program):
if paddle.is_compiled_with_cuda():
......@@ -153,7 +155,7 @@ class TestDeviceGuard(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.static.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
......
......@@ -20,6 +20,8 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
paddle.enable_static()
def build_and_run_program(place, batch_size, beam_size, stop_gradient=False):
fluid.default_startup_program().random_seed = 1
......@@ -37,7 +39,7 @@ def build_and_run_program(place, batch_size, beam_size, stop_gradient=False):
shape=[1], dtype="int64", value=10, force_cpu=True
)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
while_op = paddle.static.nn.control_flow.While(cond)
scores = layers.array_write(x, step_idx)
with while_op.block():
bs = layers.cast(paddle.shape(x)[0], "int64")
......
......@@ -103,8 +103,8 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......
......@@ -21,7 +21,14 @@ from sequence.test_sequence_pool import (
compute_seqpool_sqrt,
compute_seqpool_sum,
)
from test_reorder_lod_tensor import convert_to_offset
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestFusionSeqPoolConcatOp(OpTest):
......
......@@ -22,7 +22,14 @@ from sequence.test_sequence_pool import (
compute_seqpool_sum,
)
from test_cvm_op import cvm_compute
from test_reorder_lod_tensor import convert_to_offset
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestFusionSeqPoolCVMConcatOp(OpTest):
......
......@@ -24,6 +24,8 @@ from paddle.fluid import core, unique_name
LOADED_VAR_SUFFIX = ".load_0"
paddle.enable_static()
def while_softmax_regression(img):
def cond(i, times, pred):
......@@ -37,7 +39,7 @@ def while_softmax_regression(img):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
times = fluid.layers.fill_constant(shape=[1], dtype='int64', value=5)
pred = fluid.layers.fc(input=img, size=10, act='softmax')
i, times, pred = fluid.layers.while_loop(
i, times, pred = paddle.static.nn.while_loop(
cond=cond, body=body, loop_vars=[i, times, pred]
)
return pred
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# nlp model stack of op operate on lod. It's a classical test case in optimize pass.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import MomentumOptimizer
class TestIrMemoryOptimizeIfElseOp(unittest.TestCase):
def check_network_convergence(
self, use_cuda=True, use_mem_opt=False, iter_num=5
):
paddle.seed(100)
paddle.framework.random._manual_program_seed(100)
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
ie = layers.IfElse(cond)
with ie.true_block():
true_image = ie.input(image)
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
with ie.false_block():
false_image = ie.input(image)
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
prob = ie()
loss = layers.cross_entropy(input=prob[0], label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=200
)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = Executor(place)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy._use_device = (
core.DeviceType.CUDA if use_cuda else core.DeviceType.CPU
)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = use_mem_opt
train_cp = compiler.CompiledProgram(fluid.default_main_program())
train_cp = train_cp.with_data_parallel(
loss_name=avg_loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy,
)
fetch_list = [avg_loss.name]
exe.run(startup_prog)
PASS_NUM = 100
loop = 0
ret = []
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(
train_cp,
feed={'x': x_data, 'y': y_data},
fetch_list=[avg_loss],
)
loop += 1
ret.append(outs[0])
if iter_num == loop:
return ret
return ret
def test_ifelse(self):
ret1 = self.check_network_convergence(False, True)
print(ret1)
ret2 = self.check_network_convergence(False, False)
print(ret2)
np.testing.assert_allclose(ret1, ret2, rtol=1e-05)
if fluid.core.is_compiled_with_cuda():
ret1 = self.check_network_convergence(True, True)
print(ret1)
ret2 = self.check_network_convergence(True, False)
print(ret2)
np.testing.assert_allclose(ret1, ret2, rtol=1e-05)
if __name__ == "__main__":
unittest.main()
......@@ -1387,7 +1387,7 @@ class TestLayer(LayerTest):
def body(i):
return i + 1
out = layers.while_loop(cond, body, [i])
out = paddle.static.nn.while_loop(cond, body, [i])
static_ret = self.get_static_graph_result(feed={}, fetch_list=out)
with self.dynamic_graph():
......@@ -1400,14 +1400,14 @@ class TestLayer(LayerTest):
def body1(i):
return i + 1
dy_ret = layers.while_loop(cond1, body1, [i])
dy_ret = paddle.static.nn.while_loop(cond1, body1, [i])
with self.assertRaises(ValueError):
j = layers.fill_constant(shape=[1], dtype='int64', value=0)
def body2(i):
return i + 1, i + 2
layers.while_loop(cond1, body2, [j])
paddle.static.nn.while_loop(cond1, body2, [j])
np.testing.assert_array_equal(static_ret[0], dy_ret[0].numpy())
......@@ -1659,10 +1659,12 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
place = (
fluid.CUDAPlace(0)
......@@ -1682,10 +1684,10 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
eager_dynamic_res1 = out_1.numpy()
......@@ -1699,10 +1701,12 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
dynamic_res1 = out_1.numpy()
dynamic_res2 = out_2.numpy()
......@@ -1725,17 +1729,17 @@ class TestLayer(LayerTest):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......@@ -1759,17 +1763,17 @@ class TestLayer(LayerTest):
shape=[1], dtype='int32', value=2
)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......@@ -1781,17 +1785,17 @@ class TestLayer(LayerTest):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy
from paddle.fluid import Program, core, program_guard
from paddle.fluid.executor import Executor
from paddle.fluid.layers import data
from paddle.fluid.layers.control_flow import lod_rank_table
class TestLoDRankTable(unittest.TestCase):
def test_lod_rank_table(self):
x = data(name='x', shape=[100])
cpu = core.CPUPlace()
rank_table = lod_rank_table(x=x, level=1)
rank_table.persistable = True
exe = Executor(cpu)
scope = core.Scope()
tensor = core.LoDTensor()
tensor.set(numpy.random.random(size=(17, 100)), cpu)
tensor.set_recursive_sequence_lengths(
[[1, 2], [5, 1, 1], [3, 1, 5, 1, 3, 3, 1]]
)
exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items()))
class TestLoDRankTableError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x = numpy.random.random((2, 4)).astype("float32")
def test_Variable():
rank_table = lod_rank_table(x=x, level=1)
self.assertRaises(TypeError, test_Variable)
def test_list_Variable():
rank_table = lod_rank_table(x=[x], level=1)
self.assertRaises(TypeError, test_list_Variable)
x = data(name='x', shape=[10], dtype='float32', lod_level=1)
out = lod_rank_table(x=x, level=0)
out = lod_rank_table(x=[x], level=0)
if __name__ == '__main__':
unittest.main()
......@@ -101,7 +101,7 @@ def static(
mod_two = paddle.remainder(id, two) == 0
if loss_in_switch:
avg_loss = layers.case(
avg_loss = paddle.static.nn.case(
[(mod_two, lambda: fn_1(adam, None, prediction, label))],
lambda: fn_2(sgd, None, prediction, label),
)
......@@ -112,7 +112,7 @@ def static(
logits=prediction, label=label
)
avg_loss_2 = paddle.mean(loss_2)
avg_loss = layers.case(
avg_loss = paddle.static.nn.case(
[(mod_two, lambda: fn_1(adam, avg_loss_1))],
lambda: fn_2(sgd, avg_loss_2),
)
......@@ -264,7 +264,7 @@ class TestMultiOptimizersMultiCardsError(unittest.TestCase):
cond = layers.fill_constant([1], 'bool', True)
layers.case(
paddle.static.nn.case(
[(cond, lambda: fn_1(adam, avg_loss))],
lambda: fn_2(sgd, avg_loss),
)
......
......@@ -46,7 +46,7 @@ class TestProfiler(unittest.TestCase):
until = layers.fill_constant([1], dtype='int64', value=10)
data_arr = layers.array_write(hidden1, i)
cond = paddle.less_than(x=counter, y=until)
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
hidden_n = fluid.layers.fc(input=hidden1, size=64, act='relu')
layers.array_write(hidden_n, i, data_arr)
......
......@@ -100,7 +100,7 @@ def cond_net(use_feed=None):
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = two == 0
avg_loss = fluid.layers.case(
avg_loss = paddle.static.nn.case(
[(pred, lambda: loss1(prediction, label))],
lambda: loss2(prediction, label),
)
......@@ -132,7 +132,7 @@ def optimization_in_cond_net(with_optimize=False):
sgd = fluid.optimizer.SGD(learning_rate=0.1)
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = two == 0
avg_loss = fluid.layers.case(
avg_loss = paddle.static.nn.case(
[(pred, lambda: loss1(sgd, prediction, label, with_optimize))],
lambda: loss2(sgd, prediction, label, with_optimize),
)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.layers.control_flow import lod_rank_table
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestReorderLoDTensor(unittest.TestCase):
num_seq = 5
# [name, shape, lod_level] pair indicating data info of source and target
data_desc = (['input', [9], 0], ['ref', [5], 1])
@classmethod
def setUpClass(cls):
cls.set_program()
@classmethod
def set_program(cls):
dat = fluid.layers.data(
name=cls.data_desc[0][0], shape=cls.data_desc[0][1]
)
dat.stop_gradient = False
rank_dat = fluid.layers.data(
name=cls.data_desc[1][0], shape=cls.data_desc[1][1]
)
table = lod_rank_table(rank_dat)
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=dat, rank_table=table
)
loss = paddle.sum(new_dat)
fluid.backward.append_backward(loss=loss)
cls.fetch_list = [new_dat, cls.data_desc[0][0] + '@GRAD']
def run_program(self):
outputs = []
input_grads = []
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.set_inputs(place)
exe = fluid.Executor(place)
output, input_grad = exe.run(
fluid.default_main_program(),
feed=self.inputs,
fetch_list=self.fetch_list,
return_numpy=False,
)
outputs.append(output)
input_grads.append(input_grad)
self.actual_outputs = outputs
self.actual_grads = input_grads
def set_data(self):
self.data = {}
for desc in self.data_desc:
data_name = desc[0]
data_shape = desc[1]
data_lod_level = desc[2]
data_lod = []
for i in range(data_lod_level):
lod_level_i = np.random.randint(
low=1,
high=5,
size=self.num_seq
if i == 0
else sum(lod_level_i), # noqa: F821
).tolist()
data_lod.append(lod_level_i)
data_value = np.random.random(
size=[sum(data_lod[-1]) if data_lod else self.num_seq]
+ data_shape
).astype('float32')
self.data[data_name] = (data_value, data_lod)
def set_inputs(self, place):
self.inputs = {}
for desc in self.data_desc:
tensor = fluid.Tensor()
tensor.set(self.data[desc[0]][0], place)
if self.data[desc[0]][1]:
tensor.set_recursive_sequence_lengths(self.data[desc[0]][1])
self.inputs[desc[0]] = tensor
def reorder(self):
level = 0
# compute the rank_table according to ref_lod
ref_lod = self.data[self.data_desc[1][0]][1][level]
rank_table = [] # list of (index, length)
for i in range(len(ref_lod)):
rank_table.append((i, ref_lod[i]))
rank_table = sorted(
rank_table, key=functools.cmp_to_key(lambda x, y: y[1] - x[1])
)
# compute the input sequence info according to input_lod
input_value, input_lod = self.data[self.data_desc[0][0]]
offset_lod = convert_to_offset(input_lod)
input_table = [] # list of (offset, length, sub_lod)
if offset_lod:
for i in range(len(offset_lod[level]) - 1):
start_idx = i
end_idx = i + 1
sub_lod = []
for lod_level_i in offset_lod[level:]:
sub_lod_i = []
for idx in range(start_idx, end_idx):
sub_lod_i.append(
lod_level_i[idx + 1] - lod_level_i[idx]
)
sub_lod.append(sub_lod_i)
start_idx = lod_level_i[start_idx]
end_idx = lod_level_i[end_idx]
input_table.append((start_idx, end_idx - start_idx, sub_lod))
else:
input_table = [(i, 1, []) for i in range(len(rank_table))]
# reorder by rank_table
output_value = np.zeros_like(input_value)
output_lod = []
offset = 0
for index, length in rank_table:
input_seq_start = input_table[index][0]
input_seq_len = input_table[index][1]
input_seq_end = input_seq_start + input_seq_len
output_value[offset : offset + input_seq_len] = input_value[
input_seq_start:input_seq_end
]
offset += input_seq_len
input_seq_sub_lod = input_table[index][2]
if len(output_lod) == 0:
output_lod = [[] for i in input_seq_sub_lod]
for i, level in enumerate(input_seq_sub_lod):
output_lod[i].extend(level)
return output_value, output_lod
def test_reorder_lod_tensor(self):
self.data_desc[0][-1] = 2 # input is lod_tensor
self.set_data()
self.run_program()
# check output
expect_output, expect_output_lod = self.reorder()
for actual_output in self.actual_outputs:
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_output_lod, actual_output.recursive_sequence_lengths()
)
# check gradient
expect_grad = np.ones_like(self.data[self.data_desc[0][0]][0])
expect_grad_lod = self.data[self.data_desc[0][0]][1]
for actual_grad in self.actual_grads:
np.testing.assert_allclose(
np.array(actual_grad), expect_grad, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_grad_lod, actual_grad.recursive_sequence_lengths()
)
def test_reorder_tensor(self):
self.data_desc[0][-1] = 0 # input is tensor
self.set_data()
self.run_program()
# check output
expect_output, expect_output_lod = self.reorder()
for actual_output in self.actual_outputs:
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_output_lod, actual_output.recursive_sequence_lengths()
)
# check gradient
expect_grad = np.ones_like(self.data[self.data_desc[0][0]][0])
expect_grad_lod = self.data[self.data_desc[0][0]][1]
for actual_grad in self.actual_grads:
np.testing.assert_allclose(
np.array(actual_grad), expect_grad, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_grad_lod, actual_grad.recursive_sequence_lengths()
)
# compare outputs between LodTensors with explicit and implicit lod
# use the same data but set the input lod explicitly
input_lod = [[1] * len(self.data[self.data_desc[0][0]][0])]
self.inputs[self.data_desc[0][0]].set_recursive_sequence_lengths(
input_lod
)
# preserve the output of LodTensor with implicit lod to compare
expect_outputs = [
np.array(actual_output) for actual_output in self.actual_outputs
]
self.run_program()
for actual_output, expect_output in zip(
self.actual_outputs, expect_outputs
):
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
class TestReorderLoDTensorError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
def test_Variable():
# The input must be Variable.
x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
table1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=x1, rank_table=table1
)
self.assertRaises(TypeError, test_Variable)
def test_type():
x2 = fluid.layers.data(name='x1', shape=[4], dtype='float32')
table2 = fluid.layers.data(
name='table2', shape=[4], dtype='int32'
)
new_dat2 = fluid.layers.reorder_lod_tensor_by_rank(
x=x2, rank_table=table2
)
self.assertRaises(TypeError, test_type)
if __name__ == '__main__':
unittest.main()
......@@ -156,7 +156,7 @@ class TestSetValueItemSliceInWhile(TestSetValueApi):
return i, x
i = paddle.zeros(shape=(1,), dtype='int32')
i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
i, x = paddle.static.nn.while_loop(cond, body, [i, x])
def _get_answer(self):
self.data[0] = self.value
......
......@@ -17,11 +17,14 @@ from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
class TestAPISwitchCase(unittest.TestCase):
def test_return_single_var(self):
......@@ -42,29 +45,29 @@ class TestAPISwitchCase(unittest.TestCase):
index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
# call fn_1
out_0 = layers.switch_case(
out_0 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
)
# call default fn_3
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3,
)
# no default, call fn_2
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
)
# no default, call fn_2 but branch_index is 5
out_4 = layers.switch_case(
out_4 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
)
......@@ -132,7 +135,9 @@ class TestAPISwitchCase(unittest.TestCase):
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)
out = paddle.static.nn.switch_case(
index_1, ((1, fn_1), (2, fn_2)), fn_3
)
place = (
fluid.CUDAPlace(0)
......@@ -153,7 +158,7 @@ class TestAPISwitchCase(unittest.TestCase):
class TestAPISwitchCase_Nested(unittest.TestCase):
def test_nested_switch_case(self):
def fn_1(x=1):
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=x
),
......@@ -169,7 +174,7 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
return out
def fn_2(x=2):
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=2
),
......@@ -186,7 +191,7 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
return out
def fn_3():
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=3
),
......@@ -209,14 +214,14 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
......@@ -277,7 +282,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of 'branch_index' in Op(switch_case) must be Variable
def type_error_branch_index():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=1, branch_fns=[(1, fn_1)], default=fn_3
)
......@@ -285,7 +290,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
def dtype_error_branch_index():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_float32,
branch_fns=[(1, fn_1)],
default=fn_3,
......@@ -295,7 +300,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
def type_error_branch_fns():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=1, default=fn_3
)
......@@ -303,7 +308,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The elements' type of 'branch_fns' in Op(switch_case) must be tuple
def type_error_index_fn_pair_1():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[1], default=fn_3
)
......@@ -311,7 +316,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The tuple's size of 'branch_fns' in Op(switch_case) must be 2
def type_error_index_fn_pair_2():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[(1, 2, 3)], default=fn_3
)
......@@ -319,7 +324,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The key's type of 'branch_fns' in Op(switch_case) must be int
def type_error_key():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3
)
......@@ -327,7 +332,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The key in 'branch_fns' must be unique
def value_error_key():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(2, fn_1), (2, fn_2)],
default=fn_3,
......@@ -337,7 +342,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of function in 'branch_fns' must be callable
def type_error_fn():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(1, 1), (2, fn_2)],
default=fn_3,
......@@ -347,7 +352,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The default in Op(case) must be callable
def type_error_default():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(1, fn_1), (2, fn_2)],
default=1,
......
......@@ -21,6 +21,8 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestTensorArrayToTensorError(unittest.TestCase):
"""Tensor_array_to_tensor error message enhance"""
......@@ -288,7 +290,9 @@ class TestTensorArrayToTensorAPI(unittest.TestCase):
fluid.layers.array_write(prev, i, array)
return i + 1, end, array
_, _, array = fluid.layers.while_loop(cond, body, [i, ten, array])
_, _, array = paddle.static.nn.while_loop(
cond, body, [i, ten, array]
)
self.assertTrue(paddle.tensor.array_length(array), 10)
last = fluid.layers.fill_constant(shape=[1], dtype='int64', value=9)
......
......@@ -40,7 +40,7 @@ class TestApiWhileLoop(unittest.TestCase):
i = layers.fill_constant(shape=[1], dtype='int64', value=0)
one = layers.fill_constant(shape=[1], dtype='int64', value=1)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
out = layers.while_loop(cond, body, (i,))
out = paddle.static.nn.while_loop(cond, body, (i,))
place = (
fluid.CUDAPlace(0)
......@@ -69,7 +69,7 @@ class TestApiWhileLoop(unittest.TestCase):
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
mem = fluid.data(name='mem', shape=[10], dtype='float32')
one = layers.fill_constant(shape=[10], dtype='float32', value=1)
out = layers.while_loop(cond, body, [i, mem])
out = paddle.static.nn.while_loop(cond, body, [i, mem])
data = np.random.rand(10).astype('float32')
data_one = np.ones(10).astype('float32')
......@@ -122,7 +122,13 @@ class TestApiWhileLoop(unittest.TestCase):
}
]
i, ten, test_dict, test_list, test_list_dict = layers.while_loop(
(
i,
ten,
test_dict,
test_list,
test_list_dict,
) = paddle.static.nn.while_loop(
cond, body, [i, ten, test_dict, test_list, test_list_dict]
)
place = (
......@@ -171,7 +177,7 @@ class TestApiWhileLoop_Nested(unittest.TestCase):
j = layers.increment(j)
return [j, init, sums]
result = layers.while_loop(
result = paddle.static.nn.while_loop(
internal_cond, internal_body, [j, init, sums]
)
j = result[0]
......@@ -192,7 +198,7 @@ class TestApiWhileLoop_Nested(unittest.TestCase):
loop_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
ones = layers.fill_constant(shape=[3, 3], dtype='float32', value=1)
out = layers.while_loop(
out = paddle.static.nn.while_loop(
external_cond, external_body, [i, j, init, sums]
)
......@@ -236,7 +242,7 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
append_backward(mean)
......@@ -277,7 +283,7 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
append_backward(mean)
......@@ -328,7 +334,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
outer_sum_1 = paddle.add(x=x, y=outer_sum_0)
i = layers.increment(x=i, in_place=True)
layers.array_write(outer_sum_1, i=i, array=mem_array)
j, x, mem_array = layers.while_loop(
j, x, mem_array = paddle.static.nn.while_loop(
internal_cond, internal_body, [j, x, mem_array]
)
return [i, j, x, mem_array]
......@@ -357,7 +363,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out = layers.while_loop(
out = paddle.static.nn.while_loop(
external_cond, external_body, [i, j, x, mem_array]
)
......@@ -405,7 +411,7 @@ class TestApiWhileLoopWithSwitchCase(unittest.TestCase):
data_add_one = paddle.add(x=i, y=one)
return data_add_one
return layers.switch_case(
return paddle.static.nn.switch_case(
branch_index=i,
branch_fns={2: fn_add_three, 5: fn_square},
default=fn_add_one,
......@@ -418,7 +424,7 @@ class TestApiWhileLoopWithSwitchCase(unittest.TestCase):
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
three = layers.fill_constant(shape=[1], dtype='int64', value=3)
one = layers.fill_constant(shape=[1], dtype='int64', value=1)
out = layers.while_loop(cond, body, [i])
out = paddle.static.nn.while_loop(cond, body, [i])
place = (
fluid.CUDAPlace(0)
......@@ -488,13 +494,13 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `cond` in Op(while_loop) must be callable
def type_error_cond():
out = layers.while_loop(data, body, [data_1d])
out = paddle.static.nn.while_loop(data, body, [data_1d])
self.assertRaises(TypeError, type_error_cond)
# The type of `body` in Op(while_loop) must be callable
def type_error_body():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, data, [data_1d]
)
......@@ -502,25 +508,31 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `loop_vars` in Op(while_loop) must be list or tuple
def type_error_loop_vars():
out = layers.while_loop(cond_returns_bool_tensor, body, data_1d)
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body, data_1d
)
self.assertRaises(TypeError, type_error_loop_vars)
# The value of `loop_vars` is empty
def value_error_loop_vars():
out = layers.while_loop(cond_returns_bool_tensor, body, [])
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body, []
)
self.assertRaises(ValueError, value_error_loop_vars)
# The type of `cond` returns in Op(while_loop) must be Variable
def type_error_cond_returns_not_variable():
out = layers.while_loop(cond_returns_constant, body, [data_1d])
out = paddle.static.nn.while_loop(
cond_returns_constant, body, [data_1d]
)
self.assertRaises(TypeError, type_error_cond_returns_not_variable)
# The type of `cond` returns in Op(while_loop) must be a bollean variable
def type_error_cond_returns_not_boolean():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_not_bool_tensor, body, [data_1d]
)
......@@ -528,13 +540,15 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The shape of `cond` returns in Op(while_loop) must be 1
def type_error_shape_cond_returns_2d():
out = layers.while_loop(cond_returns_2d_tensor, body, [data_2d])
out = paddle.static.nn.while_loop(
cond_returns_2d_tensor, body, [data_2d]
)
self.assertRaises(TypeError, type_error_shape_cond_returns_2d)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_length():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body_returns_error_length, [data]
)
......@@ -542,7 +556,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_type():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_receives_two_args, body_returns_error_type, [data, ten]
)
......@@ -555,7 +569,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
shape=[2, 2], dtype='int64', value=1
)
}
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_with_mutable_dict,
body_returns_with_mutable_dict,
[data, test_dict],
......@@ -569,7 +583,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
test_list = [
layers.fill_constant(shape=[2, 2], dtype='int64', value=1)
]
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_with_mutable_list,
body_returns_with_mutable_list,
[data, test_list],
......@@ -597,7 +611,7 @@ class TestApiWhileLoopSliceInBody(unittest.TestCase):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_shape = paddle.shape(x)
i = fluid.layers.fill_constant([1], 'int32', 0)
z, _ = fluid.layers.while_loop(cond, body, [z, i])
z, _ = paddle.static.nn.while_loop(cond, body, [z, i])
place = (
fluid.CUDAPlace(0)
......
......@@ -56,8 +56,8 @@ class TestWhileOp(unittest.TestCase):
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......@@ -122,10 +122,10 @@ class TestWhileOp(unittest.TestCase):
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = paddle.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
class BadInputTest(unittest.TestCase):
......@@ -157,7 +157,7 @@ class TestIgnoreVarNameInWhile(unittest.TestCase):
i = layers.fill_constant(shape=[1], value=0, dtype='int32')
num = layers.fill_constant(shape=[1], value=5, dtype='int32')
i, ten, shuffle_temp, y = layers.while_loop(
i, ten, shuffle_temp, y = paddle.static.nn.while_loop(
cond, body_func, [i, num, temp, y]
)
......
......@@ -159,7 +159,7 @@ class TestDeviceGuard(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.static.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
......
......@@ -55,8 +55,8 @@ class TestWhileOp(unittest.TestCase):
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......@@ -121,10 +121,10 @@ class TestWhileOp(unittest.TestCase):
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = paddle.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
if __name__ == '__main__':
......
......@@ -21,10 +21,14 @@ from .common import deform_conv2d # noqa: F401
from .common import conv3d # noqa: F401
from .common import conv2d_transpose # noqa: F401
from .common import conv3d_transpose # noqa: F401
from .control_flow import (
case,
while_loop,
switch_case,
)
from .common import bilinear_tensor_product # noqa: F401
from .common import py_func # noqa: F401
from ...tensor.creation import create_parameter # noqa: F401
from ...fluid.layers import case # noqa: F401
from ...fluid.layers import cond # noqa: F401
from ...fluid.layers import conv2d # noqa: F401
from ...fluid.layers import crf_decoding # noqa: F401
......@@ -34,8 +38,6 @@ from .loss import nce # noqa: F401
from .common import prelu # noqa: F401
from ...fluid.layers import row_conv # noqa: F401
from ...fluid.layers import spectral_norm # noqa: F401
from ...fluid.layers import switch_case # noqa: F401
from ...fluid.layers import while_loop # noqa: F401
from ...fluid.input import embedding # noqa: F401
from ...fluid.contrib.layers import sparse_embedding # noqa: F401
......
此差异已折叠。
......@@ -212,7 +212,6 @@ HIGH_PARALLEL_JOB_NEW = [
'check_reduce_rank_test',
'test_progressbar',
'test_seed_op',
'test_shrink_rnn_memory',
'test_fc_bf16_mkldnn_op',
'test_sequence_first_step',
'test_fusion_lstm_mkldnn_op',
......@@ -273,7 +272,6 @@ HIGH_PARALLEL_JOB_NEW = [
'test_fleet_graph_executor',
'decorator_test',
'test_collective_base',
'test_lod_rank_table',
'test_multi_gru_mkldnn_op',
'test_eager_deletion_conditional_block',
'op_proto_maker_test',
......@@ -868,7 +866,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_imperative_load_static_param',
'test_imperative_qat_user_defined',
'test_anchor_generator_op',
'test_if_else_op',
'test_prepare_op',
'test_conj_op',
'test_imperative_hook_for_layer',
......@@ -1099,7 +1096,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_sequence_mask',
'test_fill_op',
'test_imperative_deepcf',
'test_reorder_lod_tensor',
'test_multiply',
'test_partial_program',
'test_fetch_feed',
......@@ -1264,7 +1260,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_imperative_static_runner_mnist',
'test_nearest_interp_op',
'test_diag_embed',
'test_imperative_basic',
'test_merge_selectedrows_op',
'test_feed_data_check_shape_type',
'test_complex_trace_layer',
......@@ -1740,7 +1735,6 @@ CPU_PARALLEL_JOB = [
'test_simplify_with_basic_ops_pass',
'test_similarity_focus_op',
'test_shuffle_batch_op',
'test_shrink_rnn_memory',
'test_set_bool_attr',
'test_sequence_topk_avg_pooling',
'test_sequence_scatter_op',
......@@ -1846,7 +1840,6 @@ CPU_PARALLEL_JOB = [
'test_logger',
'test_lod_tensor_array_ops',
'test_lod_tensor_array',
'test_lod_rank_table',
'test_locality_aware_nms_op',
'test_load_vars_shape_check',
'test_load_op_xpu',
......@@ -2373,7 +2366,6 @@ TETRAD_PARALLEL_JOB = [
'test_trt_conv3d_op',
'test_parallel_executor_drop_scope',
'test_tensorrt_engine',
'test_ir_memory_optimize_ifelse_op',
'test_parallel_executor_mnist',
'test_load_state_dict_from_old_format',
'test_fuse_elewise_add_act_pass',
......@@ -2594,7 +2586,6 @@ TETRAD_PARALLEL_JOB = [
'test_imperative_hook_for_layer',
'test_complex_sum_layer',
'test_complex_cast',
'test_reorder_lod_tensor',
'test_complex_kron',
'test_complex_trace_layer',
'test_merge_selectedrows_op',
......@@ -2851,7 +2842,6 @@ TWO_PARALLEL_JOB = [
'test_imperative_data_parallel',
'test_norm_nn_grad',
'test_im2sequence_op',
'test_if_else_op',
'test_one_hot_v2_op',
'test_grid_sampler_op',
'test_pad_op',
......@@ -3068,7 +3058,6 @@ TWO_PARALLEL_JOB = [
'test_broadcast_tensors_op',
'test_pad3d_op',
'test_cumprod_op',
'test_imperative_basic',
'trt_fc_prelu_test',
'test_sigmoid_focal_loss',
'test_pixel_shuffle',
......
......@@ -263,7 +263,6 @@ STATIC_MODE_TESTING_LIST = [
'test_huber_loss_op',
'test_im2sequence_op',
'test_image_classification_layer',
'test_imperative_basic',
'test_imperative_deepcf',
'test_imperative_framework',
'test_imperative_gan',
......@@ -293,7 +292,6 @@ STATIC_MODE_TESTING_LIST = [
'test_inverse_op',
'test_io_save_load',
'test_iou_similarity_op',
'test_ir_memory_optimize_ifelse_op',
'test_ir_memory_optimize_pass',
'test_is_empty_op',
'test_isfinite_op',
......@@ -315,7 +313,6 @@ STATIC_MODE_TESTING_LIST = [
'test_load_vars_shape_check',
'test_locality_aware_nms_op',
'test_lod_array_length_op',
'test_lod_rank_table',
'test_lod_tensor_array_ops',
'test_log_loss_op',
'test_log_softmax',
......@@ -440,7 +437,6 @@ STATIC_MODE_TESTING_LIST = [
'test_registry',
'test_regularizer',
'test_regularizer_api',
'test_reorder_lod_tensor',
'test_reshape_op',
'test_reshape_bf16_op',
'test_retinanet_detection_output',
......@@ -472,7 +468,6 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op',
'test_shape_op',
'test_shard_index_op',
'test_shrink_rnn_memory',
'test_shuffle_batch_op',
'test_shuffle_channel_op',
'test_sigmoid_cross_entropy_with_logits_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册