未验证 提交 881e55e4 编写于 作者: L lilong12 提交者: GitHub

add checker, test=develop (#35109)

上级 5b737834
......@@ -338,6 +338,15 @@ def set_shard_mask(x, mask):
"""
_static_mode_check()
assert isinstance(mask, list)
np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask)
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
return x
......@@ -425,6 +434,7 @@ def set_offload_device(x, device):
"""
_static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
......
......@@ -37,7 +37,7 @@ def _append_attr_suffix(name):
LAST_PP_STAGE = 3
MASK = [[0, 1], [1, 0], [1, 1]]
MASK = [[0, 1, 1], [0, 1, 1]]
MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]])
......@@ -58,7 +58,7 @@ class SimpleNet(nn.Layer):
dist.set_pipeline_stage(LAST_PP_STAGE)
y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1])
dist.set_offload_device(y, "gpu:3")
dist.set_offload_device(y, "cpu")
linear1 = self.dense1(y)
out = self.dense2(linear1)
......@@ -86,9 +86,9 @@ class TestAutoParallelAPI(unittest.TestCase):
x._get_attr(shard_mask_attr), _flatten_nested_list(MASK))
self.assertEqual(x.shard_mask, _flatten_nested_list(MASK))
offload_attr = _append_attr_suffix('offload_device')
self.assertEqual(y._get_attr(offload_attr), "gpu:3")
self.assertEqual(y._get_attr(offload_attr), "cpu")
self.assertEqual(y.desc.has_attr(offload_attr), True)
self.assertEqual(y.offload_device, "gpu:3")
self.assertEqual(y.offload_device, "cpu")
y._remove_attr(offload_attr)
self.assertEqual(y._has_attr(offload_attr), False)
ops = paddle.static.default_main_program().block(0).ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册