未验证 提交 a97b507e 编写于 作者: Y Yichen Zhang 提交者: GitHub

[Semi-Auto] Add reshape spmd rule (#55177)

* add reshape spmd rule

* add unit test for reshape spmd rule

* bug fix

* replace the print_info function with to_string

* fix typo

* bug fix

* add handling for "0" in target shape

* remove the part of computing size in dim_trans.cc
上级 e075a0dd
...@@ -125,14 +125,14 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, ...@@ -125,14 +125,14 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr); TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr);
// Check whether the given DistTensorSpec objects are valid. For each // Check whether the given DistTensorSpec objects are valid. For each
// DistTensorSpec, the rank of its dimsmapping must be equal to the rank of its // DistTensorSpec, the rank of its dims mapping must be equal to the rank of its
// corresponding tensor shape. the parameter op_name is used for logging error // corresponding tensor shape. the parameter op_name is used for logging error
// message. // message.
void VerifySpecs(const std::vector<DistTensorSpec>& specs, void VerifySpecs(const std::vector<DistTensorSpec>& specs,
const std::string& op_name); const std::string& op_name);
// Get dimsmapping for the given tensors. Return the pair of each // Get dims mapping for the given tensors. Return the pair of each
// tensor's einsum notation and the corresponding dimsmapping. // tensor's einsum notation and the corresponding dims mapping.
std::vector<std::pair<std::string, std::vector<int64_t>>> std::vector<std::pair<std::string, std::vector<int64_t>>>
GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes, GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes,
const std::vector<DistTensorSpec>& specs); const std::vector<DistTensorSpec>& specs);
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"
#include <assert.h>
#include <cstdio>
#include <numeric>
#include <set>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
static std::vector<DimTrans*> all_dim_trans;
DimTrans::DimTrans(Type type) : type_(type) {}
DimTrans::~DimTrans() {}
DimTrans::Type DimTrans::type() const { return type_; }
void DimTrans::set_type(Type type) { type_ = type; }
std::string DimTrans::to_string() { return std::string(""); }
InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = -1;
all_dim_trans.emplace_back(this);
}
InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = dim;
all_dim_trans.emplace_back(this);
}
InputDim::~InputDim() {}
int64_t InputDim::input_dim() const { return input_dim_; }
void InputDim::set_input_dim(int64_t dim) { input_dim_ = dim; }
std::string InputDim::to_string() {
return ("InputDim(" + std::to_string(input_dim_) + ")");
}
Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {
all_dim_trans.emplace_back(this);
}
std::string Singleton::to_string() { return "Singleton()"; }
Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {
all_dim_trans.emplace_back(this);
}
Flatten::Flatten(const std::vector<DimTrans*>& dims)
: DimTrans(DimTrans::Type::FLATTEN) {
input_dims_ = dims;
all_dim_trans.emplace_back(this);
}
Flatten::~Flatten() {
input_dims_.assign(input_dims_.size(), nullptr);
std::vector<DimTrans*>().swap(input_dims_);
}
const std::vector<DimTrans*>& Flatten::inputs() const { return input_dims_; }
void Flatten::set_inputs(const std::vector<DimTrans*>& dims) {
input_dims_.assign(dims.begin(), dims.end());
}
std::string Flatten::to_string() {
std::string ret_str("Flatten(");
for (int64_t i = 0, n = input_dims_.size(); i < n; ++i) {
ret_str += input_dims_[i]->to_string();
if (i < n - 1) {
ret_str += ",";
}
}
return ret_str + ")";
}
Split::Split() : DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = nullptr;
all_dim_trans.emplace_back(this);
}
Split::Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id)
: DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = dim;
split_id_ = id;
splitted_shape_.assign(shape.begin(), shape.end());
all_dim_trans.emplace_back(this);
}
Split::~Split() {
input_dim_trans_ = nullptr;
std::vector<int64_t>().swap(splitted_shape_);
}
DimTrans* Split::input() const { return input_dim_trans_; }
void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; }
int64_t Split::split_id() const { return split_id_; }
int64_t Split::local_splitted_shape_value() {
return splitted_shape_[split_id_];
}
std::string Split::to_string() {
std::string ret_str("Split(");
ret_str += input_dim_trans_->to_string() + ", (";
for (int64_t i = 0, n = splitted_shape_.size(); i < n; ++i) {
ret_str += std::to_string(splitted_shape_[i]);
if (i < n - 1) {
ret_str += ",";
}
}
return ret_str + "), " + std::to_string(split_id_) + ")";
}
DimTrans* make_flatten(const std::vector<DimTrans*>& dims) {
DimTrans* ptr = nullptr;
if (dims.size() == 0) {
ptr = new Singleton();
} else if (dims.size() == 1) {
ptr = dims[0];
} else {
ptr = new Flatten(dims);
}
return ptr;
}
DimTrans* make_split(DimTrans* dim,
const std::vector<int64_t>& shape,
int64_t id) {
assert(shape.size() > 0);
DimTrans* ptr = nullptr;
if (shape.size() == 1) {
assert(id == 0);
ptr = dim;
} else if (shape[id] == 1) {
ptr = new Singleton();
} else {
// new shape that remove 1
std::vector<int64_t> new_shape;
// map between from idx in shape to new_shape
std::vector<int64_t> idx_map(shape.size(), -1);
for (int64_t i = 0, n = shape.size(); i < n; ++i) {
if (shape[id] != 1) {
idx_map[i] = new_shape.size();
new_shape.emplace_back(shape[i]);
}
}
ptr = new Split(dim, new_shape, idx_map[id]);
}
return ptr;
}
void CleanUp() {
for (int64_t i = 0, n = all_dim_trans.size(); i < n; i++) {
if (all_dim_trans[i]) {
delete all_dim_trans[i];
all_dim_trans[i] = nullptr;
}
}
std::vector<DimTrans*>().swap(all_dim_trans);
}
// Given a `dim_trans` of an output axis, get the input axis
// whose dim mapping should be propogated to it.
// If the returned input axis is none, the output axis's
// dim mapping should be set to -1 (replicated). For an axis
// that is flattened from input axes, return the leftmost
// flattened input axis. For the split transformation,
// only the leftmost split axis in output will return its input.
DimTrans* GetDimTrans(DimTrans* dim_trans,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims) {
DimTrans::Type type = dim_trans->type();
DimTrans* ret_dim_trans = nullptr;
if (type == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);
if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans = dim_trans;
}
} else if (type == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
const std::vector<DimTrans*>& inputs = flatten->inputs();
int64_t nmesh = (*shardable)[0].size();
for (int64_t i = 1, n = inputs.size(); i < n; i++) {
DimTrans* input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
(*shardable)[i].assign(nmesh, false);
}
GetDimTrans(input,
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
}
DimTrans* dim0 = inputs[0];
if (dim0->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim0);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
ret_dim_trans = dim0;
}
}
} else if (type == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
DimTrans* dim = GetDimTrans(split->input(),
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
int64_t ret_size = split->local_splitted_shape_value();
if (split->split_id() == 0) {
if (dim != nullptr) {
PADDLE_ENFORCE_EQ(dim->type(),
DimTrans::Type::INPUTDIM,
phi::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
int64_t nmesh = mesh_shape.size();
int64_t input_axis = inputdim->input_dim();
// Check whether the sharded dim can be sharded on
// each mesh dimension. The dimension should be
// divisible by the mesh size that it is sharded on
for (int64_t imesh = 0; imesh < nmesh; imesh++) {
(*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0);
}
}
ret_dim_trans = dim;
}
} else if (type == DimTrans::Type::SINGLETON) {
ret_dim_trans = nullptr;
}
return ret_dim_trans;
}
void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
InputDim* input = dynamic_cast<InputDim*>(dim_trans);
seen_dims->insert(input->input_dim());
} else if (dim_trans->type() == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
for (DimTrans* trans : flatten->inputs()) {
GetUsedInputDim(trans, seen_dims);
}
} else if (dim_trans->type() == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
GetUsedInputDim(split->input(), seen_dims);
} else {
return;
}
}
std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans) {
const std::vector<int64_t>& input_shape = input_spec.shape();
const std::vector<int64_t>& input_dims_mapping = input_spec.dims_mapping();
const ProcessMesh& mesh = input_spec.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();
std::set<int64_t> sharded_input_dims;
for (int64_t i = 0, n = input_dims_mapping.size(); i < n; ++i) {
if (input_dims_mapping[i] > -1) {
sharded_input_dims.insert(i);
}
}
int64_t ndim = input_shape.size();
int64_t nmesh = mesh_shape.size();
std::vector<std::vector<bool>> shardable(ndim,
std::vector<bool>(nmesh, true));
std::set<int64_t> seen_input_dims;
for (DimTrans* trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}
for (int64_t idim = 0; idim < ndim; idim++) {
bool seen = seen_input_dims.count(idim);
if (!seen) {
shardable[idim].assign(nmesh, seen);
}
}
// get the map from sharded input dimensions to output dimensions.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
for (int64_t i = 0, n = dim_trans.size(); i < n; i++) {
DimTrans* dim = GetDimTrans(dim_trans[i],
&shardable,
&seen_input_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
}
}
std::vector<int64_t> out_dims_mapping(dim_trans.size(), -1);
std::vector<int64_t> new_input_dims_mapping(input_dims_mapping);
// set output dims mapping with corresponding input dimensions.
// if one input dimension is sharded on a unshardable mesh after
// splitting, we need to make it replicated.
for (int64_t i = 0; i < ndim; i++) {
int64_t mesh_dim = input_dims_mapping[i];
if (mesh_dim > -1 && shardable[i][mesh_dim] && dim_map_src2tgt[i] > -1) {
out_dims_mapping[dim_map_src2tgt[i]] = input_dims_mapping[i];
} else {
new_input_dims_mapping[i] = -1;
}
}
return {new_input_dims_mapping, out_dims_mapping};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
// This is a base class to describe how each dimension in output tensor
// is transformed from input tensor's axes. The transformation includes
// Flatten, Split, etc. A vector<DimTrans*> whose size equals to the
// output tensor's rank can be used to describe how the output shape is
// transformed from the input shape. Each element in vector<DimTrans*>
// describes the transformation of one output axis. For example, when
// a reshape operator reshapes a tensor from the shape of (6, 12, 48)
// to (72, 6, 8), this transfromation can be described as:
// [Flatten(Dim(0), Dim(1)), Split(Dim(2), (6,8), 0), Split(Dim(2), (6,8), 1)]
// meaning that dim0 in output is flattened from dim0 and dim1 in input,
// dim1 and dim2 in output are obtained by splitting dim2 in input, the
// splitted shape is (6, 8), dim1 referes to the first shape value in (6, 8)
// and dim2 referes to the second shape value in (6, 8).
class DimTrans {
public:
enum class Type { INPUTDIM, SINGLETON, FLATTEN, SPLIT };
DimTrans() = default;
explicit DimTrans(Type type);
virtual ~DimTrans();
Type type() const;
void set_type(Type type);
virtual std::string to_string();
private:
Type type_;
};
// InputDim indicates that the output dimention
// is obtained directed from one input dimension.
class InputDim : public DimTrans {
public:
InputDim();
explicit InputDim(int64_t dim);
virtual ~InputDim();
int64_t input_dim() const;
void set_input_dim(int64_t dim);
std::string to_string() override;
private:
int64_t input_dim_;
};
// Singleton indicates that the shape of the
// corresponding output dimension is 1
class Singleton : public DimTrans {
public:
Singleton();
std::string to_string() override;
};
// Flatten indicates that the output dimension
// is obtained from flattening input dimensions.
class Flatten : public DimTrans {
public:
Flatten();
explicit Flatten(const std::vector<DimTrans*>& dims);
virtual ~Flatten();
const std::vector<DimTrans*>& inputs() const;
void set_inputs(const std::vector<DimTrans*>& dims);
std::string to_string() override;
private:
std::vector<DimTrans*> input_dims_;
};
// Split indicates that the output dimension
// is obtained by splitting input dimension.
class Split : public DimTrans {
public:
Split();
Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id);
virtual ~Split();
DimTrans* input() const;
void set_input(DimTrans* dim);
int64_t split_id() const;
// get the splitted shape value of the split_id_ dimension
int64_t local_splitted_shape_value();
std::string to_string() override;
private:
DimTrans* input_dim_trans_;
std::vector<int64_t> splitted_shape_;
int64_t split_id_;
};
void CleanUp();
DimTrans* make_flatten(const std::vector<DimTrans*>& dims = {});
DimTrans* make_split(DimTrans* dim,
const std::vector<int64_t>& shape = {},
int64_t id = 0);
// Infer the dims mapping of the output tensor according to the transformation
// `dim_trans`. Returns the dims mapping of the input tensor (the input dims
// mapping may be changed for resharding) and output tensor. The inferring
// follows the rules:
// 1. For Singleton, i.e., the shape of this output axis is 1, its dim mapping
// is -1, indicating that the output axis is replicated.
// 2. For InputDim, i.e., the output axis is transformed directly from an input
// axis, set its dim mapping equals to the corresponding input axis.
// 3. For Flatten, i.e., the output axis is flattened from some input axes, it
// can be sharded only if the leftmost flattened axes is sharded.
// 4. For Split, i.e., the output axes is splited from a input axis, only the
// leftmost output split axis can be sharded when its shape can be divisible
// by the mesh dimension.
std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include <numeric>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
// The target shape in reshape may contains a -1 dimension,
// this function is used to infer what the "-1" dimension is.
std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
int64_t len) {
int64_t infer_idx = -1;
for (int64_t i = 0, n = shape.size(); i < n; i++) {
if (shape[i] == -1) {
PADDLE_ENFORCE_EQ(
infer_idx,
-1,
phi::errors::InvalidArgument(
"There can't be more than one -1 dimension in target shape."));
infer_idx = i;
}
}
int64_t product = std::accumulate(
shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
if (product > 0) {
PADDLE_ENFORCE_EQ(
product,
len,
phi::errors::InvalidArgument("The total size are not matched"));
return std::vector<int64_t>(shape);
} else {
std::vector<int64_t> new_shape(shape);
product = -product;
int64_t infer_size = len / product;
PADDLE_ENFORCE_EQ(len % infer_size,
0,
phi::errors::InvalidArgument(
"The total is not diviable by infer_size"));
new_shape[infer_idx] = infer_size;
return new_shape;
}
}
// Compute how each dimension in target shape
// is obtained from the input dimensions
std::vector<DimTrans*> MakeReshapeDimTrans(
const std::vector<int64_t>& src_shape,
const std::vector<int64_t>& tgt_shape) {
std::vector<DimTrans*> ret;
int64_t total_elem_num_src = std::accumulate(
src_shape.begin(), src_shape.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> inferred_tgt_shape =
InferTargetShape(tgt_shape, total_elem_num_src);
int64_t src_idx = 0, tgt_idx = 0;
int64_t s, t;
int64_t src_len, tgt_len;
src_len = src_shape.size();
tgt_len = inferred_tgt_shape.size();
while (src_idx < src_len || tgt_idx < tgt_len) {
std::vector<int64_t> src_dims, tgt_splitted_shape;
if (src_idx >= src_len) {
s = 1;
} else {
s = src_shape[src_idx];
src_dims.emplace_back(src_idx);
src_idx++;
}
if (tgt_idx >= tgt_len) {
t = 1;
} else {
t = tgt_shape[tgt_idx];
tgt_splitted_shape.emplace_back(t);
tgt_idx++;
}
// deal with the singleton case
if (s == 1 && t != 1) {
// case [1] [a]
tgt_idx--;
tgt_splitted_shape.clear();
} else if (s != 1 && t == 1) {
src_idx--;
src_dims.clear();
} else {
while (s != t) {
if (s < t) {
src_dims.emplace_back(src_idx);
s *= src_shape[src_idx];
src_idx++;
} else {
tgt_splitted_shape.emplace_back(inferred_tgt_shape[tgt_idx]);
t *= inferred_tgt_shape[tgt_idx];
tgt_idx++;
}
}
}
if (tgt_splitted_shape.size() > 0) {
std::vector<DimTrans*> input_dims;
for (int64_t i = 0, n = src_dims.size(); i < n; i++) {
int64_t in_dim = src_dims[i];
if (src_shape[in_dim] > 1) {
input_dims.emplace_back(new InputDim(in_dim));
}
}
DimTrans* flatten = make_flatten(input_dims);
for (int64_t i = 0, n = tgt_splitted_shape.size(); i < n; i++) {
ret.emplace_back(make_split(flatten, tgt_splitted_shape, i));
}
}
}
return ret;
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reshape");
// step1: build the transformation from
// original shape to target shape
std::vector<int64_t> src_shape = input_specs[0].shape();
std::vector<int64_t> tgt_shape =
ExtractAttr<std::vector<int64_t>>("shape", attrs);
// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
for (int64_t i = 0, n = tgt_shape.size(); i < n; i++) {
if (tgt_shape[i] == 0) {
tgt_shape[i] = src_shape[i];
}
}
std::vector<DimTrans*> trans = MakeReshapeDimTrans(src_shape, tgt_shape);
// step2: infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(input_specs[0], trans);
// step3: update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_input_dist_attr(input_specs[0].dist_attr());
new_input_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr output_dist_attr(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape)
<< "] output_shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tOutput axis " << i << ": " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";
CleanUp();
return {{new_input_dist_attr}, {output_dist_attr}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ReductionSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
class ReshapeSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
...@@ -159,6 +160,9 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); ...@@ -159,6 +160,9 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);
// transpose rule // transpose rule
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule); REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);
// reshape rule
REGISTER_SPMD_RULE(reshape, ReshapeSPMDRule);
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -10,6 +10,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -10,6 +10,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_matmul_rule MODULES test_softmax_rule) py_test_modules(test_matmul_rule MODULES test_softmax_rule)
py_test_modules(test_split_rule MODULES test_split_rule) py_test_modules(test_split_rule MODULES test_split_rule)
py_test_modules(test_transpose_rule MODULES test_transpose_rule) py_test_modules(test_transpose_rule MODULES test_transpose_rule)
py_test_modules(test_reshape_rule MODULES test_reshape_rule)
# End of unittests WITH single card WITHOUT timeout # End of unittests WITH single card WITHOUT timeout
endif() endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestReshapeSPMDRule(unittest.TestCase):
def setUp(self):
self.rule = get_spmd_rule("reshape")
x_shape = [6, 12, 48, 24]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, -1]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = {"shape": [1, 72, 48, 4, 6]}
def test_reshape_infer_forward(self):
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6]
# dims_mapping: [0, -1, 1, -1] --> [0, -1, 1, -1] [-1, 0, 1, -1, -1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6]
# dims_mapping: [-1, 0, -1, 1] --> [-1, -1, -1, -1] [-1, -1, -1, -1, -1]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6]
# dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24]
# dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1, -1]
self.attrs["shape"] = [3, 24, 6, 8, 24]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24]
# dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [1, -1, -1, -1, 0]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1, 0]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, 24]
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, -1, 1]
self.attrs["shape"] = [3, 24, 6, -1, 24]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1, 1]
)
# shape: [6, 12, 48, 24] --> [1, 72, 0, 4, 6]
# dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1]
self.attrs["shape"] = [1, 72, 0, 4, 6]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1]
)
# shape: [6, 12, 48, 24] --> [6, 12, 48, 24]
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, 1]
self.attrs["shape"] = [6, 12, 48, 24]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24]
# dims_mapping: [0, -1, 1, -1] --> [0, -1, 1, -1], [0, 1, -1, -1]
self.attrs["shape"] = [72, 3, 16, 24]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24]
# dims_mapping: [1, -1, 0, -1] --> [1, -1, -1, -1], [1, -1, -1, -1]
self.attrs["shape"] = [72, 3, 16, 24]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1]
# raise error
self.attrs["shape"] = [3, 24, 6, -1, -1]
with self.assertRaises(BaseException):
self.rule.infer_forward([self.x_dist_tensor_spec], self.attrs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册