dist_attr.h 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstddef>
#include <cstdint>
#include <iostream>
#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {

// Forward Declaration
namespace framework {

class BlockDesc;
class OpDesc;
class ProgramDesc;
class VarDesc;

}  // namespace framework

namespace distributed {
namespace auto_parallel {

using framework::BlockDesc;
using framework::OpDesc;
using framework::ProgramDesc;
using framework::VarDesc;

49 50
constexpr const char* kDefault = "default";

51 52 53 54 55 56 57 58 59 60
class TensorDistAttr {
 public:
  TensorDistAttr() = default;

  explicit TensorDistAttr(const VarDesc& tensor);

  TensorDistAttr(const TensorDistAttr& tensor);

  TensorDistAttr& operator=(const TensorDistAttr& dist_attr);

61 62
  void copy_from(const TensorDistAttr& dist_attr);

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
  const VarDesc* tensor() const { return tensor_; }

  const ProcessMesh& process_mesh() const { return process_mesh_; }

  void set_process_mesh(const ProcessMesh& process_mesh);

  const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }

  void set_dims_mapping(const std::vector<int64_t>& dims_mapping);

  int64_t batch_dim() const { return batch_dim_; }

  void set_batch_dim(int64_t batch_dim);

  const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }

  void set_dynamic_dims(const std::vector<bool>& dynamic_dims);

  const std::map<std::string, bool>& annotated() const { return annotated_; }

  void set_annotated(const std::map<std::string, bool>& annotated);

  void set_default_dims_mapping();

  bool is_annotated(const std::string& name) const {
    return annotated_.count(name) == 1;
  }

  void annotate(const std::string& name);

  bool verify_process_mesh(const ProcessMesh& process_mesh) const;

  bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping) const;

  bool verify_batch_dim(int64_t dim) const;

  bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims) const;

  bool verify_annotated(const std::map<std::string, bool>& annotated) const;

  bool verify() const;

  // TensorDistAttr from_string(const std::string& dist_str);
  std::string to_string() const;

108
  void from_proto(const TensorDistAttrProto& proto);
109 110 111

  TensorDistAttrProto to_proto() const;

112 113 114 115
  std::string serialize_to_string();

  void parse_from_string(const std::string& data);

116 117 118
 private:
  static std::vector<std::string> fields_;
  const VarDesc* tensor_{nullptr};
119
  std::vector<int64_t> tensor_shape_;
120 121
  ProcessMesh process_mesh_;
  std::vector<int64_t> dims_mapping_;
122
  int64_t batch_dim_{0};
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  std::vector<bool> dynamic_dims_;
  std::map<std::string, bool> annotated_;
};

inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
  os << obj.to_string();
  return os;
}

bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);

inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
  return !operator==(lhs, rhs);
}

class OperatorDistAttr {
 public:
  OperatorDistAttr() = default;

  explicit OperatorDistAttr(const OpDesc& op);

  OperatorDistAttr(const OperatorDistAttr& dist_attr);

  OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);

148 149 150 151
  void initialize();

  void copy_from(const OperatorDistAttr& dist_attr);

152 153 154 155 156 157 158 159 160 161 162 163 164 165
  const OpDesc* op() const { return op_; }

  const VarDesc& input(const std::string& name) const {
    return *inputs_.at(name);
  }

  const VarDesc& output(const std::string& name) const {
    return *outputs_.at(name);
  }

  const std::map<std::string, TensorDistAttr>& input_dist_attrs() const {
    return input_dist_attrs_;
  }

166 167 168
  void set_input_dist_attrs(
      const std::map<std::string, TensorDistAttr>& dist_attrs);

169 170 171 172
  const std::map<std::string, TensorDistAttr>& output_dist_attrs() const {
    return output_dist_attrs_;
  }

173 174 175
  void set_output_dist_attrs(
      const std::map<std::string, TensorDistAttr>& dist_attrs);

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
  const TensorDistAttr& input_dist_attr(const std::string& name) const {
    return input_dist_attrs_.at(name);
  }

  TensorDistAttr& input_dist_attr(const std::string& name) {
    return input_dist_attrs_.at(name);
  }

  void set_input_dist_attr(const std::string& name,
                           const TensorDistAttr& dist_attr);

  const TensorDistAttr& output_dist_attr(const std::string& name) const {
    return output_dist_attrs_.at(name);
  }

  TensorDistAttr& output_dist_attr(const std::string& name) {
    return output_dist_attrs_.at(name);
  }

  void set_output_dist_attr(const std::string& name,
                            const TensorDistAttr& dist_attr);

  const ProcessMesh& process_mesh() const { return process_mesh_; }

  void set_process_mesh(const ProcessMesh& process_mesh);

  const std::string& impl_type() const { return impl_type_; }

  void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; }

  int64_t impl_idx() const { return impl_idx_; }

  void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; }

210 211 212 213 214 215
  const std::string& execution_stream() const { return execution_stream_; }

  void set_execution_stream(const std::string& execution_stream) {
    execution_stream_ = execution_stream;
  }

216 217 218 219 220 221 222 223 224 225
  const std::map<std::string, bool>& annotated() const { return annotated_; }

  void set_annotated(const std::map<std::string, bool>& annotated);

  bool is_annotated(const std::string& name) const {
    return annotated_.count(name) == 1;
  }

  void annotate(const std::string& name);

226 227 228 229 230 231 232 233 234 235
  const std::vector<int64_t>& input_dims_mapping(const std::string& name) const;

  void set_input_dims_mapping(const std::string& name,
                              const std::vector<int64_t>& dims_mapping);

  const std::vector<int64_t>& output_dims_mapping(const std::string& name);

  void set_output_dims_mapping(const std::string& name,
                               const std::vector<int64_t>& dims_mapping);

236 237 238 239 240 241 242 243 244 245 246 247
  bool verify_input_dist_attr(const std::string& name,
                              const TensorDistAttr& dist_attr) const;

  bool verify_output_dist_attr(const std::string& name,
                               const TensorDistAttr& dist_attr) const;

  bool verify_process_mesh(const ProcessMesh& process_mesh) const;

  bool verify_annotated(const std::map<std::string, bool>& annotated) const;

  bool verify() const;

248 249 250 251
  void rename_input(const std::string& old_name, const std::string& new_name);

  void rename_output(const std::string& old_name, const std::string& new_name);

252 253 254
  // OperatorDistAttr from_string(const std::string& dist_str);
  std::string to_string() const;

255
  void from_proto(const OperatorDistAttrProto& proto);
256 257 258

  OperatorDistAttrProto to_proto() const;

259 260 261 262
  std::string serialize_to_string();

  void parse_from_string(const std::string& data);

263 264 265 266 267 268 269 270 271 272
 private:
  static std::vector<std::string> fields_;
  const OpDesc* op_{nullptr};
  std::map<std::string, VarDesc*> inputs_;
  std::map<std::string, VarDesc*> outputs_;
  std::map<std::string, TensorDistAttr> input_dist_attrs_;
  std::map<std::string, TensorDistAttr> output_dist_attrs_;
  ProcessMesh process_mesh_;
  std::string impl_type_;
  int64_t impl_idx_ = -1;
273
  std::string execution_stream_;
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
  std::map<std::string, bool> annotated_;
};

inline std::ostream& operator<<(std::ostream& os, const OperatorDistAttr& obj) {
  os << obj.to_string();
  return os;
}

bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs);

inline bool operator!=(const OperatorDistAttr& lhs,
                       const OperatorDistAttr& rhs) {
  return !operator==(lhs, rhs);
}

}  // namespace auto_parallel
}  // namespace distributed
}  // namespace paddle