dist_attr.h 9.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstddef>
#include <cstdint>
#include <iostream>
#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {

// Forward Declaration
namespace framework {

class BlockDesc;
class OpDesc;
class ProgramDesc;
class VarDesc;

}  // namespace framework

namespace distributed {
namespace auto_parallel {

using framework::BlockDesc;
using framework::OpDesc;
using framework::ProgramDesc;
using framework::VarDesc;

49 50
constexpr const char* kDefault = "default";

51 52 53 54 55 56 57 58 59 60
class TensorDistAttr {
 public:
  TensorDistAttr() = default;

  explicit TensorDistAttr(const VarDesc& tensor);

  TensorDistAttr(const TensorDistAttr& tensor);

  TensorDistAttr& operator=(const TensorDistAttr& dist_attr);

61 62
  void copy_from(const TensorDistAttr& dist_attr);

63 64 65 66 67 68 69 70
  const ProcessMesh& process_mesh() const { return process_mesh_; }

  void set_process_mesh(const ProcessMesh& process_mesh);

  const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }

  void set_dims_mapping(const std::vector<int64_t>& dims_mapping);

71 72
  void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);

73 74 75 76 77 78 79 80
  int64_t batch_dim() const { return batch_dim_; }

  void set_batch_dim(int64_t batch_dim);

  const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }

  void set_dynamic_dims(const std::vector<bool>& dynamic_dims);

81 82
  void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);

83 84 85 86 87
  const std::map<std::string, bool>& annotated() const { return annotated_; }

  void set_annotated(const std::map<std::string, bool>& annotated);

  bool is_annotated(const std::string& name) const {
88
    return annotated_.count(name) == 1 && annotated_.at(name) == true;
89 90
  }

91 92 93
  void mark_annotated(const std::string& name);

  void clear_annotated() { annotated_.clear(); }
94 95 96

  bool verify_process_mesh(const ProcessMesh& process_mesh) const;

97 98
  bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
                           const std::vector<int64_t>& tensor_shape) const;
99

100 101
  bool verify_batch_dim(int64_t dim,
                        const std::vector<int64_t>& tensor_shape) const;
102

103 104
  bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
                           const std::vector<int64_t>& tensor_shape) const;
105 106 107

  bool verify_annotated(const std::map<std::string, bool>& annotated) const;

108
  bool verify(const VarDesc* tensor = nullptr) const;
109 110 111 112

  // TensorDistAttr from_string(const std::string& dist_str);
  std::string to_string() const;

113
  void from_proto(const TensorDistAttrProto& proto);
114 115 116

  TensorDistAttrProto to_proto() const;

117 118 119 120
  std::string serialize_to_string();

  void parse_from_string(const std::string& data);

121 122 123 124
 private:
  static std::vector<std::string> fields_;
  ProcessMesh process_mesh_;
  std::vector<int64_t> dims_mapping_;
125
  int64_t batch_dim_{0};
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
  std::vector<bool> dynamic_dims_;
  std::map<std::string, bool> annotated_;
};

inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
  os << obj.to_string();
  return os;
}

bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);

inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
  return !operator==(lhs, rhs);
}

class OperatorDistAttr {
 public:
  OperatorDistAttr() = default;

  explicit OperatorDistAttr(const OpDesc& op);

  OperatorDistAttr(const OperatorDistAttr& dist_attr);

  OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);

151
  void initialize(const OpDesc* op = nullptr);
152 153 154

  void copy_from(const OperatorDistAttr& dist_attr);

155 156
  const std::map<std::string, TensorDistAttr>& input_dist_attrs() const {
    return input_dist_attrs_;
157 158
  }

159
  std::map<std::string, TensorDistAttr>& input_dist_attrs() {
160 161 162
    return input_dist_attrs_;
  }

163 164 165
  void set_input_dist_attrs(
      const std::map<std::string, TensorDistAttr>& dist_attrs);

166 167 168 169
  const std::map<std::string, TensorDistAttr>& output_dist_attrs() const {
    return output_dist_attrs_;
  }

170 171 172 173
  std::map<std::string, TensorDistAttr>& output_dist_attrs() {
    return output_dist_attrs_;
  }

174 175 176
  void set_output_dist_attrs(
      const std::map<std::string, TensorDistAttr>& dist_attrs);

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
  const TensorDistAttr& input_dist_attr(const std::string& name) const {
    return input_dist_attrs_.at(name);
  }

  TensorDistAttr& input_dist_attr(const std::string& name) {
    return input_dist_attrs_.at(name);
  }

  void set_input_dist_attr(const std::string& name,
                           const TensorDistAttr& dist_attr);

  const TensorDistAttr& output_dist_attr(const std::string& name) const {
    return output_dist_attrs_.at(name);
  }

  TensorDistAttr& output_dist_attr(const std::string& name) {
    return output_dist_attrs_.at(name);
  }

  void set_output_dist_attr(const std::string& name,
                            const TensorDistAttr& dist_attr);

  const ProcessMesh& process_mesh() const { return process_mesh_; }

  void set_process_mesh(const ProcessMesh& process_mesh);

203 204 205 206
  const std::string& op_type() const { return op_type_; }

  void set_op_type(const std::string& op_type) { op_type_ = op_type; }

207 208 209 210 211 212 213 214
  const std::string& impl_type() const { return impl_type_; }

  void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; }

  int64_t impl_idx() const { return impl_idx_; }

  void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; }

215 216 217 218
  bool is_recompute() const { return is_recompute_; }

  void set_is_recompute(bool is_recompute) { is_recompute_ = is_recompute; }

219 220 221 222 223 224
  const std::string& execution_stream() const { return execution_stream_; }

  void set_execution_stream(const std::string& execution_stream) {
    execution_stream_ = execution_stream;
  }

225 226 227 228 229 230
  int64_t scheduling_priority() const { return scheduling_priority_; }

  void set_scheduling_priority(int64_t scheduling_priority) {
    scheduling_priority_ = scheduling_priority;
  }

231 232 233 234 235
  const std::map<std::string, bool>& annotated() const { return annotated_; }

  void set_annotated(const std::map<std::string, bool>& annotated);

  bool is_annotated(const std::string& name) const {
236
    return annotated_.count(name) == 1 && annotated_.at(name) == true;
237 238
  }

239 240 241
  void mark_annotated(const std::string& name);

  void clear_annotated();
242

243 244 245 246 247 248 249 250 251 252
  const std::vector<int64_t>& input_dims_mapping(const std::string& name) const;

  void set_input_dims_mapping(const std::string& name,
                              const std::vector<int64_t>& dims_mapping);

  const std::vector<int64_t>& output_dims_mapping(const std::string& name);

  void set_output_dims_mapping(const std::string& name,
                               const std::vector<int64_t>& dims_mapping);

253
  bool verify_input_dist_attr(const std::string& name,
254 255
                              const TensorDistAttr& dist_attr,
                              const VarDesc* tensor) const;
256 257

  bool verify_output_dist_attr(const std::string& name,
258 259
                               const TensorDistAttr& dist_attr,
                               const VarDesc* tensor) const;
260 261 262 263 264

  bool verify_process_mesh(const ProcessMesh& process_mesh) const;

  bool verify_annotated(const std::map<std::string, bool>& annotated) const;

265
  bool verify(const OpDesc* op = nullptr) const;
266

267 268 269 270
  void rename_input(const std::string& old_name, const std::string& new_name);

  void rename_output(const std::string& old_name, const std::string& new_name);

271 272 273
  // OperatorDistAttr from_string(const std::string& dist_str);
  std::string to_string() const;

274
  void from_proto(const OperatorDistAttrProto& proto);
275 276 277

  OperatorDistAttrProto to_proto() const;

278 279 280 281
  std::string serialize_to_string();

  void parse_from_string(const std::string& data);

282 283 284 285 286
 private:
  static std::vector<std::string> fields_;
  std::map<std::string, TensorDistAttr> input_dist_attrs_;
  std::map<std::string, TensorDistAttr> output_dist_attrs_;
  ProcessMesh process_mesh_;
287 288 289 290
  std::string op_type_;
  std::string impl_type_ = kDefault;
  int64_t impl_idx_ = 0;
  bool is_recompute_ = false;
291
  std::string execution_stream_;
292
  int64_t scheduling_priority_;  // lower value, higher priority, default to 0
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
  std::map<std::string, bool> annotated_;
};

inline std::ostream& operator<<(std::ostream& os, const OperatorDistAttr& obj) {
  os << obj.to_string();
  return os;
}

bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs);

inline bool operator!=(const OperatorDistAttr& lhs,
                       const OperatorDistAttr& rhs) {
  return !operator==(lhs, rhs);
}

}  // namespace auto_parallel
}  // namespace distributed
}  // namespace paddle