未验证 提交 7fc9f433 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim]check prim forward in windows and mac (#50527)

* check win

* fix random error

* fix manage
上级 8c0d957c
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -134,17 +133,13 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -134,17 +133,13 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('softmax' not in fwd_ops) self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system() dy_res = self.train(use_prim=False)
if plat == "Linux": cinn_res = self.train(use_prim=True)
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True) for i in range(len(dy_res)):
np.testing.assert_allclose(
for i in range(len(dy_res)): cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
np.testing.assert_allclose( )
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import unittest import unittest
import numpy as np import numpy as np
...@@ -98,31 +97,23 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -98,31 +97,23 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.assertTrue('gelu' not in fwd_ops) self.assertTrue('gelu' not in fwd_ops)
def test_cinn_prim(self): def test_cinn_prim(self):
plat = platform.system() for shape in self.shapes:
if plat == "Linux": for dtype in self.dtypes:
for shape in self.shapes: if paddle.device.get_device() == "cpu" and dtype == "float16":
for dtype in self.dtypes: print("need pass this case")
if ( continue
paddle.device.get_device() == "cpu" data = generate_data(shape, dtype)
and dtype == "float16" data_t = paddle.to_tensor(data)
): data_t.stop_gradient = False
print("need pass this case") dy_res = self.train(use_prim=False, data=data_t)
continue cinn_res = self.train(use_prim=True, data=data_t)
data = generate_data(shape, dtype) for i in range(len(dy_res)):
data_t = paddle.to_tensor(data) np.testing.assert_allclose(
data_t.stop_gradient = False cinn_res[i],
dy_res = self.train(use_prim=False, data=data_t) dy_res[i],
cinn_res = self.train(use_prim=True, data=data_t) rtol=TOLERANCE[dtype]['rtol'],
for i in range(len(dy_res)): atol=TOLERANCE[dtype]['atol'],
np.testing.assert_allclose( )
cinn_res[i],
dy_res[i],
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,11 +10,6 @@ file( ...@@ -10,11 +10,6 @@ file(
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}") string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}")
if(WIN32 OR APPLE)
# TODO: Fix these unittests failed on Windows and MAC.
list(REMOVE_ITEM TEST_OPS ${TEST_OPS_GRAD})
endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach() endforeach()
......
...@@ -34,7 +34,7 @@ def softmax_composite(x, axis): ...@@ -34,7 +34,7 @@ def softmax_composite(x, axis):
"""define composite rule of op softmax""" """define composite rule of op softmax"""
if not x.shape: if not x.shape:
# do not return 1, to ensure gradients # do not return 1, to ensure gradients
res = divide(x + 1e-5, x + 1e-5) res = exp(x - x)
return res return res
max_temp = max(x, axis, keepdim=True) max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True max_temp.stop_gradient = True
......
...@@ -83,6 +83,7 @@ API_FILES=("CMakeLists.txt" ...@@ -83,6 +83,7 @@ API_FILES=("CMakeLists.txt"
"paddle/fluid/prim/api/composite_backward/composite_backward_api.h" "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
"paddle/fluid/prim/api/manual_prim/prim_manual_api.h" "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
"python/paddle/incubate/autograd/composite_rules.py" "python/paddle/incubate/autograd/composite_rules.py"
"python/paddle/incubate/autograd/primitives.py"
) )
approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000` approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000`
...@@ -210,7 +211,7 @@ for API_FILE in ${API_FILES[*]}; do ...@@ -210,7 +211,7 @@ for API_FILE in ${API_FILES[*]}; do
check_approval 1 JiabinYang cxxly xiaoguoguo626807 check_approval 1 JiabinYang cxxly xiaoguoguo626807
elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then
echo_line="You must have one RD (cyber-pioneer(chenzhuo), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n" echo_line="You must have one RD (cyber-pioneer(chenzhuo), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n"
check_approval cyber-pioneer JiabinYang check_approval 1 cyber-pioneer JiabinYang
else else
echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n" echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册