未验证 提交 94c17a0f 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add mul dist op cost (#44973)

* add mul dist op cost

* add mul unittest
上级 2c77b575
......@@ -215,6 +215,215 @@ class TestDistOpCost(unittest.TestCase):
dist_context, cluster)
self.assertTrue(dist_op_cost)
def test_dist_op_cost_part3(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# matmul_v2
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
def test_dist_op_cost_part4(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# mul
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
if __name__ == "__main__":
unittest.main()
......@@ -76,6 +76,14 @@ class MLPLayer(nn.Layer):
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
param = paddle.fluid.layers.create_parameter([1024, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, 1]
})
out = paddle.fluid.layers.mul(out, param)
return out
......
......@@ -93,6 +93,14 @@ class MLPLayer(nn.Layer):
})
w_out = self.word_embeddings(input)
out = self.linear0(w_out)
param = paddle.fluid.layers.create_parameter([4096, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.mul(out, param)
gelu_out = F.gelu(out, approximate=True)
out = self.linear1(gelu_out)
out1 = self.linear2(gelu_out)
......@@ -228,7 +236,7 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册