未验证 提交 4d1b9f04 编写于 作者: L Leo Chen 提交者: GitHub

[Semi-Auto] LayerNorm Parallel Rule (#55130)

* add layernorm spmd rule

* add ut

* follow comments
上级 2ff949da
file(GLOB SPMD_SRCS "*.cc")
cc_library( cc_library(
spmd_rule spmd_rule
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc replicated_spmd_rule.cc SRCS ${SPMD_SRCS}
DEPS phi) DEPS phi)
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
LayerNormSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on layer_norm logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
3,
phi::errors::InvalidArgument(
"The size of InputSpec of layer_norm should be 3, but got [%d].",
input_specs_size));
auto x_shape = input_specs[0].shape();
auto scale_shape = input_specs[1].shape();
auto bias_shape = input_specs[2].shape();
int x_ndim = x_shape.size();
int scale_ndim = scale_shape.size();
int bias_ndim = bias_shape.size();
PADDLE_ENFORCE_EQ(
scale_ndim,
1,
phi::errors::InvalidArgument(
"The ndim of scale in layer_norm should be 1, but got [%d].",
scale_ndim));
PADDLE_ENFORCE_EQ(
bias_ndim,
1,
phi::errors::InvalidArgument(
"The ndim of bias in layer_norm should be 1, but got [%d].",
bias_ndim));
auto x_dims_mapping = input_specs[0].dist_attr().dims_mapping();
auto scale_dims_mapping = input_specs[1].dist_attr().dims_mapping();
auto bias_dims_mapping = input_specs[2].dist_attr().dims_mapping();
auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<TensorDistAttr> input_dist_attrs;
input_dist_attrs.reserve(input_specs.size());
int begin_norm_axis = ExtractAttr<int>("begin_norm_axis", attrs);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "LayerNormSPMDRule InferForward Inputs: "
<< "x shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
<< str_join(x_dims_mapping) << "]; scale shape: ["
<< str_join(scale_shape) << "], scale_dims_mapping: ["
<< str_join(scale_dims_mapping) << "]; bias shape: ["
<< str_join(bias_shape) << "], bias_dims_mapping: ["
<< str_join(bias_dims_mapping) << "]; begin_norm_axis: ["
<< begin_norm_axis << "]; ";
// step1: build Einsum Notation
// ijk,k,k->ijk,x,x (x,scale,bias->out,mean,variance, begin_norm_axis=2, x=ij)
// ijkl,y(kl),y(kl)->ijkl,x(ij),x(ij) (x,scale,bias->out,mean,variance,
// begin_norm_axis=2, x=ij, y=kl)
std::string x_axes = "";
for (auto i = 0; i < x_ndim; ++i) {
x_axes += static_cast<char>(static_cast<int>('k') - begin_norm_axis + i);
}
std::string scale_axes;
std::string bias_axes;
if (x_ndim - begin_norm_axis == 1) {
scale_axes = "k";
bias_axes = "k";
} else {
// z = x_axes.substr(begin_norm_axis, x_ndim - begin_norm_axis)
scale_axes = "y";
bias_axes = "y";
}
std::string mean_axes;
std::string variance_axes;
if (begin_norm_axis > 1) {
mean_axes = "x";
variance_axes = "x";
} else {
mean_axes = "j";
variance_axes = "j";
}
std::string out_axes = x_axes;
VLOG(4) << "LayerNormSPMDRule build Einsum notation (x,scale,bias->out): ["
<< x_axes << "," << scale_axes << "," << bias_axes << " --> "
<< out_axes << "," << mean_axes << "," << variance_axes
<< "](begin_norm_axis:" << begin_norm_axis
<< ",x=" << x_axes.substr(0, begin_norm_axis)
<< ",y=" << x_axes.substr(begin_norm_axis, x_ndim - begin_norm_axis)
<< ").";
// step2: Sharding Propogation
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr mean_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr varience_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
int64_t mean_shard_dim = -1;
for (size_t i = 0; i < out_axes.size(); ++i) {
if (i < static_cast<size_t>(begin_norm_axis)) {
out_dims_mapping.push_back(x_dims_mapping[i]);
// if ijk,k,k->ijk,x,x (x,scale,bias->out,mean,variance,
// begin_norm_axis=2, x=ij), and the dims_mapping of input is (0,1,-1),
// the mean and varience is sharded by dim 0 and 1,
// which is not supported currently.
mean_shard_dim =
ShardingMergeForAxis(mean_axes, mean_shard_dim, x_dims_mapping[i]);
} else {
out_dims_mapping.push_back(-1);
}
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
mean_dist_attr_dst.set_dims_mapping({mean_shard_dim});
varience_dist_attr_dst.set_dims_mapping({mean_shard_dim});
// step2.3: Merge and get Inputs' New Dims Mapping.
x_dist_attr_dst.set_dims_mapping(out_dims_mapping);
input_dist_attrs.emplace_back(x_dist_attr_dst);
// TODO(zhiqiu): support shardding on scale and bias
// Now, apply replicating.
input_dist_attrs.emplace_back(ReplicatedOnMesh(input_specs[1].dist_attr()));
input_dist_attrs.emplace_back(ReplicatedOnMesh(input_specs[2].dist_attr()));
// Step2.4. handle input and out tensor partial
// LayerNorm not support
VLOG(4) << "LayerNormSPMDRule InferForward: "
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; scale shape: ["
<< str_join(scale_shape) << "], src_dims_mapping: ["
<< str_join(scale_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(input_dist_attrs[1].dims_mapping()) << "]; bias shape: ["
<< str_join(bias_shape) << "], src_dims_mapping: ["
<< str_join(bias_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(input_dist_attrs[2].dims_mapping())
<< "]; out dims_mapping: [" << str_join(out_dims_mapping)
<< "]; mean dims_mapping: [" << mean_shard_dim
<< "]; varience dims_mapping: [" << mean_shard_dim
<< "], partial_on_dims: []";
return {input_dist_attrs,
{output_dist_attr_dst, mean_dist_attr_dst, varience_dist_attr_dst}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
LayerNormSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of LayerNormSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
class LayerNormSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
...@@ -26,6 +27,9 @@ namespace auto_parallel { ...@@ -26,6 +27,9 @@ namespace auto_parallel {
// matmul rule // matmul rule
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); REGISTER_SPMD_RULE(matmul, MatmulSPMDRule);
// matmul rule
REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule);
// replicated rule // replicated rule
REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule); REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);
......
...@@ -38,13 +38,11 @@ TEST(MatmulSPMDRule, Ctor) { ...@@ -38,13 +38,11 @@ TEST(MatmulSPMDRule, Ctor) {
TensorDistAttr x_dist_attr = TensorDistAttr(); TensorDistAttr x_dist_attr = TensorDistAttr();
x_dist_attr.set_process_mesh(process_mesh); x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({1, -1})); x_dist_attr.set_dims_mapping(std::vector<int64_t>({1, -1}));
x_dist_attr.set_batch_dim(-1);
x_dist_attr.set_dynamic_dims(std::vector<bool>({false, false})); x_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
TensorDistAttr y_dist_attr = TensorDistAttr(); TensorDistAttr y_dist_attr = TensorDistAttr();
y_dist_attr.set_process_mesh(process_mesh); y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, -1})); y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, -1}));
y_dist_attr.set_batch_dim(-1);
y_dist_attr.set_dynamic_dims(std::vector<bool>({false, false})); y_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr);
...@@ -201,6 +199,101 @@ TEST(MatmulSPMDRule, Ctor) { ...@@ -201,6 +199,101 @@ TEST(MatmulSPMDRule, Ctor) {
VLOG(4) << "test10 done." << std::endl << std::endl << std::endl; VLOG(4) << "test10 done." << std::endl << std::endl << std::endl;
} }
TEST(LayerNormSPMDRule, Ctor) {
// build input data class
std::vector<int64_t> x_shape = {64, 32, 1024};
std::vector<int64_t> scale_shape = {1024};
std::vector<int64_t> bias_shape = {1024};
std::vector<int64_t> mesh_shape = {2, 3};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
TensorDistAttr x_dist_attr = TensorDistAttr();
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({1, -1, -1}));
x_dist_attr.set_dynamic_dims(std::vector<bool>({false, false, false}));
TensorDistAttr scale_dist_attr = TensorDistAttr();
scale_dist_attr.set_process_mesh(process_mesh);
scale_dist_attr.set_dims_mapping(std::vector<int64_t>({-1}));
scale_dist_attr.set_dynamic_dims(std::vector<bool>({false}));
TensorDistAttr bias_dist_attr = TensorDistAttr();
bias_dist_attr.set_process_mesh(process_mesh);
bias_dist_attr.set_dims_mapping(std::vector<int64_t>({-1}));
bias_dist_attr.set_dynamic_dims(std::vector<bool>({false}));
DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr);
DistTensorSpec scale_dist_tensor_spec =
DistTensorSpec(scale_shape, scale_dist_attr);
DistTensorSpec bias_dist_tensor_spec =
DistTensorSpec(bias_shape, bias_dist_attr);
paddle::framework::AttributeMap attrs;
attrs["begin_norm_axis"] = 2;
SPMDRuleBase* layer_norm_rule = SPMDRuleMap::Instance().Get("layer_norm");
// ijk[1, -1, -1],k[-1],k[-1] --> ijk[1, -1, -1] partial[1]
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infered_dist_attrs = layer_norm_rule->InferForward(
{x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec},
attrs);
size_t input_size = 3;
size_t output_size = 3;
EXPECT_EQ(infered_dist_attrs.first.size(), input_size);
EXPECT_EQ(infered_dist_attrs.second.size(), output_size);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1}));
EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(),
std::vector<int64_t>({-1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(),
std::vector<int64_t>({1}));
EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(),
std::vector<int64_t>({1}));
VLOG(4) << "test1 done.";
// ijk[1, 0, -1],k[0],k[0] --> ijk[1, 0, -1]
x_dist_tensor_spec.set_dims_mapping({1, 0, -1});
scale_dist_tensor_spec.set_dims_mapping({0});
bias_dist_tensor_spec.set_dims_mapping({0});
EXPECT_ANY_THROW(
infered_dist_attrs = layer_norm_rule->InferForward(
{x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec},
attrs););
VLOG(4) << "test2 done.";
// ijk[0, -1, -1],z[-1],z[1] --> ijk[0, 1, -1, -1], z=jk
x_dist_tensor_spec.set_dims_mapping({0, -1, -1});
scale_dist_tensor_spec.set_dims_mapping({-1});
bias_dist_tensor_spec.set_dims_mapping({1});
attrs["begin_norm_axis"] = 1;
infered_dist_attrs = layer_norm_rule->InferForward(
{x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec},
attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({0, -1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1}));
EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(),
std::vector<int64_t>({-1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({0, -1, -1}));
EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(),
std::vector<int64_t>({0}));
EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(),
std::vector<int64_t>({0}));
VLOG(4) << "test2 done.";
}
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册