未验证 提交 7fc9f433 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim]check prim forward in windows and mac (#50527)

* check win

* fix random error

* fix manage
上级 8c0d957c
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -134,8 +133,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -134,8 +133,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('softmax' not in fwd_ops) self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True) cinn_res = self.train(use_prim=True)
...@@ -143,8 +140,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -143,8 +140,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
) )
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -98,14 +97,9 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -98,14 +97,9 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('gelu' not in fwd_ops) self.assertTrue('gelu' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes: for shape in self.shapes:
for dtype in self.dtypes: for dtype in self.dtypes:
if ( if paddle.device.get_device() == "cpu" and dtype == "float16":
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
print("need pass this case") print("need pass this case")
continue continue
data = generate_data(shape, dtype) data = generate_data(shape, dtype)
...@@ -121,9 +115,6 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -121,9 +115,6 @@ class TestPrimForwardAndBackward(unittest.TestCase):
atol=TOLERANCE[dtype]['atol'], atol=TOLERANCE[dtype]['atol'],
) )
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -10,11 +10,6 @@ file( ...@@ -10,11 +10,6 @@ file(
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}") string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}")
if(WIN32 OR APPLE)
# TODO: Fix these unittests failed on Windows and MAC.
list(REMOVE_ITEM TEST_OPS ${TEST_OPS_GRAD})
endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach() endforeach()
......
...@@ -34,7 +34,7 @@ def softmax_composite(x, axis): ...@@ -34,7 +34,7 @@ def softmax_composite(x, axis):
"""define composite rule of op softmax""" """define composite rule of op softmax"""
if not x.shape: if not x.shape:
# do not return 1, to ensure gradients # do not return 1, to ensure gradients
res = divide(x + 1e-5, x + 1e-5) res = exp(x - x)
return res return res
max_temp = max(x, axis, keepdim=True) max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True max_temp.stop_gradient = True
......
...@@ -83,6 +83,7 @@ API_FILES=("CMakeLists.txt" ...@@ -83,6 +83,7 @@ API_FILES=("CMakeLists.txt"
"paddle/fluid/prim/api/composite_backward/composite_backward_api.h" "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
"paddle/fluid/prim/api/manual_prim/prim_manual_api.h" "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
"python/paddle/incubate/autograd/composite_rules.py" "python/paddle/incubate/autograd/composite_rules.py"
"python/paddle/incubate/autograd/primitives.py"
) )
approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000` approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000`
...@@ -210,7 +211,7 @@ for API_FILE in ${API_FILES[*]}; do ...@@ -210,7 +211,7 @@ for API_FILE in ${API_FILES[*]}; do
check_approval 1 JiabinYang cxxly xiaoguoguo626807 check_approval 1 JiabinYang cxxly xiaoguoguo626807
elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then
echo_line="You must have one RD (cyber-pioneer(chenzhuo), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n" echo_line="You must have one RD (cyber-pioneer(chenzhuo), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n"
check_approval cyber-pioneer JiabinYang check_approval 1 cyber-pioneer JiabinYang
else else
echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n" echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册