未验证 提交 4cb7d32c 编写于 作者: J Jiabin Yang 提交者: GitHub

test=develop, add dygraph_not_support and refine ocr (#17868)

* test=develop, add dygraph_not_support and refine ocr

* test=develop, shrink name of dygraph_not_support
上级 545afb2d
......@@ -22,6 +22,7 @@ from .tracer import Tracer
__all__ = [
'enabled',
'no_grad',
'not_support',
'guard',
'to_variable',
]
......@@ -43,6 +44,15 @@ def _switch_tracer_mode_guard_(is_train=True):
yield
def _dygraph_not_support_(func):
def __impl__(*args, **kwargs):
assert not framework.in_dygraph_mode(
), "We don't support %s in Dygraph mode" % func.__name__
return func(*args, **kwargs)
return __impl__
def _no_grad_(func):
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
......@@ -52,6 +62,7 @@ def _no_grad_(func):
no_grad = wrap_decorator(_no_grad_)
not_support = wrap_decorator(_dygraph_not_support_)
@signature_safe_contextmanager
......
......@@ -13,16 +13,11 @@
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import os
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
......@@ -37,13 +32,13 @@ class Config(object):
# size for word embedding
word_vector_dim = 128
# max length for label padding
max_length = 15
max_length = 5
# optimizer setting
LR = 1.0
learning_rate_decay = None
# batch size to train
batch_size = 32
batch_size = 16
# class number to classify
num_classes = 481
......@@ -445,10 +440,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
(i - 1) * Config.max_length,
i * Config.max_length,
dtype='int64').reshape([1, Config.max_length])))
#if Config.use_gpu:
# place = fluid.CUDAPlace(0)
#else:
# place = fluid.CPUPlace()
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -461,10 +453,7 @@ class TestDygraphOCRAttention(unittest.TestCase):
[50000], [Config.LR, Config.LR * 0.01])
else:
learning_rate = Config.LR
#optimizer = fluid.optimizer.Adadelta(learning_rate=learning_rate,
# epsilon=1.0e-6, rho=0.9)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
# place = fluid.CPUPlace()
dy_param_init_value = {}
for param in ocr_attention.parameters():
dy_param_init_value[param.name] = param.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册