未验证 提交 e57051b4 编写于 作者: W WangZhen 提交者: GitHub

Polish ResNet and Bert prim_cinn test (#52030)

* Polish ResNet and Bert prim_cinn test
上级 2ba4515e
...@@ -13,7 +13,5 @@ set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500) ...@@ -13,7 +13,5 @@ set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500)
if(WITH_CINN) if(WITH_CINN)
set_tests_properties(test_resnet_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(test_resnet_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties( set_tests_properties(test_bert_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN")
test_bert_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT
"FLAGS_deny_cinn_ops=dropout")
endif() endif()
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import platform
import time import time
import unittest import unittest
...@@ -33,6 +32,44 @@ MODULE_NAME = 'test_bert_prim_cinn' ...@@ -33,6 +32,44 @@ MODULE_NAME = 'test_bert_prim_cinn'
MD5SUM = '71e730ee8d7aa77a215b7e898aa089af' MD5SUM = '71e730ee8d7aa77a215b7e898aa089af'
SAVE_NAME = 'bert_training_data.npz' SAVE_NAME = 'bert_training_data.npz'
DY2ST_PRIM_GT = [
11.144556999206543,
10.343620300292969,
10.330279350280762,
10.276118278503418,
10.222086906433105,
10.194628715515137,
10.14902114868164,
10.096250534057617,
10.104615211486816,
9.985644340515137,
]
DY2ST_CINN_GT = [
10.649632453918457,
10.333406448364258,
10.33541202545166,
10.260543823242188,
10.219606399536133,
10.176884651184082,
10.124699592590332,
10.072620391845703,
10.112163543701172,
9.969393730163574,
]
DY2ST_PRIM_CINN_GT = [
11.144556999206543,
10.343620300292969,
10.330279350280762,
10.276118278503418,
10.222086906433105,
10.194628715515137,
10.149020195007324,
10.096250534057617,
10.104615211486816,
9.985644340515137,
]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True}) paddle.set_flags({'FLAGS_cudnn_deterministic': True})
...@@ -42,9 +79,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -42,9 +79,7 @@ def train(to_static, enable_prim, enable_cinn):
paddle.set_device('gpu') paddle.set_device('gpu')
else: else:
paddle.set_device('cpu') paddle.set_device('cpu')
fluid.core._set_prim_all_enabled( fluid.core._set_prim_all_enabled(enable_prim)
enable_prim and platform.system() == 'Linux'
)
np.random.seed(SEED) np.random.seed(SEED)
paddle.seed(SEED) paddle.seed(SEED)
...@@ -95,7 +130,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -95,7 +130,7 @@ def train(to_static, enable_prim, enable_cinn):
loss.backward() loss.backward()
optimizer.minimize(loss) optimizer.minimize(loss)
bert.clear_gradients() bert.clear_gradients()
losses.append(loss) losses.append(loss.numpy().item())
print( print(
"step: {}, loss: {}, batch_cost: {:.5}".format( "step: {}, loss: {}, batch_cost: {:.5}".format(
...@@ -106,6 +141,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -106,6 +141,7 @@ def train(to_static, enable_prim, enable_cinn):
) )
if step >= 9: if step >= 9:
break break
print(losses)
return losses return losses
...@@ -113,28 +149,42 @@ class TestBert(unittest.TestCase): ...@@ -113,28 +149,42 @@ class TestBert(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
download(URL, MODULE_NAME, MD5SUM, SAVE_NAME) download(URL, MODULE_NAME, MD5SUM, SAVE_NAME)
cls.dy2st = train(to_static=True, enable_prim=False, enable_cinn=False)
def tearDown(self):
paddle.set_flags({'FLAGS_deny_cinn_ops': ''})
@unittest.skipIf(
not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
"paddle is not compiled with CINN and CUDA",
)
def test_prim(self): def test_prim(self):
dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False) dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False)
np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-1) np.testing.assert_allclose(dy2st_prim, DY2ST_PRIM_GT, rtol=1e-5)
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cinn(), "paddle is not compiled with CINN" not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
"paddle is not compiled with CINN and CUDA",
) )
def test_cinn(self): def test_cinn(self):
paddle.set_flags({'FLAGS_deny_cinn_ops': "dropout"})
dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True) dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True)
np.testing.assert_allclose(self.dy2st, dy2st_cinn, rtol=1e-6) np.testing.assert_allclose(dy2st_cinn, DY2ST_CINN_GT, rtol=1e-5)
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cinn(), "paddle is not compiled with CINN" not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
"paddle is not compiled with CINN and CUDA",
) )
def test_prim_cinn(self): def test_prim_cinn(self):
paddle.set_flags(
{'FLAGS_deny_cinn_ops': "gaussian_random;uniform_random"}
)
core._add_skip_comp_ops("layer_norm") core._add_skip_comp_ops("layer_norm")
dy2st_prim_cinn = train( dy2st_prim_cinn = train(
to_static=True, enable_prim=True, enable_cinn=True to_static=True, enable_prim=True, enable_cinn=True
) )
np.testing.assert_allclose(self.dy2st, dy2st_prim_cinn, rtol=1e-1) np.testing.assert_allclose(
dy2st_prim_cinn, DY2ST_PRIM_CINN_GT, rtol=1e-5
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,6 +29,82 @@ l2_decay = 1e-4 ...@@ -29,6 +29,82 @@ l2_decay = 1e-4
batch_size = 2 batch_size = 2
epoch_num = 1 epoch_num = 1
# In V100, 16G, CUDA 11.2, the results are as follows:
# DY2ST_PRIM_GT = [
# 5.8473358154296875,
# 8.354944229125977,
# 5.098367691040039,
# 8.533346176147461,
# 8.179085731506348,
# 7.285282135009766,
# 9.824585914611816,
# 8.56928825378418,
# 8.539499282836914,
# 10.256929397583008,
# ]
# DY2ST_CINN_GT = [
# 5.847336769104004,
# 8.336246490478516,
# 5.108744144439697,
# 8.316713333129883,
# 8.175262451171875,
# 7.590441703796387,
# 9.895681381225586,
# 8.196207046508789,
# 8.438933372497559,
# 10.305074691772461,
# ]
# DY2ST_PRIM_CINN_GT = [
# 5.8473358154296875,
# 8.322463989257812,
# 5.169863700866699,
# 8.399882316589355,
# 7.859550476074219,
# 7.4672698974609375,
# 9.828727722167969,
# 8.270355224609375,
# 8.456792831420898,
# 9.919631958007812,
# ]
# The results in ci as as follows:
DY2ST_PRIM_GT = [
5.82879114151001,
8.333706855773926,
5.07769250869751,
8.66937255859375,
8.411705017089844,
7.252340793609619,
9.683248519897461,
8.177335739135742,
8.195427894592285,
10.219732284545898,
]
DY2ST_CINN_GT = [
5.828789710998535,
8.340764999389648,
4.998944282531738,
8.474305152893066,
8.09157943725586,
7.440057754516602,
9.907357215881348,
8.304681777954102,
8.383116722106934,
10.120304107666016,
]
DY2ST_PRIM_CINN_GT = [
5.828784942626953,
8.341737747192383,
5.113619327545166,
8.625601768493652,
8.082450866699219,
7.4913249015808105,
9.858025550842285,
8.287693977355957,
8.435812950134277,
10.372406005859375,
]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True}) paddle.set_flags({'FLAGS_cudnn_deterministic': True})
...@@ -109,7 +185,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -109,7 +185,7 @@ def train(to_static, enable_prim, enable_cinn):
total_acc1 += acc_top1 total_acc1 += acc_top1
total_acc5 += acc_top5 total_acc5 += acc_top5
total_sample += 1 total_sample += 1
losses.append(avg_loss.numpy()) losses.append(avg_loss.numpy().item())
end_time = time.time() end_time = time.time()
print( print(
...@@ -123,49 +199,42 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -123,49 +199,42 @@ def train(to_static, enable_prim, enable_cinn):
end_time - start_time, end_time - start_time,
) )
) )
if batch_id == 10: if batch_id >= 9:
# avoid dataloader throw abort signaal # avoid dataloader throw abort signaal
data_loader._reset() data_loader._reset()
break break
print(losses)
return losses return losses
class TestResnet(unittest.TestCase): class TestResnet(unittest.TestCase):
@classmethod @unittest.skipIf(
def setUpClass(cls): not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
cls.dy2st = train(to_static=True, enable_prim=False, enable_cinn=False) "paddle is not compiled with CINN and CUDA",
)
def test_prim(self): def test_prim(self):
# todo: to be removed after adjust of rtol
core._set_prim_forward_blacklist("batch_norm")
core._add_skip_comp_ops("batch_norm")
dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False) dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False)
# NOTE: Now dy2st is equal to dy2st_prim. With the splitting of kernels, the threshold here may need to be adjusted np.testing.assert_allclose(dy2st_prim, DY2ST_PRIM_GT, rtol=1e-5)
np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-6)
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
"paddle is not compiled with CINN and CUDA",
) )
def test_cinn(self): def test_cinn(self):
dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True) dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True)
# TODO(0x45f): The following is only temporary thresholds, and the final thresholds needs to be discussed np.testing.assert_allclose(dy2st_cinn, DY2ST_CINN_GT, rtol=1e-5)
np.testing.assert_allclose(self.dy2st[0:2], dy2st_cinn[0:2], rtol=1e-3)
np.testing.assert_allclose(self.dy2st, dy2st_cinn, rtol=1e-1)
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
"paddle is not compiled with CINN and CUDA",
) )
def test_prim_cinn(self): def test_prim_cinn(self):
core._set_prim_forward_blacklist("flatten_contiguous_range")
dy2st_prim_cinn = train( dy2st_prim_cinn = train(
to_static=True, enable_prim=True, enable_cinn=True to_static=True, enable_prim=True, enable_cinn=True
) )
# TODO(0x45f): The following is only temporary thresholds, and the final thresholds need to be discussed
np.testing.assert_allclose( np.testing.assert_allclose(
self.dy2st[0:2], dy2st_prim_cinn[0:2], rtol=1e-2 dy2st_prim_cinn, DY2ST_PRIM_CINN_GT, rtol=1e-5
) )
np.testing.assert_allclose(self.dy2st, dy2st_prim_cinn, rtol=1e-1)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册