未验证 提交 c03186f9 编写于 作者: 0 0x45f 提交者: GitHub

Refine test_lac.py for eager mode (#40951)

* Refine test_lac.py for eager mode

* refine code

* Fix test_program_translator for eager
上级 0d0d76eb
......@@ -98,6 +98,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"CustomDistAlias", "CustomDistAliasProbs"}},
{"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}},
{"group_norm", {"X", "Scale", "Bias"}},
{"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}},
{"crf_decoding", {"Emission", "Transition", "Label", "Length"}},
{"chunk_eval", {"Inference", "Label", "SeqLength"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......
......@@ -27,6 +27,8 @@ from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Embedding, Linear, GRUUnit
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.framework import _non_static_mode
from paddle import _C_ops
SEED = 2020
......@@ -167,6 +169,11 @@ class LinearChainCRF(fluid.dygraph.Layer):
self._transition = value
def forward(self, input, label, length=None):
if _non_static_mode():
_, _, _, log_likelihood = _C_ops.linear_chain_crf(
input, self._transition, label, length, "is_test",
self._is_test)
return log_likelihood
alpha = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
......@@ -218,6 +225,9 @@ class CRFDecoding(fluid.dygraph.Layer):
self._transition = value
def forward(self, input, label=None, length=None):
if _non_static_mode():
return _C_ops.crf_decoding(input, self._transition, label, length,
"is_test", self._is_test)
viterbi_path = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
......@@ -245,6 +255,11 @@ class ChunkEval(fluid.dygraph.Layer):
self.excluded_chunk_types = excluded_chunk_types
def forward(self, input, label, seq_length=None):
if _non_static_mode():
return _C_ops.chunk_eval(
input, label, seq_length, "num_chunk_types",
self.num_chunk_types, "chunk_scheme", self.chunk_scheme,
"excluded_chunk_types", self.excluded_chunk_types or [])
precision = self._helper.create_variable_for_type_inference(
dtype="float32")
......
......@@ -232,7 +232,9 @@ class TestEnableDeclarative(unittest.TestCase):
dygraph_func = self.program_translator.get_func(simple_func)
self.assertTrue(callable(dygraph_func))
dygraph_output = dygraph_func(self.x, self.weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
self.assertTrue(
isinstance(dygraph_output, (fluid.core.VarBase,
fluid.core.eager.Tensor)))
def test_enable_disable_get_program(self):
......@@ -254,7 +256,9 @@ class TestEnableDeclarative(unittest.TestCase):
with fluid.dygraph.guard():
dygraph_output = self.program_translator.get_program(
simple_func, self.x, self.weight)
self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))
self.assertTrue(
isinstance(dygraph_output, (fluid.core.VarBase,
fluid.core.eager.Tensor)))
def test_enable_disable_declarative(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册