未验证 提交 ca364e14 编写于 作者: C cyber-pioneer 提交者: GitHub

rm restict of platform (#51806)

上级 c985b1ac
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -171,28 +170,24 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
print("need pass this case")
continue
x_n, w_n, b_n = generate_data(dtype)
self.x = paddle.to_tensor(x_n)
self.w = paddle.to_tensor(w_n)
self.b = paddle.to_tensor(b_n)
self.x.stop_gradient = False
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu":
print("need pass this case")
continue
x_n, w_n, b_n = generate_data(dtype)
self.x = paddle.to_tensor(x_n)
self.w = paddle.to_tensor(w_n)
self.b = paddle.to_tensor(b_n)
self.x.stop_gradient = False
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
if __name__ == '__main__':
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -185,31 +184,24 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('reduce_mean' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes:
for dtype in self.dtypes:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
for shape in self.shapes:
for dtype in self.dtypes:
# mean-kernel on cpu not support float16
if paddle.device.get_device() == "cpu" and dtype == "float16":
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
if __name__ == '__main__':
......
......@@ -14,7 +14,6 @@
import math
import os
import platform
import tempfile
import time
import unittest
......@@ -442,22 +441,18 @@ class TestResnet(unittest.TestCase):
)
def test_resnet_composite_forward_backward(self):
plat = platform.system()
if plat == "Linux":
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
core._set_prim_all_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-02,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
else:
pass
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
core._set_prim_all_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-02,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import time
import unittest
......@@ -63,9 +62,7 @@ def train(to_static, enable_prim, enable_cinn):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(
enable_prim and platform.system() == 'Linux'
)
fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import platform
import unittest
import paddle
......@@ -95,78 +94,48 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "cinn_prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
_ = self.train(use_cinn=True)
def test_prim_all(self):
"""prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
_ = self.train(use_cinn=False)
def test_cinn_prim_forward(self):
"""cinn + prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "cinn_prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
_ = self.train(use_cinn=True)
def test_prim_forward(self):
"""only prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
_ = self.train(use_cinn=False)
def test_cinn_prim_backward(self):
"""cinn + prim_backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "cinn_prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
_ = self.train(use_cinn=True)
def test_prim_backward(self):
"""only prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass
_ = self.train(use_cinn=False)
def test_cinn(self):
"""only cinn"""
self.reset_env_flag()
self.flag = "cinn"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass
_ = self.train(use_cinn=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册