未验证 提交 71d62207 编写于 作者: L Leo Chen 提交者: GitHub

Skip reader op in mixed_precision decorator (#28353)

* skip reader op in mixed_precision decorator

* add ut
上级 8b2436a7
......@@ -215,6 +215,14 @@ def rewrite_program(main_prog, amp_lists):
white_op_set = set()
black_op_set = set()
for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
......
......@@ -417,5 +417,42 @@ class TestImageClassification(unittest.TestCase):
yield
class TestAmpWithNonIterableDataLoader(unittest.TestCase):
def decorate_with_data_loader(self):
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
with paddle.fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
py_reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=4,
iterable=False,
use_double_buffer=False)
net = vgg16_bn_drop(image)
logits = fluid.layers.fc(input=net, size=10, act="softmax")
cost, predict = fluid.layers.softmax_with_cross_entropy(
logits, label, return_softmax=True)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer,
amp_lists=amp_lists,
init_loss_scaling=8.0,
use_dynamic_loss_scaling=True)
mp_optimizer.minimize(avg_cost)
def test_non_iterable_dataloader(self):
self.decorate_with_data_loader()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册