未验证 提交 881e55e4 编写于 作者: L lilong12 提交者: GitHub

add checker, test=develop (#35109)

上级 5b737834
...@@ -338,6 +338,15 @@ def set_shard_mask(x, mask): ...@@ -338,6 +338,15 @@ def set_shard_mask(x, mask):
""" """
_static_mode_check() _static_mode_check()
assert isinstance(mask, list) assert isinstance(mask, list)
np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask)
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask') attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask)) x._set_attr(attr_name, _flatten_nested_list(mask))
return x return x
...@@ -425,6 +434,7 @@ def set_offload_device(x, device): ...@@ -425,6 +434,7 @@ def set_offload_device(x, device):
""" """
_static_mode_check() _static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device") attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device) x._set_attr(attr_name, device)
return x return x
......
...@@ -37,7 +37,7 @@ def _append_attr_suffix(name): ...@@ -37,7 +37,7 @@ def _append_attr_suffix(name):
LAST_PP_STAGE = 3 LAST_PP_STAGE = 3
MASK = [[0, 1], [1, 0], [1, 1]] MASK = [[0, 1, 1], [0, 1, 1]]
MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]]) MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]])
...@@ -58,7 +58,7 @@ class SimpleNet(nn.Layer): ...@@ -58,7 +58,7 @@ class SimpleNet(nn.Layer):
dist.set_pipeline_stage(LAST_PP_STAGE) dist.set_pipeline_stage(LAST_PP_STAGE)
y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1]) y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1])
dist.set_offload_device(y, "gpu:3") dist.set_offload_device(y, "cpu")
linear1 = self.dense1(y) linear1 = self.dense1(y)
out = self.dense2(linear1) out = self.dense2(linear1)
...@@ -86,9 +86,9 @@ class TestAutoParallelAPI(unittest.TestCase): ...@@ -86,9 +86,9 @@ class TestAutoParallelAPI(unittest.TestCase):
x._get_attr(shard_mask_attr), _flatten_nested_list(MASK)) x._get_attr(shard_mask_attr), _flatten_nested_list(MASK))
self.assertEqual(x.shard_mask, _flatten_nested_list(MASK)) self.assertEqual(x.shard_mask, _flatten_nested_list(MASK))
offload_attr = _append_attr_suffix('offload_device') offload_attr = _append_attr_suffix('offload_device')
self.assertEqual(y._get_attr(offload_attr), "gpu:3") self.assertEqual(y._get_attr(offload_attr), "cpu")
self.assertEqual(y.desc.has_attr(offload_attr), True) self.assertEqual(y.desc.has_attr(offload_attr), True)
self.assertEqual(y.offload_device, "gpu:3") self.assertEqual(y.offload_device, "cpu")
y._remove_attr(offload_attr) y._remove_attr(offload_attr)
self.assertEqual(y._has_attr(offload_attr), False) self.assertEqual(y._has_attr(offload_attr), False)
ops = paddle.static.default_main_program().block(0).ops ops = paddle.static.default_main_program().block(0).ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册