未验证 提交 30b66f03 编写于 作者: Z zhaoyingli 提交者: GitHub

fix conflict (#44891)

上级 247002ec
...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys ...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator: class DistributedOperator:
def __init__(self, serial_op, dist_attr=None): def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op self._serial_op = serial_op
self._serial_inputs = {} self._serial_inputs = {}
...@@ -248,6 +249,7 @@ class DistributedOperator: ...@@ -248,6 +249,7 @@ class DistributedOperator:
class DistributedModule: class DistributedModule:
def __init__(self, serial_module, dist_attr=None): def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module self._serial_module = serial_module
self._dist_attr = dist_attr self._dist_attr = dist_attr
...@@ -265,6 +267,4 @@ class DistributedModule: ...@@ -265,6 +267,4 @@ class DistributedModule:
dist_op = DistributedOperator(op, self._dist_attr) dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable): return output
output = [output]
return list(output)
...@@ -47,6 +47,7 @@ paddle.seed(44) ...@@ -47,6 +47,7 @@ paddle.seed(44)
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset, self).__init__() super(MyDataset, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
...@@ -61,6 +62,7 @@ class MyDataset(Dataset): ...@@ -61,6 +62,7 @@ class MyDataset(Dataset):
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
hidden_size=1024, hidden_size=1024,
intermediate_size=4 * 1024, intermediate_size=4 * 1024,
...@@ -69,43 +71,45 @@ class MLPLayer(nn.Layer): ...@@ -69,43 +71,45 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
d_model = hidden_size d_model = hidden_size
dim_feedforward = intermediate_size dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( weight_attr = paddle.ParamAttr(
mean=0.0, std=initializer_range)) initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None bias_attr = None
self.linear0 = nn.Linear( self.linear0 = nn.Linear(d_model,
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) dim_feedforward,
self.linear1 = nn.Linear( weight_attr,
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op( out = auto.shard_op(self.norm, dist_attr={"process_mesh":
self.norm, dist_attr={"process_mesh": PP_MESH_0})(input)[0] PP_MESH_0})(input)
out = self.linear0(input) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op( out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
self.linear1, dist_attr={"process_mesh": PP_MESH_1})(out)[0] PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
return out return out
def train(): def train():
mlp = MLPLayer( mlp = MLPLayer(hidden_size=hidden_size,
hidden_size=hidden_size, intermediate_size=4 * hidden_size,
intermediate_size=4 * hidden_size, dropout_ratio=0.1,
dropout_ratio=0.1, initializer_range=0.02)
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
learning_rate=0.00001, beta1=0.9,
beta1=0.9, beta2=0.999,
beta2=0.999, epsilon=1e-08,
epsilon=1e-08, grad_clip=None)
grad_clip=None)
dataset = MyDataset(batch_num * batch_size) dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
...@@ -119,11 +123,10 @@ def train(): ...@@ -119,11 +123,10 @@ def train():
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine( engine = Engine(mlp,
mlp, inputs_spec=inputs_spec,
inputs_spec=inputs_spec, labels_spec=labels_spec,
labels_spec=labels_spec, strategy=dist_strategy)
strategy=dist_strategy)
engine.prepare(optimizer, loss) engine.prepare(optimizer, loss)
engine.fit(dataset, engine.fit(dataset,
batch_size=batch_size, batch_size=batch_size,
......
...@@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer): ...@@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer):
if self.fuse: if self.fuse:
assert self.kdim == embed_dim assert self.kdim == embed_dim
assert self.vdim == embed_dim assert self.vdim == embed_dim
self.qkv_proj = nn.Linear( self.qkv_proj = nn.Linear(embed_dim,
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) 3 * embed_dim,
weight_attr,
bias_attr=bias_attr)
else: else:
self.q_proj = nn.Linear( self.q_proj = nn.Linear(embed_dim,
embed_dim, embed_dim,
embed_dim, weight_attr=weight_attr,
weight_attr=weight_attr, bias_attr=bias_attr)
bias_attr=bias_attr) self.k_proj = nn.Linear(self.kdim,
self.k_proj = nn.Linear( embed_dim,
self.kdim, weight_attr=weight_attr,
embed_dim, bias_attr=bias_attr)
weight_attr=weight_attr, self.v_proj = nn.Linear(self.vdim,
bias_attr=bias_attr) embed_dim,
self.v_proj = nn.Linear( weight_attr=weight_attr,
self.vdim, bias_attr=bias_attr)
embed_dim, self.out_proj = nn.Linear(embed_dim,
weight_attr=weight_attr, embed_dim,
bias_attr=bias_attr) weight_attr=weight_attr,
self.out_proj = nn.Linear( bias_attr=bias_attr)
embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query): def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query) mix_layer = self.qkv_proj(query)
...@@ -113,33 +114,30 @@ class MultiHeadAttention(nn.Layer): ...@@ -113,33 +114,30 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 1]
"dims_mapping": [-1, 1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache): if isinstance(cache, self.StaticCache):
...@@ -167,62 +165,56 @@ class MultiHeadAttention(nn.Layer): ...@@ -167,62 +165,56 @@ class MultiHeadAttention(nn.Layer):
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 1]
"dims_mapping": [-1, 1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 1]
"dims_mapping": [-1, 1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -276,17 +268,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -276,17 +268,18 @@ class MultiHeadAttention(nn.Layer):
else: else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache) cache)
product = layers.matmul( product = layers.matmul(x=q,
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) y=k,
transpose_y=True,
alpha=self.head_dim**-0.5)
if attn_mask is not None: if attn_mask is not None:
product = product + attn_mask product = product + attn_mask
weights = F.softmax(product) weights = F.softmax(product)
if self.dropout: if self.dropout:
weights = F.dropout( weights = F.dropout(weights,
weights, self.dropout,
self.dropout, training=self.training,
training=self.training, mode="upscale_in_train")
mode="upscale_in_train")
out = tensor.matmul(weights, v) out = tensor.matmul(weights, v)
# combine heads # combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.transpose(out, perm=[0, 2, 1, 3])
...@@ -294,33 +287,30 @@ class MultiHeadAttention(nn.Layer): ...@@ -294,33 +287,30 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [1, -1]
"dims_mapping": [1, -1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
outs.append(weights) outs.append(weights)
...@@ -362,36 +352,37 @@ class TransformerDecoder(nn.Layer): ...@@ -362,36 +352,37 @@ class TransformerDecoder(nn.Layer):
new_caches = [] new_caches = []
self.checkpoints = [] self.checkpoints = []
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(output,
output, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": PP_MESH_LIST[0], PP_MESH_LIST[0],
"dims_mapping": [-1 for i in range(len(output.shape))] "dims_mapping":
}) [-1 for i in range(len(output.shape))]
})
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPPP_MESH_LIST[0], DPPP_MESH_LIST[0],
"dims_mapping": "dims_mapping": [0] +
[0] + [-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
if _global_parallel_strategy == "mp_pp": if _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": MPPP_MESH_LIST[0], MPPP_MESH_LIST[0],
"dims_mapping": "dims_mapping": [-1] +
[-1] + [-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[0], DPMPPP_MESH_LIST[0],
"dims_mapping": "dims_mapping": [0] +
[0] + [-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if cache is None: if cache is None:
if use_cache: if use_cache:
...@@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer): ...@@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx] "process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx], "process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1 for i in range(len(output.shape))] [-1 for i in range(len(output.shape))]
}) })
...@@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer): ...@@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer): ...@@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] "process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] + "dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer): ...@@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer): ...@@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer):
new_caches.append(new_cache) new_caches.append(new_cache)
else: else:
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": PP_MESH_LIST[mod.mesh_idx] PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx], "process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1 for i in range(len(output.shape))] [-1 for i in range(len(output.shape))]
}) })
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] + "dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer): ...@@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer): ...@@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]})( dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]
output, memory, tgt_mask, use_cache, cache) })(output, memory, tgt_mask, use_cache,
cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
...@@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer): ...@@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)] [0] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer): ...@@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)] [-1] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer): ...@@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)] [0] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer):
self.normalize_before = normalize_before self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3) weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention( self.self_attn = MultiHeadAttention(d_model,
d_model, nhead,
nhead, dropout=attn_dropout,
dropout=attn_dropout, weight_attr=weight_attrs[0],
weight_attr=weight_attrs[0], bias_attr=bias_attrs[0],
bias_attr=bias_attrs[0], mesh_idx=self.mesh_idx)
mesh_idx=self.mesh_idx) self.linear1 = nn.Linear(d_model,
self.linear1 = nn.Linear( dim_feedforward,
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) weight_attrs[2],
self.linear2 = nn.Linear( bias_attr=bias_attrs[2])
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
...@@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [-1, 1]
"dims_mapping": [-1, 1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [-1, 0]
"dims_mapping": [-1, 0] })
})
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [1, -1]
"dims_mapping": [1, -1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
tgt = self.dropout2( tgt = self.dropout2(
self.linear2(F.gelu( self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
self.linear1(tgt), approximate=True)))
tgt = residual + tgt tgt = residual + tgt
if not self.normalize_before: if not self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache) return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory): def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache( incremental_cache = self.self_attn.gen_cache(memory,
memory, type=self.self_attn.Cache) type=self.self_attn.Cache)
return incremental_cache return incremental_cache
...@@ -737,17 +739,15 @@ class GPTEmbeddings(nn.Layer): ...@@ -737,17 +739,15 @@ class GPTEmbeddings(nn.Layer):
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(
vocab_size, vocab_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="word_embeddings",
name="word_embeddings", initializer=nn.initializer.Normal(
initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)))
mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
max_position_embeddings, max_position_embeddings,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="pos_embeddings",
name="pos_embeddings", initializer=nn.initializer.Normal(
initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)))
mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
...@@ -757,33 +757,29 @@ class GPTEmbeddings(nn.Layer): ...@@ -757,33 +757,29 @@ class GPTEmbeddings(nn.Layer):
position_ids = seq_length - ones position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={ "process_mesh": _global_process_mesh,
"process_mesh": _global_process_mesh, "dims_mapping": [1, -1]
"dims_mapping": [1, -1] })
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={ "process_mesh": MPPP_MESH_LIST[0],
"process_mesh": MPPP_MESH_LIST[0], "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={ "process_mesh": DPMPPP_MESH_LIST[0],
"process_mesh": DPMPPP_MESH_LIST[0], "dims_mapping": [1, -1]
"dims_mapping": [1, -1] })
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -821,9 +817,10 @@ class GPTModel(nn.Layer): ...@@ -821,9 +817,10 @@ class GPTModel(nn.Layer):
self.pipline_mode = (pp_degree is not None and pp_degree > 1) self.pipline_mode = (pp_degree is not None and pp_degree > 1)
if self.pipline_mode: if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // pp_degree self.layer_per_stage = num_hidden_layers // pp_degree
self.embeddings = GPTEmbeddings( self.embeddings = GPTEmbeddings(vocab_size, hidden_size,
vocab_size, hidden_size, hidden_dropout_prob, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range) max_position_embeddings,
type_vocab_size, self.initializer_range)
decoder_layers = nn.LayerList() decoder_layers = nn.LayerList()
for i in range(num_hidden_layers): for i in range(num_hidden_layers):
mesh_index = None mesh_index = None
...@@ -831,25 +828,23 @@ class GPTModel(nn.Layer): ...@@ -831,25 +828,23 @@ class GPTModel(nn.Layer):
if self.layer_per_stage is not None: if self.layer_per_stage is not None:
mesh_index = i // self.layer_per_stage mesh_index = i // self.layer_per_stage
decoder_layers.append( decoder_layers.append(
DecoderLayer( DecoderLayer(d_model=hidden_size,
d_model=hidden_size, nhead=num_attention_heads,
nhead=num_attention_heads, dim_feedforward=intermediate_size,
dim_feedforward=intermediate_size, dropout=hidden_dropout_prob,
dropout=hidden_dropout_prob, activation=hidden_act,
activation=hidden_act, attn_dropout=attention_probs_dropout_prob,
attn_dropout=attention_probs_dropout_prob, act_dropout=hidden_dropout_prob,
act_dropout=hidden_dropout_prob, weight_attr=paddle.ParamAttr(
weight_attr=paddle.ParamAttr( initializer=nn.initializer.Normal(
initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)),
mean=0.0, std=self.initializer_range)), bias_attr=None,
bias_attr=None, mesh_idx=mesh_index))
mesh_idx=mesh_index))
Decoder = TransformerDecoder Decoder = TransformerDecoder
self.decoder = Decoder( self.decoder = Decoder(decoder_layers,
decoder_layers, num_hidden_layers,
num_hidden_layers, norm="LayerNorm",
norm="LayerNorm", hidden_size=hidden_size)
hidden_size=hidden_size)
self.checkpoints = [] self.checkpoints = []
def forward(self, def forward(self,
...@@ -863,44 +858,44 @@ class GPTModel(nn.Layer): ...@@ -863,44 +858,44 @@ class GPTModel(nn.Layer):
past_length = 0 past_length = 0
if cache is not None: if cache is not None:
past_length = paddle.shape(cache[0].k)[-2] past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange( position_ids = paddle.arange(past_length,
past_length, paddle.shape(input_ids)[-1] +
paddle.shape(input_ids)[-1] + past_length, past_length,
dtype='int64') dtype='int64')
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
position_ids = paddle.fluid.layers.expand_as(position_ids, position_ids = paddle.fluid.layers.expand_as(
input_ids) position_ids, input_ids)
embedding_output = self.embeddings( embedding_output = self.embeddings(input_ids=input_ids,
input_ids=input_ids, position_ids=position_ids) position_ids=position_ids)
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": PP_MESH_LIST[0], PP_MESH_LIST[0],
"dims_mapping": [-1 for i in range(len(input_ids.shape))] "dims_mapping":
}) [-1 for i in range(len(input_ids.shape))]
})
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPPP_MESH_LIST[0], DPPP_MESH_LIST[0],
"dims_mapping": "dims_mapping": [0] +
[0] + [-1 for i in range(len(input_ids.shape) - 1)] [-1 for i in range(len(input_ids.shape) - 1)]
}) })
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": DPMPPP_MESH_LIST[0], DPMPPP_MESH_LIST[0],
"dims_mapping": "dims_mapping": [0] +
[0] + [-1 for i in range(len(input_ids.shape) - 1)] [-1 for i in range(len(input_ids.shape) - 1)]
}) })
encoder_outputs = self.decoder( encoder_outputs = self.decoder(embedding_output,
embedding_output, memory=None,
memory=None, tgt_mask=attention_mask,
tgt_mask=attention_mask, use_cache=use_cache,
use_cache=use_cache, cache=cache)
cache=cache)
self.checkpoints.extend(self.decoder.checkpoints) self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs return encoder_outputs
...@@ -912,19 +907,19 @@ class GPTForPretraining(nn.Layer): ...@@ -912,19 +907,19 @@ class GPTForPretraining(nn.Layer):
""" """
def __init__( def __init__(
self, self,
gpt, gpt,
vocab_size=50304, vocab_size=50304,
hidden_size=768, hidden_size=768,
initializer_range=0.02, ): initializer_range=0.02,
):
super(GPTForPretraining, self).__init__() super(GPTForPretraining, self).__init__()
self.output_embeddings = nn.Embedding( self.output_embeddings = nn.Embedding(
vocab_size, vocab_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="output_embeddings",
name="output_embeddings", initializer=nn.initializer.Normal(
initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)))
mean=0.0, std=initializer_range)))
self.gpt = gpt self.gpt = gpt
def forward(self, def forward(self,
...@@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer): ...@@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer):
encoder_outputs, cached_kvs = outputs[:2] encoder_outputs, cached_kvs = outputs[:2]
else: else:
encoder_outputs = outputs encoder_outputs = outputs
logits = paddle.matmul( logits = paddle.matmul(encoder_outputs,
encoder_outputs, self.output_embeddings.weight, transpose_y=True) self.output_embeddings.weight,
transpose_y=True)
if use_cache: if use_cache:
return logits, cached_kvs return logits, cached_kvs
else: else:
......
...@@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3]) ...@@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
hidden_size=1024, hidden_size=1024,
intermediate_size=4 * 1024, intermediate_size=4 * 1024,
...@@ -45,42 +46,51 @@ class MLPLayer(nn.Layer): ...@@ -45,42 +46,51 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
d_model = hidden_size d_model = hidden_size
dim_feedforward = intermediate_size dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( weight_attr = paddle.ParamAttr(
mean=0.0, std=initializer_range)) initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None bias_attr = None
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(
hidden_size, hidden_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="word_embeddings",
name="word_embeddings", initializer=nn.initializer.Normal(
initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)))
mean=0.0, std=initializer_range)))
self.linear0 = nn.Linear(d_model,
self.linear0 = nn.Linear( dim_feedforward,
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) weight_attr,
self.linear1 = nn.Linear( bias_attr=bias_attr)
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) self.linear1 = nn.Linear(dim_feedforward,
self.linear2 = nn.Linear( d_model,
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={"process_mesh": PP_MESH_0, "process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]}) "dims_mapping": [0, -1]
auto.shard_tensor( })
self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0, dist_attr={
"dims_mapping": [-1, 0]}) "process_mesh": PP_MESH_0,
auto.shard_tensor( "dims_mapping": [-1, 0]
self.linear1.weight, })
dist_attr={"process_mesh": PP_MESH_1, auto.shard_tensor(self.linear1.weight,
"dims_mapping": [0, -1]}) dist_attr={
auto.shard_tensor( "process_mesh": PP_MESH_1,
self.linear2.weight, "dims_mapping": [0, -1]
dist_attr={"process_mesh": PP_MESH_1, })
"dims_mapping": [0, -1]}) auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
w_out = self.word_embeddings(input) w_out = self.word_embeddings(input)
out = self.linear0(w_out) out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True) gelu_out = F.gelu(out, approximate=True)
...@@ -98,21 +108,24 @@ def mlp_forward(train_program, start_program): ...@@ -98,21 +108,24 @@ def mlp_forward(train_program, start_program):
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
input = static.data(name="input", shape=[batch_size], dtype='int32') input = static.data(name="input", shape=[batch_size], dtype='int32')
label = static.data( label = static.data(name="label",
name="label", shape=[batch_size, 1], dtype='float32') shape=[batch_size, 1],
dtype='float32')
auto.shard_tensor(
input, dist_attr={"process_mesh": PP_MESH_0, auto.shard_tensor(input,
"dims_mapping": [-1]}) dist_attr={
auto.shard_tensor( "process_mesh": PP_MESH_0,
label, "dims_mapping": [-1]
dist_attr={"process_mesh": PP_MESH_1, })
"dims_mapping": [-1, -1]}) auto.shard_tensor(label,
dist_attr={
mlp = MLPLayer( "process_mesh": PP_MESH_1,
hidden_size=hidden_size, "dims_mapping": [-1, -1]
intermediate_size=4 * hidden_size, })
initializer_range=0.02)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input) predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label) error_cost = paddle.nn.functional.square_error_cost(predict, label)
...@@ -137,13 +150,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -137,13 +150,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program) dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(complete_train_program,
complete_train_program, startup_program,
startup_program, loss,
loss, parameter_list=None,
parameter_list=None, no_grad_set=None,
no_grad_set=None, callbacks=None)
callbacks=None)
# logical partition # logical partition
partitioner = Partitioner(dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
...@@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id): ...@@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[ if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[
0]: 0]:
send_result = True send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]:
0]:
recv_result = True recv_result = True
return send_result and recv_result return send_result and recv_result
...@@ -206,6 +217,7 @@ def check_allgather(dist_main_program): ...@@ -206,6 +217,7 @@ def check_allgather(dist_main_program):
class TestMLPReshard(unittest.TestCase): class TestMLPReshard(unittest.TestCase):
def test_mlp_mppp(self): def test_mlp_mppp(self):
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -230,38 +242,29 @@ class TestMLPReshard(unittest.TestCase): ...@@ -230,38 +242,29 @@ class TestMLPReshard(unittest.TestCase):
process_mesh = auto.ProcessMesh(mesh=[0, 3]) process_mesh = auto.ProcessMesh(mesh=[0, 3])
with static.program_guard(train_program, startup_program): with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor( x = auto.shard_tensor(x,
x, dist_attr={
dist_attr={ "process_mesh": process_mesh,
"process_mesh": process_mesh, "dims_mapping": [0, -1]
"dims_mapping": [0, -1] })
})
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor( w = auto.shard_tensor(w,
w, dist_attr={
dist_attr={ "process_mesh": process_mesh,
"process_mesh": process_mesh, "dims_mapping": [-1, -1]
"dims_mapping": [-1, -1] })
})
y = paddle.distributed.shard_op(paddle.matmul,
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { dist_attr={
# x.name: [-1, -1], "process_mesh": process_mesh,
# w.name: [-1, -1] x: {
# }, **{"x": x, "dims_mapping": [-1, -1]
# "y": w})[0] },
w: {
y = paddle.distributed.shard_op( "dims_mapping": [-1, -1]
paddle.matmul, }
dist_attr={ })(x, w)
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)[0]
rank_id = 0 rank_id = 0
dist_context = DistributedContext() dist_context = DistributedContext()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册