diff --git a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt index 364f17c2e0d0aeda8a1ab4a33d1f56a61fe5f966..5a58bd25d27e73809b5c65dc74d3b9ea0a911389 100644 --- a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt @@ -3,6 +3,7 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_amp") +list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_sharding") foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) @@ -12,3 +13,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_with_asp MODULES test_fleet_with_asp ENVS ${dist_ENVS}) py_test_modules(test_fleet_with_asp_amp MODULES test_fleet_with_asp_amp ENVS ${dist_ENVS}) endif() + +if((WITH_DISTRIBUTE) AND (NOT WIN32) AND (NOT APPLE)) + py_test_modules(test_fleet_with_asp_sharding MODULES test_fleet_with_asp_sharding ENVS ${dist_ENVS}) +endif() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py index dd609d3ae2e1136a632083dfc229550aade33547..26170015ae8c249fb3a36d13285f5b34491acb3a 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_sharding.py @@ -20,7 +20,6 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import os -import sys from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np @@ -78,9 +77,6 @@ class TestFleetWithASPSharding(unittest.TestCase): return avg_cost, dist_strategy, input_x, input_y def test_with_asp_sharding(self): - if sys.platform == 'win32': - return - print(sys.platform) fleet.init(is_collective=True) train_prog, startup_prog = fluid.Program(), fluid.Program() avg_cost, strategy, input_x, input_y = self.net(train_prog,