未验证 提交 ca364e14 编写于 作者: C cyber-pioneer 提交者: GitHub

rm restict of platform (#51806)

上级 c985b1ac
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -171,8 +170,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
print("need pass this case")
......@@ -191,8 +188,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
if __name__ == '__main__':
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -185,15 +184,10 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('reduce_mean' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes:
for dtype in self.dtypes:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
if paddle.device.get_device() == "cpu" and dtype == "float16":
print("need pass this case")
continue
data = generate_data(shape, dtype)
......@@ -208,8 +202,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
if __name__ == '__main__':
......
......@@ -14,7 +14,6 @@
import math
import os
import platform
import tempfile
import time
import unittest
......@@ -442,8 +441,6 @@ class TestResnet(unittest.TestCase):
)
def test_resnet_composite_forward_backward(self):
plat = platform.system()
if plat == "Linux":
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
core._set_prim_all_enabled(False)
......@@ -456,8 +453,6 @@ class TestResnet(unittest.TestCase):
static_loss, dygraph_loss
),
)
else:
pass
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import time
import unittest
......@@ -63,9 +62,7 @@ def train(to_static, enable_prim, enable_cinn):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(
enable_prim and platform.system() == 'Linux'
)
fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import platform
import unittest
import paddle
......@@ -95,78 +94,48 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "cinn_prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_all(self):
"""prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_forward(self):
"""cinn + prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "cinn_prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_forward(self):
"""only prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_backward(self):
"""cinn + prim_backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "cinn_prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
def test_prim_backward(self):
"""only prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
def test_cinn(self):
"""only cinn"""
self.reset_env_flag()
self.flag = "cinn"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册