From 4cb7d32c9b40616755fbf6badacd4c7ed433c06c Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Thu, 6 Jun 2019 20:39:44 +0800 Subject: [PATCH] test=develop, add dygraph_not_support and refine ocr (#17868) * test=develop, add dygraph_not_support and refine ocr * test=develop, shrink name of dygraph_not_support --- python/paddle/fluid/dygraph/base.py | 11 +++++++++++ .../test_imperative_ocr_attention_model.py | 17 +++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 63f861e38..13b09f145 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -22,6 +22,7 @@ from .tracer import Tracer __all__ = [ 'enabled', 'no_grad', + 'not_support', 'guard', 'to_variable', ] @@ -43,6 +44,15 @@ def _switch_tracer_mode_guard_(is_train=True): yield +def _dygraph_not_support_(func): + def __impl__(*args, **kwargs): + assert not framework.in_dygraph_mode( + ), "We don't support %s in Dygraph mode" % func.__name__ + return func(*args, **kwargs) + + return __impl__ + + def _no_grad_(func): def __impl__(*args, **kwargs): with _switch_tracer_mode_guard_(is_train=False): @@ -52,6 +62,7 @@ def _no_grad_(func): no_grad = wrap_decorator(_no_grad_) +not_support = wrap_decorator(_dygraph_not_support_) @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py index f3a231e7b..3f53552ba 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py @@ -13,16 +13,11 @@ # limitations under the License. from __future__ import print_function -import contextlib import unittest import numpy as np import six -import os -from PIL import Image -import paddle import paddle.fluid as fluid from paddle.fluid import core -from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm, Embedding, GRUUnit from paddle.fluid.dygraph.base import to_variable from test_imperative_base import new_program_scope @@ -37,13 +32,13 @@ class Config(object): # size for word embedding word_vector_dim = 128 # max length for label padding - max_length = 15 + max_length = 5 # optimizer setting LR = 1.0 learning_rate_decay = None # batch size to train - batch_size = 32 + batch_size = 16 # class number to classify num_classes = 481 @@ -445,10 +440,7 @@ class TestDygraphOCRAttention(unittest.TestCase): (i - 1) * Config.max_length, i * Config.max_length, dtype='int64').reshape([1, Config.max_length]))) - #if Config.use_gpu: - # place = fluid.CUDAPlace(0) - #else: - # place = fluid.CPUPlace() + with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -461,10 +453,7 @@ class TestDygraphOCRAttention(unittest.TestCase): [50000], [Config.LR, Config.LR * 0.01]) else: learning_rate = Config.LR - #optimizer = fluid.optimizer.Adadelta(learning_rate=learning_rate, - # epsilon=1.0e-6, rho=0.9) optimizer = fluid.optimizer.SGD(learning_rate=0.001) - # place = fluid.CPUPlace() dy_param_init_value = {} for param in ocr_attention.parameters(): dy_param_init_value[param.name] = param.numpy() -- GitLab