未验证 提交 4812c707 编写于 作者: H hong 提交者: GitHub

[NewIR]enable dy2st new ir test in ci (#56034)

* support new ir dy2st

* revert code

* skip test with stride

* chang import file
上级 03ca04fe
......@@ -121,6 +121,8 @@ def test_with_new_ir(func):
@wraps(func)
def impl(*args, **kwargs):
ir_outs = None
if os.environ.get('FLAGS_use_stride_kernel', False):
return
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
try:
......
......@@ -20,7 +20,7 @@ import unittest
import numpy as np
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
from dygraph_to_static_util import ast_only_test
from dygraph_to_static_util import ast_only_test, test_with_new_ir
from predictor_utils import PredictorTools
import paddle
......@@ -263,6 +263,17 @@ class TestBert(unittest.TestCase):
out = output()
return out
@test_with_new_ir
def test_train_new_ir(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05)
np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05)
@ast_only_test
def test_train(self):
static_loss, static_ppl = self.train_static(
......
......@@ -19,6 +19,7 @@ import time
import unittest
import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools
import paddle
......@@ -731,6 +732,13 @@ class TestMobileNet(unittest.TestCase):
),
)
@test_with_new_ir
def test_mobile_net_new_ir(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
# MobileNet-V2
self.assert_same_loss("MobileNetV2")
def test_mobile_net(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
......
......@@ -19,6 +19,7 @@ import time
import unittest
import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools
import paddle
......@@ -426,6 +427,19 @@ class TestResnet(unittest.TestCase):
),
)
@test_with_new_ir
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
......
......@@ -19,6 +19,7 @@ import time
import unittest
import numpy as np
from dygraph_to_static_util import test_with_new_ir
from predictor_utils import PredictorTools
import paddle
......@@ -431,6 +432,19 @@ class TestResnet(unittest.TestCase):
),
)
@test_with_new_ir
def test_resnet_new_ir(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册