diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index a2c2748a8cea390003dfec857a252b7df3ee1b05..b6a77b778885f552438fb654ab3c4a88ebd79a05 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys class DistributedOperator: + def __init__(self, serial_op, dist_attr=None): self._serial_op = serial_op self._serial_inputs = {} @@ -248,6 +249,7 @@ class DistributedOperator: class DistributedModule: + def __init__(self, serial_module, dist_attr=None): self._serial_module = serial_module self._dist_attr = dist_attr @@ -265,6 +267,4 @@ class DistributedModule: dist_op = DistributedOperator(op, self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr) default_dist_ctx.add_dist_op_for_program(dist_op) - if isinstance(output, Variable): - output = [output] - return list(output) + return output diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 3a16f5d70d78a1a919b083d9266601378a3f558d..73aae9b4b18de47b4129a4df993f02bfce3e4ff8 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -47,6 +47,7 @@ paddle.seed(44) class MyDataset(Dataset): + def __init__(self, num_samples): super(MyDataset, self).__init__() self.num_samples = num_samples @@ -61,6 +62,7 @@ class MyDataset(Dataset): class MLPLayer(nn.Layer): + def __init__(self, hidden_size=1024, intermediate_size=4 * 1024, @@ -69,43 +71,45 @@ class MLPLayer(nn.Layer): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size - weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range)) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) bias_attr = None - self.linear0 = nn.Linear( - d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) - self.linear1 = nn.Linear( - dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - out = auto.shard_op( - self.norm, dist_attr={"process_mesh": PP_MESH_0})(input)[0] - out = self.linear0(input) + out = auto.shard_op(self.norm, dist_attr={"process_mesh": + PP_MESH_0})(input) + out = self.linear0(out) out = F.gelu(out, approximate=True) - out = auto.shard_op( - self.linear1, dist_attr={"process_mesh": PP_MESH_1})(out)[0] + out = auto.shard_op(self.linear1, dist_attr={"process_mesh": + PP_MESH_1})(out) out = self.dropout(out) out = self.linear2(out) return out def train(): - mlp = MLPLayer( - hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) dataset = MyDataset(batch_num * batch_size) inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') @@ -119,11 +123,10 @@ def train(): dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) - engine = Engine( - mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy) + engine = Engine(mlp, + inputs_spec=inputs_spec, + labels_spec=labels_spec, + strategy=dist_strategy) engine.prepare(optimizer, loss) engine.fit(dataset, batch_size=batch_size, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index b1c15c5ce62652eb61b981b4a587b539a19ba63c..87c746ab5d3b506ba865904d15bf04ac0310f85d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer): if self.fuse: assert self.kdim == embed_dim assert self.vdim == embed_dim - self.qkv_proj = nn.Linear( - embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) + self.qkv_proj = nn.Linear(embed_dim, + 3 * embed_dim, + weight_attr, + bias_attr=bias_attr) else: - self.q_proj = nn.Linear( - embed_dim, - embed_dim, - weight_attr=weight_attr, - bias_attr=bias_attr) - self.k_proj = nn.Linear( - self.kdim, - embed_dim, - weight_attr=weight_attr, - bias_attr=bias_attr) - self.v_proj = nn.Linear( - self.vdim, - embed_dim, - weight_attr=weight_attr, - bias_attr=bias_attr) - self.out_proj = nn.Linear( - embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr) + self.q_proj = nn.Linear(embed_dim, + embed_dim, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.k_proj = nn.Linear(self.kdim, + embed_dim, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.v_proj = nn.Linear(self.vdim, + embed_dim, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.out_proj = nn.Linear(embed_dim, + embed_dim, + weight_attr=weight_attr, + bias_attr=bias_attr) def _fuse_prepare_qkv(self, query): mix_layer = self.qkv_proj(query) @@ -113,33 +114,30 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.q_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.q_proj.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.q_proj.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.q_proj.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 1] + }) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) if isinstance(cache, self.StaticCache): @@ -167,62 +165,56 @@ class MultiHeadAttention(nn.Layer): """ k = self.k_proj(key) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.k_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.k_proj.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.k_proj.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.k_proj.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 1] + }) v = self.v_proj(value) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.v_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.v_proj.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.v_proj.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.v_proj.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) @@ -276,17 +268,18 @@ class MultiHeadAttention(nn.Layer): else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) - product = layers.matmul( - x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + product = layers.matmul(x=q, + y=k, + transpose_y=True, + alpha=self.head_dim**-0.5) if attn_mask is not None: product = product + attn_mask weights = F.softmax(product) if self.dropout: - weights = F.dropout( - weights, - self.dropout, - training=self.training, - mode="upscale_in_train") + weights = F.dropout(weights, + self.dropout, + training=self.training, + mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) @@ -294,33 +287,30 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.out_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.out_proj.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.out_proj.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.out_proj.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [1, -1] + }) outs = [out] if self.need_weights: outs.append(weights) @@ -362,36 +352,37 @@ class TransformerDecoder(nn.Layer): new_caches = [] self.checkpoints = [] if _global_parallel_strategy == "pp": - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": PP_MESH_LIST[0], - "dims_mapping": [-1 for i in range(len(output.shape))] - }) + auto.shard_tensor(output, + dist_attr={ + "process_mesh": + PP_MESH_LIST[0], + "dims_mapping": + [-1 for i in range(len(output.shape))] + }) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": DPPP_MESH_LIST[0], - "dims_mapping": - [0] + [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, + dist_attr={ + "process_mesh": + DPPP_MESH_LIST[0], + "dims_mapping": [0] + + [-1 for i in range(len(output.shape) - 1)] + }) if _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[0], - "dims_mapping": - [-1] + [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, + dist_attr={ + "process_mesh": + MPPP_MESH_LIST[0], + "dims_mapping": [-1] + + [-1 for i in range(len(output.shape) - 1)] + }) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[0], - "dims_mapping": - [0] + [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[0], + "dims_mapping": [0] + + [-1 for i in range(len(output.shape) - 1)] + }) for i, mod in enumerate(self.layers): if cache is None: if use_cache: @@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": PP_MESH_LIST[mod.mesh_idx], + "process_mesh": + PP_MESH_LIST[mod.mesh_idx], "dims_mapping": [-1 for i in range(len(output.shape))] }) @@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) @@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + MPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [-1] + [-1 for i in range(len(output.shape) - 1)] }) @@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPMPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) @@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer): new_caches.append(new_cache) else: if _global_parallel_strategy == "pp": - output = auto.shard_op( - mod, - dist_attr={ - "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + output = auto.shard_op(mod, + dist_attr={ + "process_mesh": + PP_MESH_LIST[mod.mesh_idx] + })(output, memory, tgt_mask, + use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": PP_MESH_LIST[mod.mesh_idx], + "process_mesh": + PP_MESH_LIST[mod.mesh_idx], "dims_mapping": [-1 for i in range(len(output.shape))] }) elif _global_parallel_strategy == "dp_pp": - output = auto.shard_op( - mod, - dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + output = auto.shard_op(mod, + dist_attr={ + "process_mesh": + DPPP_MESH_LIST[mod.mesh_idx] + })(output, memory, tgt_mask, + use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) elif _global_parallel_strategy == "mp_pp": - output = auto.shard_op( - mod, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + output = auto.shard_op(mod, + dist_attr={ + "process_mesh": + MPPP_MESH_LIST[mod.mesh_idx] + })(output, memory, tgt_mask, + use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + MPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [-1] + [-1 for i in range(len(output.shape) - 1)] }) @@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPMPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) @@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer): if _global_parallel_strategy == "pp": output, new_cache = auto.shard_op( mod, - dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]})( - output, memory, tgt_mask, use_cache, cache) + dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx] + })(output, memory, tgt_mask, use_cache, + cache) auto.shard_tensor( output, dist_attr={ @@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer): auto.shard_tensor( output, dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) @@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer): auto.shard_tensor( output, dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + MPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [-1] + [-1 for i in range(len(output.shape) - 1)] }) @@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer): auto.shard_tensor( output, dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], + "process_mesh": + DPMPPP_MESH_LIST[mod.mesh_idx], "dims_mapping": [0] + [-1 for i in range(len(output.shape) - 1)] }) @@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer): self.normalize_before = normalize_before weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) - self.self_attn = MultiHeadAttention( - d_model, - nhead, - dropout=attn_dropout, - weight_attr=weight_attrs[0], - bias_attr=bias_attrs[0], - mesh_idx=self.mesh_idx) - self.linear1 = nn.Linear( - d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) - self.linear2 = nn.Linear( - dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) + self.self_attn = MultiHeadAttention(d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0], + mesh_idx=self.mesh_idx) + self.linear1 = nn.Linear(d_model, + dim_feedforward, + weight_attrs[2], + bias_attr=bias_attrs[2]) + self.linear2 = nn.Linear(dim_feedforward, + d_model, + weight_attrs[2], + bias_attr=bias_attrs[2]) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") @@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.linear1.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 0] + }) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.linear1.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [-1, 1] + }) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.linear2.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear2.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.linear2.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear2.weight, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[self.mesh_idx], + "dims_mapping": [1, -1] + }) tgt = self.dropout2( - self.linear2(F.gelu( - self.linear1(tgt), approximate=True))) + self.linear2(F.gelu(self.linear1(tgt), approximate=True))) tgt = residual + tgt if not self.normalize_before: tgt = self.norm2(tgt) return tgt if use_cache is False else (tgt, incremental_cache) def gen_cache(self, memory): - incremental_cache = self.self_attn.gen_cache( - memory, type=self.self_attn.Cache) + incremental_cache = self.self_attn.gen_cache(memory, + type=self.self_attn.Cache) return incremental_cache @@ -737,17 +739,15 @@ class GPTEmbeddings(nn.Layer): self.word_embeddings = nn.Embedding( vocab_size, hidden_size, - weight_attr=paddle.ParamAttr( - name="word_embeddings", - initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range))) + weight_attr=paddle.ParamAttr(name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) self.position_embeddings = nn.Embedding( max_position_embeddings, hidden_size, - weight_attr=paddle.ParamAttr( - name="pos_embeddings", - initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range))) + weight_attr=paddle.ParamAttr(name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, input_ids, position_ids=None): @@ -757,33 +757,29 @@ class GPTEmbeddings(nn.Layer): position_ids = seq_length - ones input_embedings = self.word_embeddings(input_ids) if _global_parallel_strategy == "mp": - auto.shard_tensor( - self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor( - self.word_embeddings.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[0], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, + dist_attr={ + "process_mesh": MPPP_MESH_LIST[0], + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - self.word_embeddings.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[0], - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, + dist_attr={ + "process_mesh": DPMPPP_MESH_LIST[0], + "dims_mapping": [1, -1] + }) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings embeddings = self.dropout(embeddings) @@ -821,9 +817,10 @@ class GPTModel(nn.Layer): self.pipline_mode = (pp_degree is not None and pp_degree > 1) if self.pipline_mode: self.layer_per_stage = num_hidden_layers // pp_degree - self.embeddings = GPTEmbeddings( - vocab_size, hidden_size, hidden_dropout_prob, - max_position_embeddings, type_vocab_size, self.initializer_range) + self.embeddings = GPTEmbeddings(vocab_size, hidden_size, + hidden_dropout_prob, + max_position_embeddings, + type_vocab_size, self.initializer_range) decoder_layers = nn.LayerList() for i in range(num_hidden_layers): mesh_index = None @@ -831,25 +828,23 @@ class GPTModel(nn.Layer): if self.layer_per_stage is not None: mesh_index = i // self.layer_per_stage decoder_layers.append( - DecoderLayer( - d_model=hidden_size, - nhead=num_attention_heads, - dim_feedforward=intermediate_size, - dropout=hidden_dropout_prob, - activation=hidden_act, - attn_dropout=attention_probs_dropout_prob, - act_dropout=hidden_dropout_prob, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Normal( - mean=0.0, std=self.initializer_range)), - bias_attr=None, - mesh_idx=mesh_index)) + DecoderLayer(d_model=hidden_size, + nhead=num_attention_heads, + dim_feedforward=intermediate_size, + dropout=hidden_dropout_prob, + activation=hidden_act, + attn_dropout=attention_probs_dropout_prob, + act_dropout=hidden_dropout_prob, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)), + bias_attr=None, + mesh_idx=mesh_index)) Decoder = TransformerDecoder - self.decoder = Decoder( - decoder_layers, - num_hidden_layers, - norm="LayerNorm", - hidden_size=hidden_size) + self.decoder = Decoder(decoder_layers, + num_hidden_layers, + norm="LayerNorm", + hidden_size=hidden_size) self.checkpoints = [] def forward(self, @@ -863,44 +858,44 @@ class GPTModel(nn.Layer): past_length = 0 if cache is not None: past_length = paddle.shape(cache[0].k)[-2] - position_ids = paddle.arange( - past_length, - paddle.shape(input_ids)[-1] + past_length, - dtype='int64') + position_ids = paddle.arange(past_length, + paddle.shape(input_ids)[-1] + + past_length, + dtype='int64') position_ids = position_ids.unsqueeze(0) - position_ids = paddle.fluid.layers.expand_as(position_ids, - input_ids) - embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids) + position_ids = paddle.fluid.layers.expand_as( + position_ids, input_ids) + embedding_output = self.embeddings(input_ids=input_ids, + position_ids=position_ids) if _global_parallel_strategy == "pp": - auto.shard_tensor( - input_ids, - dist_attr={ - "process_mesh": PP_MESH_LIST[0], - "dims_mapping": [-1 for i in range(len(input_ids.shape))] - }) + auto.shard_tensor(input_ids, + dist_attr={ + "process_mesh": + PP_MESH_LIST[0], + "dims_mapping": + [-1 for i in range(len(input_ids.shape))] + }) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor( - input_ids, - dist_attr={ - "process_mesh": DPPP_MESH_LIST[0], - "dims_mapping": - [0] + [-1 for i in range(len(input_ids.shape) - 1)] - }) + auto.shard_tensor(input_ids, + dist_attr={ + "process_mesh": + DPPP_MESH_LIST[0], + "dims_mapping": [0] + + [-1 for i in range(len(input_ids.shape) - 1)] + }) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - input_ids, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[0], - "dims_mapping": - [0] + [-1 for i in range(len(input_ids.shape) - 1)] - }) - encoder_outputs = self.decoder( - embedding_output, - memory=None, - tgt_mask=attention_mask, - use_cache=use_cache, - cache=cache) + auto.shard_tensor(input_ids, + dist_attr={ + "process_mesh": + DPMPPP_MESH_LIST[0], + "dims_mapping": [0] + + [-1 for i in range(len(input_ids.shape) - 1)] + }) + encoder_outputs = self.decoder(embedding_output, + memory=None, + tgt_mask=attention_mask, + use_cache=use_cache, + cache=cache) self.checkpoints.extend(self.decoder.checkpoints) return encoder_outputs @@ -912,19 +907,19 @@ class GPTForPretraining(nn.Layer): """ def __init__( - self, - gpt, - vocab_size=50304, - hidden_size=768, - initializer_range=0.02, ): + self, + gpt, + vocab_size=50304, + hidden_size=768, + initializer_range=0.02, + ): super(GPTForPretraining, self).__init__() self.output_embeddings = nn.Embedding( vocab_size, hidden_size, - weight_attr=paddle.ParamAttr( - name="output_embeddings", - initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range))) + weight_attr=paddle.ParamAttr(name="output_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) self.gpt = gpt def forward(self, @@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer): encoder_outputs, cached_kvs = outputs[:2] else: encoder_outputs = outputs - logits = paddle.matmul( - encoder_outputs, self.output_embeddings.weight, transpose_y=True) + logits = paddle.matmul(encoder_outputs, + self.output_embeddings.weight, + transpose_y=True) if use_cache: return logits, cached_kvs else: diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 5f9c2ec2371a5088ce319651ebf7e1e791103fb2..dfb314796a9ff2166fd6970c46ebec2455280a1c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3]) class MLPLayer(nn.Layer): + def __init__(self, hidden_size=1024, intermediate_size=4 * 1024, @@ -45,42 +46,51 @@ class MLPLayer(nn.Layer): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size - weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range)) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) bias_attr = None self.word_embeddings = nn.Embedding( hidden_size, hidden_size, - weight_attr=paddle.ParamAttr( - name="word_embeddings", - initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range))) - - self.linear0 = nn.Linear( - d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) - self.linear1 = nn.Linear( - dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) - self.linear2 = nn.Linear( - dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + weight_attr=paddle.ParamAttr(name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) + self.linear2 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) def forward(self, input): - auto.shard_tensor( - self.word_embeddings.weight, - dist_attr={"process_mesh": PP_MESH_0, - "dims_mapping": [0, -1]}) - auto.shard_tensor( - self.linear0.weight, - dist_attr={"process_mesh": PP_MESH_0, - "dims_mapping": [-1, 0]}) - auto.shard_tensor( - self.linear1.weight, - dist_attr={"process_mesh": PP_MESH_1, - "dims_mapping": [0, -1]}) - auto.shard_tensor( - self.linear2.weight, - dist_attr={"process_mesh": PP_MESH_1, - "dims_mapping": [0, -1]}) + auto.shard_tensor(self.word_embeddings.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [0, -1] + }) + auto.shard_tensor(self.linear0.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor(self.linear1.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [0, -1] + }) + auto.shard_tensor(self.linear2.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [0, -1] + }) w_out = self.word_embeddings(input) out = self.linear0(w_out) gelu_out = F.gelu(out, approximate=True) @@ -98,21 +108,24 @@ def mlp_forward(train_program, start_program): hidden_size = 1024 sequence_len = 512 input = static.data(name="input", shape=[batch_size], dtype='int32') - label = static.data( - name="label", shape=[batch_size, 1], dtype='float32') - - auto.shard_tensor( - input, dist_attr={"process_mesh": PP_MESH_0, - "dims_mapping": [-1]}) - auto.shard_tensor( - label, - dist_attr={"process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1]}) - - mlp = MLPLayer( - hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - initializer_range=0.02) + label = static.data(name="label", + shape=[batch_size, 1], + dtype='float32') + + auto.shard_tensor(input, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1] + }) + auto.shard_tensor(label, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) + + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) predict = mlp(input) error_cost = paddle.nn.functional.square_error_cost(predict, label) @@ -137,13 +150,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) - params_grads = parallelizer._generate_backward( - complete_train_program, - startup_program, - loss, - parameter_list=None, - no_grad_set=None, - callbacks=None) + params_grads = parallelizer._generate_backward(complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) @@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id): if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[ 0]: send_result = True - if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ - 0]: + if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]: recv_result = True return send_result and recv_result @@ -206,6 +217,7 @@ def check_allgather(dist_main_program): class TestMLPReshard(unittest.TestCase): + def test_mlp_mppp(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -230,38 +242,29 @@ class TestMLPReshard(unittest.TestCase): process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') - x = auto.shard_tensor( - x, - dist_attr={ - "process_mesh": process_mesh, - "dims_mapping": [0, -1] - }) + x = auto.shard_tensor(x, + dist_attr={ + "process_mesh": process_mesh, + "dims_mapping": [0, -1] + }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') - w = auto.shard_tensor( - w, - dist_attr={ - "process_mesh": process_mesh, - "dims_mapping": [-1, -1] - }) - - # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { - # x.name: [-1, -1], - # w.name: [-1, -1] - # }, **{"x": x, - # "y": w})[0] - - y = paddle.distributed.shard_op( - paddle.matmul, - dist_attr={ - "process_mesh": process_mesh, - x: { - "dims_mapping": [-1, -1] - }, - w: { - "dims_mapping": [-1, -1] - } - })(x, w)[0] + w = auto.shard_tensor(w, + dist_attr={ + "process_mesh": process_mesh, + "dims_mapping": [-1, -1] + }) + + y = paddle.distributed.shard_op(paddle.matmul, + dist_attr={ + "process_mesh": process_mesh, + x: { + "dims_mapping": [-1, -1] + }, + w: { + "dims_mapping": [-1, -1] + } + })(x, w) rank_id = 0 dist_context = DistributedContext()