未验证 提交 49e4e2f9 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support rnn_decode switch to eager mode (#41333)

上级 42075ddc
...@@ -106,6 +106,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -106,6 +106,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}}, {"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}},
{"crf_decoding", {"Emission", "Transition", "Label", "Length"}}, {"crf_decoding", {"Emission", "Transition", "Label", "Length"}},
{"chunk_eval", {"Inference", "Label", "SeqLength"}}, {"chunk_eval", {"Inference", "Label", "SeqLength"}},
{"sequence_mask", {"X", "MaxLenTensor"}},
{"graph_reindex", {"graph_reindex",
{"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}}, {"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}},
{"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}}, {"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}},
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
from __future__ import print_function from __future__ import print_function
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..framework import Variable, _non_static_mode from ..framework import core, Variable, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype from ..data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..core import VarDesc from ..core import VarDesc
from paddle import _C_ops
__all__ = [ __all__ = [
'sequence_conv', 'sequence_conv',
...@@ -1380,6 +1381,21 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): ...@@ -1380,6 +1381,21 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
# [1 1 1 1 1 1 1 1 0 0]] # [1 1 1 1 1 1 1 1 0 0]]
""" """
if _non_static_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
if maxlen is not None:
if isinstance(maxlen, core.eager.Tensor):
attrs = ('out_dtype', dtype)
out = _C_ops.sequence_mask(x, maxlen, *attrs)
else:
attrs = ('out_dtype', dtype, 'maxlen', maxlen)
out = _C_ops.sequence_mask(x, None, *attrs)
out.stop_gradient = True
return out
helper = LayerHelper('sequence_mask', **locals()) helper = LayerHelper('sequence_mask', **locals())
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
...@@ -31,7 +31,7 @@ import paddle.fluid.core as core ...@@ -31,7 +31,7 @@ import paddle.fluid.core as core
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static() paddle.enable_static()
...@@ -554,7 +554,7 @@ class TestDynamicDecode(unittest.TestCase): ...@@ -554,7 +554,7 @@ class TestDynamicDecode(unittest.TestCase):
}, },
fetch_list=[output])[0] fetch_list=[output])[0]
def test_dynamic_basic_decoder(self): def func_dynamic_basic_decoder(self):
paddle.disable_static() paddle.disable_static()
src = paddle.to_tensor(np.random.randint(8, size=(8, 4))) src = paddle.to_tensor(np.random.randint(8, size=(8, 4)))
src_length = paddle.to_tensor(np.random.randint(8, size=(8))) src_length = paddle.to_tensor(np.random.randint(8, size=(8)))
...@@ -562,6 +562,11 @@ class TestDynamicDecode(unittest.TestCase): ...@@ -562,6 +562,11 @@ class TestDynamicDecode(unittest.TestCase):
probs, samples, sample_length = model(src, src_length) probs, samples, sample_length = model(src, src_length)
paddle.enable_static() paddle.enable_static()
def test_dynamic_basic_decoder(self):
with _test_eager_guard():
self.func_dynamic_basic_decoder()
self.func_dynamic_basic_decoder()
class ModuleApiTest(unittest.TestCase): class ModuleApiTest(unittest.TestCase):
@classmethod @classmethod
...@@ -708,9 +713,17 @@ class TestBeamSearch(ModuleApiTest): ...@@ -708,9 +713,17 @@ class TestBeamSearch(ModuleApiTest):
] ]
return inputs return inputs
def test_check_output(self): def func_check_output(self):
self.setUp()
self.make_inputs()
self.make_inputs()
self.check_output() self.check_output()
def test_check_output(self):
with _test_eager_guard():
self.func_check_output()
self.func_check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册