未验证 提交 ca364e14 编写于 作者: C cyber-pioneer 提交者: GitHub

rm restict of platform (#51806)

上级 c985b1ac
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -171,8 +170,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -171,8 +170,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for dtype in self.dtypes: for dtype in self.dtypes:
if paddle.device.get_device() == "cpu": if paddle.device.get_device() == "cpu":
print("need pass this case") print("need pass this case")
...@@ -191,8 +188,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -191,8 +188,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
rtol=TOLERANCE[dtype]['rtol'], rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'], atol=TOLERANCE[dtype]['atol'],
) )
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -185,15 +184,10 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -185,15 +184,10 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('reduce_mean' not in fwd_ops) self.assertTrue('reduce_mean' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes: for shape in self.shapes:
for dtype in self.dtypes: for dtype in self.dtypes:
# mean-kernel on cpu not support float16 # mean-kernel on cpu not support float16
if ( if paddle.device.get_device() == "cpu" and dtype == "float16":
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
print("need pass this case") print("need pass this case")
continue continue
data = generate_data(shape, dtype) data = generate_data(shape, dtype)
...@@ -208,8 +202,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -208,8 +202,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
rtol=TOLERANCE[dtype]['rtol'], rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'], atol=TOLERANCE[dtype]['atol'],
) )
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import math import math
import os import os
import platform
import tempfile import tempfile
import time import time
import unittest import unittest
...@@ -442,8 +441,6 @@ class TestResnet(unittest.TestCase): ...@@ -442,8 +441,6 @@ class TestResnet(unittest.TestCase):
) )
def test_resnet_composite_forward_backward(self): def test_resnet_composite_forward_backward(self):
plat = platform.system()
if plat == "Linux":
core._set_prim_all_enabled(True) core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
...@@ -456,8 +453,6 @@ class TestResnet(unittest.TestCase): ...@@ -456,8 +453,6 @@ class TestResnet(unittest.TestCase):
static_loss, dygraph_loss static_loss, dygraph_loss
), ),
) )
else:
pass
def test_in_static_mode_mkldnn(self): def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True}) fluid.set_flags({'FLAGS_use_mkldnn': True})
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import time import time
import unittest import unittest
...@@ -63,9 +62,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -63,9 +62,7 @@ def train(to_static, enable_prim, enable_cinn):
np.random.seed(SEED) np.random.seed(SEED)
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled( fluid.core._set_prim_all_enabled(enable_prim)
enable_prim and platform.system() == 'Linux'
)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import platform
import unittest import unittest
import paddle import paddle
...@@ -95,78 +94,48 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -95,78 +94,48 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True" os.environ["FLAGS_prim_all"] = "True"
self.flag = "cinn_prim_all" self.flag = "cinn_prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True) _ = self.train(use_cinn=True)
else:
pass
def test_prim_all(self): def test_prim_all(self):
"""prim forward + prim backward""" """prim forward + prim backward"""
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True" os.environ["FLAGS_prim_all"] = "True"
self.flag = "prim_all" self.flag = "prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False) _ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_forward(self): def test_cinn_prim_forward(self):
"""cinn + prim forward""" """cinn + prim forward"""
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True" os.environ["FLAGS_prim_forward"] = "True"
self.flag = "cinn_prim_forward" self.flag = "cinn_prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True) _ = self.train(use_cinn=True)
else:
pass
def test_prim_forward(self): def test_prim_forward(self):
"""only prim forward""" """only prim forward"""
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True" os.environ["FLAGS_prim_forward"] = "True"
self.flag = "prim_forward" self.flag = "prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False) _ = self.train(use_cinn=False)
else:
pass
def test_cinn_prim_backward(self): def test_cinn_prim_backward(self):
"""cinn + prim_backward""" """cinn + prim_backward"""
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True" os.environ["FLAGS_prim_backward"] = "True"
self.flag = "cinn_prim_backward" self.flag = "cinn_prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True) _ = self.train(use_cinn=True)
else:
pass
def test_prim_backward(self): def test_prim_backward(self):
"""only prim backward""" """only prim backward"""
self.reset_env_flag() self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True" os.environ["FLAGS_prim_backward"] = "True"
self.flag = "prim_backward" self.flag = "prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False) _ = self.train(use_cinn=False)
else:
pass
def test_cinn(self): def test_cinn(self):
"""only cinn""" """only cinn"""
self.reset_env_flag() self.reset_env_flag()
self.flag = "cinn" self.flag = "cinn"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True) _ = self.train(use_cinn=True)
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册