未验证 提交 885d1aec 编写于 作者: J JZ-LIANG 提交者: GitHub

[Semi Auto] Softmax SPMD Rule (#55196)

* resolute input sharding conflict maybe

* fixed comment

---------
Co-authored-by: NYichen Zhang <zhangyichen03@baidu.com>
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 bb0df468
......@@ -17,11 +17,14 @@ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
......@@ -40,7 +43,8 @@ SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs,
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<std::string, std::vector<int64_t>>>&
tensor_axes_to_dim_pairs) {
tensor_axes_to_dim_pairs,
const bool merge_conflicts) {
std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;
......@@ -74,11 +78,18 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
// memory or communication or computation).
for (auto& it : dim_to_axis_map) {
if (it.second.size() > 1) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (size_t i = 1; i < it.second.size(); ++i) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
if (merge_conflicts) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (size_t i = 1; i < it.second.size(); ++i) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
}
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Multiple Tensor Axes [%s] is sharded by same mesh dimension [%d].",
str_join(it.second),
it.first));
}
}
}
......
......@@ -86,7 +86,8 @@ class SPMDRuleBase {
// The same axes of different tensors will be merged.
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<std::string, std::vector<int64_t>>>&
tensor_axes_to_dim_pairs);
tensor_axes_to_dim_pairs,
const bool merge_conflicts = true);
// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension.
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
// step0: verify input args based on embedding logic
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of embedding should be 2, but got [%d].",
input_specs_size));
auto x_shape = input_specs[0].shape();
auto weight_shape = input_specs[1].shape();
int x_ndim = x_shape.size();
int weight_ndim = weight_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
auto weight_dist_attr_src = input_specs[1].dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> weight_dims_mapping =
weight_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
PADDLE_ENFORCE_EQ(
weight_ndim,
weight_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of W's tensor size: [%d] and W's dims_mapping size [%d].",
weight_ndim,
weight_dims_mapping.size()));
PADDLE_ENFORCE_EQ(
weight_ndim,
2,
phi::errors::InvalidArgument("Embedding table should have TWO dimension, "
"but got a tensor with [%d] dimension.",
weight_ndim));
int64_t padding_idx = ExtractAttr<int64_t>("padding_idx", attrs);
bool sparse = ExtractAttr<bool>("sparse", attrs);
// determine parallel mode
int64_t weight_row_axis_mapping = weight_dims_mapping[0];
// padding_idx s not supported by c_embedding kernel.
// (TODO) might be could reshard as replicated when padding_idx != -1
if (padding_idx != -1) {
PADDLE_ENFORCE_EQ(
weight_row_axis_mapping,
-1,
phi::errors::InvalidArgument(
"Row-wise parallel of embedding table does NOT support Padding "
"Idx, "
"but got padding_idx [%d] and row axis of embedding table is "
"sharded by mesh dimension [%d].",
padding_idx,
weight_ndim));
}
// (TODO) might be could reshard as replicated when sparse
if (sparse) {
PADDLE_ENFORCE_EQ(
weight_row_axis_mapping,
-1,
phi::errors::InvalidArgument(
"Row-wise parallel of embedding table does NOT support Sparse, but "
"row axis of embedding table is sharded by mesh dimension [%d].",
padding_idx,
weight_ndim));
}
VLOG(6) << "EmbeddingSPMDRule InferForward Inputs: "
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
<< str_join(x_dims_mapping) << "]; Weight shape: ["
<< str_join(weight_shape) << "], weight_dims_mapping: ["
<< str_join(weight_dims_mapping) << "]; padding_idx: "
<< "[" << padding_idx << "]; "
<< "sparse: "
<< "[" << (sparse ? "true" : "false") << "]; ";
// step1: build Einsum Notation
std::string alphabet = "abcdefghilmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
std::string weight_axes = "jk";
std::string out_axes = x_axes + "k";
// step2: Sharding Propogation
auto axis_to_dim_map = ShardingMergeForTensors(
{{x_axes, x_dims_mapping}, {weight_axes, weight_dims_mapping}}, false);
// step3: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
for (size_t i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step3.1: Handle Partial
// (TODO) support case where embedding table is partial in very beginning.
std::vector<int64_t> partial_on_dims;
if (weight_row_axis_mapping > -1) {
partial_on_dims.push_back(weight_row_axis_mapping);
}
// step4: merge potential conflict in inputs
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
TensorDistAttr weight_dist_attr_dst =
CopyTensorDistAttrForOutput(weight_dist_attr_src);
weight_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(weight_axes, axis_to_dim_map));
VLOG(4) << "EmbeddingSPMDRule InferForward: "
<< "Einsum notation: [" << x_axes << "," << weight_axes << " --> "
<< out_axes << "]. " << std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; Y shape: ["
<< str_join(weight_shape) << "], src_dims_mapping: ["
<< str_join(weight_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(weight_dist_attr_dst.dims_mapping())
<< "]; Output dims_mapping: [" << str_join(out_dims_mapping)
<< "], partial_on_dims: [" << str_join(partial_on_dims) << "]";
return {{x_dist_attr_dst, weight_dist_attr_dst}, {output_dist_attr_dst}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
EmbeddingSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of EmbeddingSPMDRule is NOT implemented yet."));
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
// (TODO) Support 3 parallel cases for embedding:
// 1. Batch dimensions of input ids is sharded on mesh.
// 2. Row-wise Parallel of embedding table. (NOTE: Row-wise Parallel need to
// change the embedding kernel for miss ids.)
// 3. Column-wise Parallel of embedding table.
// 4. Hybrid Parallelism of above 3 cases.
class EmbeddingSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -51,14 +51,13 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
y_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
y_ndim,
y_dims_mapping.size()));
bool trans_x = ExtractAttr<bool>("trans_x", attrs);
bool trans_y = ExtractAttr<bool>("trans_y", attrs);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward Inputs: "
VLOG(6) << "MatmulSPMDRule InferForward Inputs: "
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
<< str_join(x_dims_mapping) << "]; Y shape: [" << str_join(y_shape)
<< "], y_dims_mapping: [" << str_join(y_dims_mapping)
......@@ -117,9 +116,6 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
y_ndim));
}
VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_axes << ","
<< y_axes << " --> " << out_axes << "].";
// step2: Sharding Propogation
if (trans_x) {
PADDLE_ENFORCE_GE(
......@@ -167,6 +163,8 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: "
<< "Einsum notation: [" << x_axes << "," << y_axes << " --> "
<< out_axes << "]. " << std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dist_attr_src.dims_mapping())
<< "], dst_dims_mapping: ["
......
......@@ -16,10 +16,12 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
......@@ -135,6 +137,14 @@ REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule);
// replicated rule
REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);
// embedding rule
REGISTER_SPMD_RULE(embedding, EmbeddingSPMDRule);
REGISTER_SPMD_RULE(lookup_table_v2, EmbeddingSPMDRule);
// softmax rule
REGISTER_SPMD_RULE(softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
// step0: verify input args based on softmax logic
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
1,
phi::errors::InvalidArgument(
"The size of InputSpec of softmax should be 1, but got [%d].",
input_specs_size));
auto x_shape = input_specs[0].shape();
int x_ndim = x_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
int axis = ExtractAttr<int>("axis", attrs);
VLOG(6) << "SoftmaxSPMDRule InferForward Inputs: "
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
<< str_join(x_dims_mapping) << "]; axis: "
<< "[" << axis << "]; ";
// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}
// step1: build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
std::string out_axes = x_axes;
// step2: Sharding Propogation
// naive support for sharding on softmax_axis
// softmax_axis should be resharded as replicated (TODO: support sharding on
// softmax_axis effeciently)
if (x_dims_mapping[axis] >= 0) {
x_dims_mapping[axis] = -1;
VLOG(6) << "SoftmaxSPMDRule InferForward: softmax axis is reshard to be "
"replicated: "
<< "original dims_mapping["
<< str_join(x_dist_attr_src.dims_mapping()) << "], "
<< "resharded dims_mapping[" << str_join(x_dims_mapping) << "].";
}
// Avoid multiple tensor axes sharded by same mesh deminsion
auto axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}}, false);
// step3: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
for (size_t i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// Update x's dist_attr
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
VLOG(4) << "EmbeddingSPMDRule InferForward: "
<< "Einsum notation: [" << x_axes << " --> " << out_axes << "]. "
<< std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dist_attr_src.dims_mapping())
<< "], dst_dims_mapping: [" << str_join(x_dims_mapping)
<< "]; Output dims_mapping: [" << str_join(out_dims_mapping) << "]";
return {{x_dist_attr_dst}, {output_dist_attr_dst}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SoftmaxSPMDRule is NOT implemented yet."));
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
// (TODO) Support 2 kind of parallel:
// 1. sharding on batch axes (any axis that is not to be softmax normalized) of
// tensor.
// 2. sharding on normalized axis of tensor. (naive support by now, effecient
// support need to change the softmax kernel).
class SoftmaxSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -5,6 +5,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
# NOTE(zyl): unittests WITH single card and WITHOUT timeout
py_test_modules(test_matmul_rule MODULES test_matmul_rule)
py_test_modules(test_matmul_rule MODULES test_embedding_rule)
py_test_modules(test_matmul_rule MODULES test_replicated_rule)
py_test_modules(test_matmul_rule MODULES test_softmax_rule)
# End of unittests WITH single card WITHOUT timeout
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestEmbeddingSPMDRule(unittest.TestCase):
def setUp(self):
self.rule1 = get_spmd_rule("lookup_table_v2")
x_shape = [4, 1024] # [B,S]
table_shape = [512, 768] # [V,H]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
table_tensor_dist_attr = TensorDistAttr()
table_tensor_dist_attr.process_mesh = process_mesh
self.table_dist_tensor_spec = DistTensorSpec(
table_shape, table_tensor_dist_attr
)
self.attrs = {
'padding_idx': -1,
'sparse': False,
}
def test_embedding_infer_forward(self):
# data parallel
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# table col-wise parallel & dp
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([-1, 0])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, 0])
# table row-wise parallel & dp
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# table row-wise parallel & padding_idx
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([0, -1])
self.attrs['padding_idx'] = 128
with self.assertRaises(ValueError):
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec],
self.attrs,
)
# table row-wise parallel & sparse
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([0, -1])
self.attrs['padding_idx'] = -1
self.attrs['sparse'] = True
with self.assertRaises(ValueError):
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec],
self.attrs,
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestEmbeddingSPMDRule(unittest.TestCase):
def setUp(self):
self.rule1 = get_spmd_rule("softmax")
self.rule2 = get_spmd_rule("log_softmax")
x_shape = [8, 16, 48]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = {
'axis': -1,
}
def test_softmax_infer_forward(self):
# sharding on batch axis I
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# sharding on batch axis II
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1])
# sharding on softmax_axis
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# sharding on softmax_axis + axis = 1
self.attrs = {
'axis': 1,
}
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
# sharding on softmax_axis + axis = -2
self.attrs = {
'axis': -2,
}
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册