未验证 提交 4cb7d32c 编写于 作者: J Jiabin Yang 提交者: GitHub

test=develop, add dygraph_not_support and refine ocr (#17868)

* test=develop, add dygraph_not_support and refine ocr

* test=develop, shrink name of dygraph_not_support
上级 545afb2d
...@@ -22,6 +22,7 @@ from .tracer import Tracer ...@@ -22,6 +22,7 @@ from .tracer import Tracer
__all__ = [ __all__ = [
'enabled', 'enabled',
'no_grad', 'no_grad',
'not_support',
'guard', 'guard',
'to_variable', 'to_variable',
] ]
...@@ -43,6 +44,15 @@ def _switch_tracer_mode_guard_(is_train=True): ...@@ -43,6 +44,15 @@ def _switch_tracer_mode_guard_(is_train=True):
yield yield
def _dygraph_not_support_(func):
def __impl__(*args, **kwargs):
assert not framework.in_dygraph_mode(
), "We don't support %s in Dygraph mode" % func.__name__
return func(*args, **kwargs)
return __impl__
def _no_grad_(func): def _no_grad_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with _switch_tracer_mode_guard_(is_train=False):
...@@ -52,6 +62,7 @@ def _no_grad_(func): ...@@ -52,6 +62,7 @@ def _no_grad_(func):
no_grad = wrap_decorator(_no_grad_) no_grad = wrap_decorator(_no_grad_)
not_support = wrap_decorator(_dygraph_not_support_)
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -13,16 +13,11 @@ ...@@ -13,16 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import contextlib
import unittest import unittest
import numpy as np import numpy as np
import six import six
import os
from PIL import Image
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm, Embedding, GRUUnit from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
...@@ -37,13 +32,13 @@ class Config(object): ...@@ -37,13 +32,13 @@ class Config(object):
# size for word embedding # size for word embedding
word_vector_dim = 128 word_vector_dim = 128
# max length for label padding # max length for label padding
max_length = 15 max_length = 5
# optimizer setting # optimizer setting
LR = 1.0 LR = 1.0
learning_rate_decay = None learning_rate_decay = None
# batch size to train # batch size to train
batch_size = 32 batch_size = 16
# class number to classify # class number to classify
num_classes = 481 num_classes = 481
...@@ -445,10 +440,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -445,10 +440,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
(i - 1) * Config.max_length, (i - 1) * Config.max_length,
i * Config.max_length, i * Config.max_length,
dtype='int64').reshape([1, Config.max_length]))) dtype='int64').reshape([1, Config.max_length])))
#if Config.use_gpu:
# place = fluid.CUDAPlace(0)
#else:
# place = fluid.CPUPlace()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -461,10 +453,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -461,10 +453,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
[50000], [Config.LR, Config.LR * 0.01]) [50000], [Config.LR, Config.LR * 0.01])
else: else:
learning_rate = Config.LR learning_rate = Config.LR
#optimizer = fluid.optimizer.Adadelta(learning_rate=learning_rate,
# epsilon=1.0e-6, rho=0.9)
optimizer = fluid.optimizer.SGD(learning_rate=0.001) optimizer = fluid.optimizer.SGD(learning_rate=0.001)
# place = fluid.CPUPlace()
dy_param_init_value = {} dy_param_init_value = {}
for param in ocr_attention.parameters(): for param in ocr_attention.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册