未验证 提交 7fc9f433 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim]check prim forward in windows and mac (#50527)

* check win

* fix random error

* fix manage
上级 8c0d957c
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -134,17 +133,13 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
else:
pass
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
if __name__ == '__main__':
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
......@@ -98,31 +97,23 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('gelu' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes:
for dtype in self.dtypes:
if (
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i],
dy_res[i],
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
for shape in self.shapes:
for dtype in self.dtypes:
if paddle.device.get_device() == "cpu" and dtype == "float16":
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i],
dy_res[i],
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
if __name__ == '__main__':
......
......@@ -10,11 +10,6 @@ file(
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}")
if(WIN32 OR APPLE)
# TODO: Fix these unittests failed on Windows and MAC.
list(REMOVE_ITEM TEST_OPS ${TEST_OPS_GRAD})
endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
......
......@@ -34,7 +34,7 @@ def softmax_composite(x, axis):
"""define composite rule of op softmax"""
if not x.shape:
# do not return 1, to ensure gradients
res = divide(x + 1e-5, x + 1e-5)
res = exp(x - x)
return res
max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True
......
......@@ -83,6 +83,7 @@ API_FILES=("CMakeLists.txt"
"paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
"paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
"python/paddle/incubate/autograd/composite_rules.py"
"python/paddle/incubate/autograd/primitives.py"
)
approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000`
......@@ -210,7 +211,7 @@ for API_FILE in ${API_FILES[*]}; do
check_approval 1 JiabinYang cxxly xiaoguoguo626807
elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then
echo_line="You must have one RD (cyber-pioneer(chenzhuo), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n"
check_approval cyber-pioneer JiabinYang
check_approval 1 cyber-pioneer JiabinYang
else
echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册