未验证 提交 f6161d1e 编写于 作者: Y Yichen Zhang 提交者: GitHub

[Semi-Auto] Add transpose spmd rule (#55350)

* [Semi-Auto] Add transpose spmd rule

* add unit test in cmake file

* log perm info
上级 a9f877ff
......@@ -77,7 +77,7 @@ class SPMDRuleBase {
PADDLE_ENFORCE_NE(iter,
attrs.end(),
paddle::platform::errors::NotFound(
"(%s) is not found in AttributeMap."));
"(%s) is not found in AttributeMap.", name));
return iter->second;
}
};
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
......@@ -155,6 +156,9 @@ REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);
REGISTER_SPMD_RULE(split, SplitSPMDRule);
REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);
// transpose rule
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
TransposeSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Transpose Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in transpose must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "transpose");
// step1: Build Einsum Notation
std::vector<int64_t> perm_dims =
ExtractAttr<std::vector<int64_t>>("perm", attrs);
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// get einsum notation for input
int64_t ndim = input_specs[0].shape().size();
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
for (int64_t i = 0, n = perm_dims.size(); i < n; ++i) {
// convert the negative dim value to normal dim value
if (perm_dims[i] < 0) {
perm_dims[i] = ndim + perm_dims[i];
}
}
std::string output_axes = "";
for (int64_t i = 0; i < ndim; i++) {
output_axes.append(1, input_axes[perm_dims[i]]);
}
// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
// step2.2: infer output dimsmapping from merged input dimsmapping
std::vector<int64_t> output_dims_mapping =
GetDimsMappingForAxes(output_axes, axis_to_dim_map);
// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr output_dist_attr =
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping);
// Step2.3 handle input tensor partial (TODO)
VLOG(4) << "TransposeSPMDRule InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "] "
<< "perm: [" << str_join(perm_dims) << "] "
<< "dst_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "]";
}
VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "]\n\n";
return {{input_specs[0].dist_attr()}, {output_dist_attr}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
TransposeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of TransposeSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
class TransposeSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_matmul_rule MODULES test_replicated_rule)
py_test_modules(test_matmul_rule MODULES test_softmax_rule)
py_test_modules(test_split_rule MODULES test_split_rule)
py_test_modules(test_transpose_rule MODULES test_transpose_rule)
# End of unittests WITH single card WITHOUT timeout
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestTransposeSPMDRule(unittest.TestCase):
"""
Unit tests for reduction spmd rule.
"""
def setUp(self):
self.rule = get_spmd_rule("transpose")
x_shape = [64, 36]
process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [1, 0]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = {
'perm': [0, 1, 2, 3],
}
def test_single_mesh_dim(self):
# perm = [1, 0]
# [0, -1] --> [0, -1], [-1, 0]
self.attrs['perm'] = [1, 0]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
# perm = [0, 1]
# [0, -1] --> [0, -1], [0, -1]
self.attrs['perm'] = [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
# perm = [0, 2, 3, 1]
# [-1, -1, 0, -1] --> [-1, -1, 0, -1], [-1, 0, -1, -1]
self.x_dist_tensor_spec.shape = [64, 48, 36, 24]
self.attrs['perm'] = [0, 2, 3, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, -1]
)
def test_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [64, 48, 36, 24]
# perm = [0, 2, 3, 1]
# [-1, 0, 1, -1] --> [-1, 0, 1, -1], [-1, 1, -1, 0]
self.attrs['perm'] = [0, 2, 3, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
# perm = [0, 2, 3, 1]
# [-1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1]
self.attrs['perm'] = [0, 2, 3, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
# perm = [-1, 0, -2, 1]
# [-1, -1, 0, 1] --> [-1, -1, 0, 1], [1, -1, 0, -1]
self.attrs['perm'] = [-1, 0, -2, 1]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册