From 8212874f4780f19316f3379b365e41faaa3b48b1 Mon Sep 17 00:00:00 2001 From: guofei <52460041+gfwm2013@users.noreply.github.com> Date: Tue, 29 Dec 2020 16:25:46 +0800 Subject: [PATCH] Fix test_imperative_skip_out (#29939) * Fix unittest:test_imperative_skip_out --- .../fluid/contrib/slim/tests/test_imperative_skip_op.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py index d030d1eb51..0561055e6e 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py @@ -38,11 +38,9 @@ if core.is_compiled_with_cuda(): _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -quant_skip_pattern_list = ['skip_qat', 'skip_quant'] - class ImperativeLenet(fluid.dygraph.Layer): - def __init__(self, num_classes=10, classifier_activation='softmax'): + def __init__(self, num_classes=10): super(ImperativeLenet, self).__init__() conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") @@ -135,7 +133,7 @@ class TestImperativeOutSclae(unittest.TestCase): np.random.seed(seed) reader = paddle.batch( - paddle.dataset.mnist.test(), batch_size=32, drop_last=True) + paddle.dataset.mnist.test(), batch_size=512, drop_last=True) lenet = ImperativeLenet() fixed_state = {} for name, param in lenet.named_parameters(): -- GitLab