未验证 提交 4812c707 编写于 作者: H hong 提交者: GitHub

[NewIR]enable dy2st new ir test in ci (#56034)

* support new ir dy2st

* revert code

* skip test with stride

* chang import file
上级 03ca04fe
...@@ -121,6 +121,8 @@ def test_with_new_ir(func): ...@@ -121,6 +121,8 @@ def test_with_new_ir(func):
@wraps(func) @wraps(func)
def impl(*args, **kwargs): def impl(*args, **kwargs):
ir_outs = None ir_outs = None
if os.environ.get('FLAGS_use_stride_kernel', False):
return
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
try: try:
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import numpy as np import numpy as np
from bert_dygraph_model import PretrainModelLayer from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader from bert_utils import get_bert_config, get_feed_data_reader
from dygraph_to_static_util import ast_only_test from dygraph_to_static_util import ast_only_test, test_with_new_ir
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -263,6 +263,17 @@ class TestBert(unittest.TestCase): ...@@ -263,6 +263,17 @@ class TestBert(unittest.TestCase):
out = output() out = output()
return out return out
@test_with_new_ir
def test_train_new_ir(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05)
np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05)
@ast_only_test @ast_only_test
def test_train(self): def test_train(self):
static_loss, static_ppl = self.train_static( static_loss, static_ppl = self.train_static(
......
...@@ -19,6 +19,7 @@ import time ...@@ -19,6 +19,7 @@ import time
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -731,6 +732,13 @@ class TestMobileNet(unittest.TestCase): ...@@ -731,6 +732,13 @@ class TestMobileNet(unittest.TestCase):
), ),
) )
@test_with_new_ir
def test_mobile_net_new_ir(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
# MobileNet-V2
self.assert_same_loss("MobileNetV2")
def test_mobile_net(self): def test_mobile_net(self):
# MobileNet-V1 # MobileNet-V1
self.assert_same_loss("MobileNetV1") self.assert_same_loss("MobileNetV1")
......
...@@ -19,6 +19,7 @@ import time ...@@ -19,6 +19,7 @@ import time
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -426,6 +427,19 @@ class TestResnet(unittest.TestCase): ...@@ -426,6 +427,19 @@ class TestResnet(unittest.TestCase):
), ),
) )
@test_with_new_ir
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_resnet(self): def test_resnet(self):
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
......
...@@ -19,6 +19,7 @@ import time ...@@ -19,6 +19,7 @@ import time
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
import paddle import paddle
...@@ -431,6 +432,19 @@ class TestResnet(unittest.TestCase): ...@@ -431,6 +432,19 @@ class TestResnet(unittest.TestCase):
), ),
) )
@test_with_new_ir
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_resnet(self): def test_resnet(self):
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册