提交 701d3fa0 编写于 作者: L liangjianzhong

broadcast func

上级 ed0c31e6
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
......@@ -42,7 +44,7 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
int64_t merge_dim;
for (auto& pair : tensor_axes_to_dim_pairs) {
for (int i = 0; i < pair.second.size(); i++) {
for (size_t i = 0; i < pair.second.size(); ++i) {
auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i];
......@@ -71,7 +73,7 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (int i = 1; i < it.second.size(); i++) {
for (size_t i = 1; i < it.second.size(); ++i) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
}
}
......@@ -113,7 +115,7 @@ TensorDistAttr CopyTensorDistAttrForOutput(
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh());
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim());
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims());
new_dist_attr.set_annotated(false);
// new_dist_attr.set_annotated(false); TODO unset field is false by default.
return new_dist_attr;
}
......@@ -122,8 +124,8 @@ std::vector<int64_t> ResoluteOutputPartialDimension(
const std::string& tensor_axes) {
std::vector<int64_t> partial_on_dims;
for (auto& it : in_axis_to_dim_map) {
if (out_axis.find(it.first) != std::string::npos) {
for (auto& it : axis_to_dim_map) {
if (tensor_axes.find(it.first) == std::string::npos) {
if (it.second > -1) {
partial_on_dims.push_back(it.second);
}
......@@ -132,6 +134,20 @@ std::vector<int64_t> ResoluteOutputPartialDimension(
return partial_on_dims;
}
std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet) {
PADDLE_ENFORCE_GE(
alphabet.size(),
broadcast_ndim,
phi::errors::InvalidArgument(
"size of alphabet [%d] is less than broadcast ndim [%d]",
alphabet.size(),
broadcast_ndim));
return alphabet.substr(0, broadcast_ndim)
.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -19,14 +19,18 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using paddle::framework::Attribute;
class SPMDRuleBase {
public:
virtual ~SPMDRuleBase() {}
......@@ -64,10 +68,10 @@ class SPMDRuleBase {
const Attribute& GetAttr(const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto iter = attrs.find(name);
PADDLE_ENFORCE_NE(
iter,
PADDLE_ENFORCE_NE(iter,
attrs.end(),
platform::errors::NotFound("(%s) is not found in AttributeMap."));
paddle::platform::errors::NotFound(
"(%s) is not found in AttributeMap."));
return iter->second;
}
};
......@@ -96,6 +100,15 @@ std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes);
// Generate the axis notation of tensor for the einsum notation of a broadcast
// operation(alignment star from the rightmost axis). tenosr_ndim: the size of
// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast
// operation. alphabet: the characters used to represent the axes of tensor.
// length of alphabet should >= broadcast_ndim.
std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册