df_demo.cpp 14.4 KB
Newer Older
X
Xinqi Li 已提交
1 2
#include <random>
#include <cmath>
X
Xinqi Li 已提交
3
#include "oneflow/core/auto_placement/df_func.h"
4
#include "oneflow/core/auto_placement/demo_chain_graph.h"
X
Xinqi Li 已提交
5 6 7 8 9

namespace oneflow {

namespace df {

10 11
namespace {

X
Xinqi Li 已提交
12
Tensor CalcTaskNodeComputeTime(const Tensor& chain_node_placement) {
X
backup  
Xinqi Li 已提交
13
  return chain_node_placement;
X
Xinqi Li 已提交
14 15
}

X
Xinqi Li 已提交
16 17
Tensor CalcDeviceComputeTime(const Tensor& prob_matrix) {
  return MatrixRowSum(prob_matrix);
X
Xinqi Li 已提交
18 19 20
}

Tensor CalcTaskNodeTime(const Tensor& chain_node_placement) {
X
backup  
Xinqi Li 已提交
21
  return chain_node_placement;
22 23
}

X
Xinqi Li 已提交
24 25
Tensor CalcRegstDuration(const Tensor& chain_node_placement,
                         const DemoChainGraph& chain_graph) {
X
backup  
Xinqi Li 已提交
26
  Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
X
Xinqi Li 已提交
27
  Tensor task_node_time = CalcTaskNodeTime(chain_node_placement);
X
backup  
Xinqi Li 已提交
28
  Tensor chain_node_time = MatrixColSum(task_node_time);
X
Xinqi Li 已提交
29
  const auto& regst2path = chain_graph.chain_regst_id2path_chain_node_ids();
X
backup  
Xinqi Li 已提交
30
  return ColIndexReduce(TensorProduct(row_ones, chain_node_time), regst2path);
X
Xinqi Li 已提交
31 32 33 34
}

Tensor CalcRegstMemory(const Tensor& chain_node_placement,
                       const DemoChainGraph& chain_graph) {
X
Xinqi Li 已提交
35 36
  const auto& regst2producer =
      chain_graph.chain_regst_id2producer_chain_node_id();
X
Xinqi Li 已提交
37 38 39 40 41 42
  int64_t regst_num = regst2producer.size();
  Tensor regst_placement = ColIndexReduce(chain_node_placement, regst2producer);
  Tensor row_ones(Shape({regst_placement.shape().At(0)}), 1);
  auto copies = Clone(regst_placement, 3);
  Tensor col_sum = TensorProduct(row_ones, MatrixColSum(copies.at(0)));
  Tensor split_workload_ratio = ElemWiseDiv(copies.at(1), col_sum);
X
backup  
Xinqi Li 已提交
43
  Tensor clone_workload_ratio = copies.at(2);
X
Xinqi Li 已提交
44
  Tensor clone_weight = TensorProduct(
X
Xinqi Li 已提交
45 46
      row_ones,
      Tensor(Shape({regst_num}), chain_graph.chain_regst_id2is_cloned()));
X
Xinqi Li 已提交
47 48 49 50
  auto clone_weight_copies = Clone(clone_weight, 2);
  return ADD(ElemWiseMul(clone_workload_ratio, clone_weight_copies.at(0)),
             ElemWiseMul(split_workload_ratio,
                         Sub(Tensor(1), clone_weight_copies.at(1))));
X
Xinqi Li 已提交
51 52 53
}

Tensor CalcIIRatio(const Tensor& chain_node_placement,
X
Xinqi Li 已提交
54 55
                   const DemoChainGraph& chain_graph) {
  const auto& ii_ratios = chain_graph.chain_regst_id2ii_scale();
X
Xinqi Li 已提交
56 57 58 59 60 61
  int64_t regst_num = ii_ratios.size();
  Tensor ii_ratio_tensor(Shape({regst_num}), ii_ratios);
  Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
  return Reciprocal(TensorProduct(row_ones, ii_ratio_tensor));
}

X
backup  
Xinqi Li 已提交
62
Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement,
X
Xinqi Li 已提交
63
                                  Tensor regst_duration,
X
Xinqi Li 已提交
64
                                  const DemoChainGraph& chain_graph) {
X
Xinqi Li 已提交
65
  Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph);
X
Xinqi Li 已提交
66
  Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
X
backup  
Xinqi Li 已提交
67 68 69 70 71
  return MatrixRowSum(
      ElemWiseMul(ElemWiseMul(ii_ratio, regst_duration), regst_mem));
}

Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
X
Xinqi Li 已提交
72
                                Tensor regst_duration,
X
backup  
Xinqi Li 已提交
73 74 75
                                const DemoChainGraph& chain_graph) {
  auto chain_node_prob_copies = Clone(chain_node_prob, 2);
  Tensor edge_src_prob = ColIndexReduce(
X
Xinqi Li 已提交
76
      chain_node_prob_copies.at(0), chain_graph.edge_id2src_chain_node_id());
X
backup  
Xinqi Li 已提交
77
  Tensor edge_dst_prob = ColIndexReduce(
X
Xinqi Li 已提交
78
      chain_node_prob_copies.at(1), chain_graph.edge_id2dst_chain_node_id());
X
Xinqi Li 已提交
79 80
  Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
  Tensor edge_regst_duration_prob =
X
Xinqi Li 已提交
81
      ColIndexReduce(regst_duration, chain_graph.edge_id2chain_regst_id());
X
backup  
Xinqi Li 已提交
82
  Tensor copied_task_regst_prob =
X
Xinqi Li 已提交
83
      ElemWiseMul(edge_prob, edge_regst_duration_prob);
X
backup  
Xinqi Li 已提交
84 85 86
  return MatrixRowSum(copied_task_regst_prob);
}

X
Xinqi Li 已提交
87 88 89 90 91 92 93 94 95
Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
                                const DemoChainGraph& chain_graph) {
  auto chain_node_prob_copies = Clone(chain_node_prob, 2);
  Tensor regst_duration =
      CalcRegstDuration(chain_node_prob_copies.at(0), chain_graph);
  return CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1), regst_duration,
                                  chain_graph);
}

X
backup  
Xinqi Li 已提交
96
Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob,
X
Xinqi Li 已提交
97
                             const DemoChainGraph& chain_graph) {
X
Xinqi Li 已提交
98 99 100 101
  auto chain_node_prob_copies = Clone(chain_node_prob, 3);
  Tensor regst_duration =
      CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph);
  auto regst_duration_copies = Clone(regst_duration, 2);
102 103
  return ADD(
      CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0),
X
Xinqi Li 已提交
104
                                 regst_duration_copies.at(0), chain_graph),
105 106
      CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1),
                               regst_duration_copies.at(1), chain_graph));
X
backup  
Xinqi Li 已提交
107 108
}

X
Xinqi Li 已提交
109 110 111 112 113 114 115 116 117 118 119
Tensor CalcTransportation(const Tensor& chain_node_prob,
                          const DemoChainGraph& chain_graph) {
  auto chain_node_prob_copies = Clone(chain_node_prob, 2);
  Tensor edge_src_prob = ColIndexReduce(
      chain_node_prob_copies.at(0), chain_graph.edge_id2src_chain_node_id());
  Tensor edge_dst_prob = ColIndexReduce(
      chain_node_prob_copies.at(1), chain_graph.edge_id2dst_chain_node_id());
  Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
  return MatrixRowSum(edge_prob);
}

X
Xinqi Li 已提交
120 121
Tensor CalcDeviceMemII(const Tensor& chain_node_placement,
                       const DemoChainGraph& chain_graph,
X
Xinqi Li 已提交
122
                       double mem_size_per_device) {
X
Xinqi Li 已提交
123 124 125 126
  auto placement_copies = Clone(chain_node_placement, 2);
  Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph);
  Tensor regst_duration =
      CalcRegstDuration(placement_copies.at(1), chain_graph);
X
Xinqi Li 已提交
127
  Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
X
Xinqi Li 已提交
128
  auto ii_ratio_copies = Clone(ii_ratio, 2);
X
Xinqi Li 已提交
129
  auto regst_mem_copies = Clone(regst_mem, 2);
X
Xinqi Li 已提交
130 131 132 133 134
  Tensor weighted_mem_time =
      ElemWiseMul(ElemWiseMul(ii_ratio_copies.at(0), regst_duration),
                  regst_mem_copies.at(0));
  Tensor weighted_mem_ceil_diff = ElemWiseMul(
      Sub(Tensor(1.5), ii_ratio_copies.at(1)), regst_mem_copies.at(1));
X
Xinqi Li 已提交
135 136 137
  Tensor device_mem_time = MatrixRowSum(weighted_mem_time);
  Tensor device_mem =
      Sub(Tensor(mem_size_per_device), MatrixRowSum(weighted_mem_ceil_diff));
X
Xinqi Li 已提交
138 139 140 141 142
  int64_t dev_num = chain_node_placement.shape().At(0);
  Tensor row_ones(Shape({dev_num}), 1);
  Tensor epsilon = Reshape(TensorProduct(row_ones, Tensor(0.000000000001)),
                           Shape({dev_num}));
  Tensor cliped_device_mem = Max(device_mem, epsilon);
X
Xinqi Li 已提交
143
  return ElemWiseDiv(device_mem_time, cliped_device_mem);
144 145
}

X
Xinqi Li 已提交
146 147
Tensor ProbabilityMatrix(Tensor* var, double lr) {
  Tensor row_ones(Shape({var->shape().At(0)}), 1);
X
Xinqi Li 已提交
148
  Tensor epsilon(0.000000000000000001);
X
Xinqi Li 已提交
149 150 151 152 153 154
  Tensor x = ADD(Square(FixedExpectation(Update(var, lr), 1)), epsilon);
  auto x_copies = Clone(x, 2);
  Tensor x_col_sum = TensorProduct(row_ones, MatrixColSum(x_copies.at(0)));
  return ElemWiseDiv(x_copies.at(1), x_col_sum);
}

X
Xinqi Li 已提交
155 156 157 158 159 160 161 162 163 164 165 166
std::function<double()> MakeFlation(int keep, double ratio) {
  std::shared_ptr<int> exec_cnt(new int(-1));
  return [=]() {
    if (++(*exec_cnt) < keep) { return 1.0; }
    return 1.0 / (((*exec_cnt) - keep) * ratio + 1.0);
  };
}

std::function<double()> MakeFlation(int keep) {
  return MakeFlation(keep, 0.005);
}

X
Xinqi Li 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
Tensor SqrtIndecision(const Tensor& input) {
  return Sub(MatrixColSum(Sqrt(input)), Tensor(1));
}

Tensor EntropyIndecision(const Tensor& input) {
  const auto& input_copies = Clone(input, 2);
  return MatrixColSum(Minus(Mul(input_copies.at(0), Log(input_copies.at(1)))));
}

Tensor SquareIndecision(const Tensor& input) {
  const auto& input_copies = Clone(input, 2);
  return MatrixColSum(
      Mul(input_copies.at(0), Sub(Tensor(1), input_copies.at(1))));
}

182
void AutoPlacementMemoryDemo() {
X
Xinqi Li 已提交
183 184
  std::random_device rd{};
  std::mt19937 gen{rd()};
X
Xinqi Li 已提交
185
  std::normal_distribution<double> distr(1, 0.01);
X
Xinqi Li 已提交
186
  DemoChainGraph chain_graph(4, [](DemoChainGraphBuilder* builder) {
X
Xinqi Li 已提交
187
    auto regst = builder->ModelOp("op0");
X
Xinqi Li 已提交
188
    FOR_RANGE(int, i, 1, 63) {
X
Xinqi Li 已提交
189 190 191
      regst = builder->ModelOp("op" + std::to_string(i), {regst});
    }
    builder->Backward(builder->ModelOp("loss", {regst}));
X
Xinqi Li 已提交
192
  });
X
Xinqi Li 已提交
193 194
  const auto& chain_node_id2fw_id =
      chain_graph.chain_node_id2fw_chain_node_id();
X
Xinqi Li 已提交
195
  int64_t fw_node_num = chain_graph.FwChainNodeNum();
X
Xinqi Li 已提交
196
  Shape shape({2, fw_node_num});
X
Xinqi Li 已提交
197
  Tensor fw_var(shape, [&](size_t index) { return distr(gen); });
X
backup  
Xinqi Li 已提交
198
  Tensor fw_prob;
X
Xinqi Li 已提交
199
  const auto& chain_node_id2name = chain_graph.chain_node_id2chain_node_name();
200
  double bugo = 2;
X
Xinqi Li 已提交
201
  double rethink_threshold = 10;
X
Xinqi Li 已提交
202 203 204
  Tensor decision_ratio(Shape({fw_node_num}), [&](int64_t index) {
    return 1 + fw_node_num * 0.5 / (index + 1);
  });
X
Xinqi Li 已提交
205 206
  int64_t mem_keep = 100;
  std::function<double()> MemFlation = MakeFlation(mem_keep);
207
  FOR_RANGE(int, step, 0, 100000) {
X
Xinqi Li 已提交
208
    double lr = 0.01;
X
Xinqi Li 已提交
209
    if (step % (static_cast<int>(bugo += 0.05))) {
X
Xinqi Li 已提交
210
      fw_prob = ProbabilityMatrix(&fw_var, lr);
211 212 213 214
      auto fw_prob_copies = Clone(fw_prob, 2);
      Tensor chain_node_prob =
          ColIndexReduce(fw_prob_copies.at(0), chain_node_id2fw_id);
      auto chain_prob_copies = Clone(chain_node_prob, 2);
X
Xinqi Li 已提交
215 216
      Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
      Tensor dev_mem =
X
Xinqi Li 已提交
217
          CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph);
218
      Tensor normalized_dev_mem =
X
Xinqi Li 已提交
219
          Mul(Tensor(2.5 * MemFlation()), Sqrt(dev_mem));
220
      Tensor fw_indecision =
X
Xinqi Li 已提交
221
          Mul(SqrtIndecision(fw_prob_copies.at(1)), decision_ratio);
222 223
      Tensor indecision = Sum(fw_indecision);
      Tensor balance = ADD(indecision, ADD(AvgAbsDeviation(normalized_dev_mem),
X
Xinqi Li 已提交
224 225
                                           AvgAbsDeviation(computation_ii)));
      BackwardRun(balance);
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272

      if (step % 10 == 0) {
        std::cout << "fw_prob: " << std::endl;
        FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
          FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
            double x = fw_prob.At(i, j);
            if (x < 0.01) { x = 0; }
            if (x > 0.99) { x = 1; }
            std::cout << x << "\t";
          }
          std::cout << std::endl;
        }
        std::cout << "indecision: " << indecision.At(0) << std::endl;
        std::cout << "computation_ii: ";
        for (double i : computation_ii.buffer().data()) {
          std::cout << i << " ";
        }
        std::cout << std::endl;
        std::cout << "normalized_dev_mem: ";
        for (double i : normalized_dev_mem.buffer().data()) {
          std::cout << i << " ";
        }
        std::cout << std::endl;

        std::vector<int64_t> fw_id2dev_id(fw_prob.shape().At(1));
        FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
          double max_val = 0;
          int max_index = 0;
          FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
            if (max_val < fw_prob.At(i, j)) {
              max_val = fw_prob.At(i, j);
              max_index = i;
            }
          }
          fw_id2dev_id.at(j) = max_index;
        }
        std::vector<std::list<int64_t>> dev_id2fw_ids(fw_prob.shape().At(0));
        FOR_RANGE(int, fw_id, 0, fw_id2dev_id.size()) {
          dev_id2fw_ids.at(fw_id2dev_id.at(fw_id)).push_back(fw_id);
        }

        FOR_RANGE(int, dev_id, 0, dev_id2fw_ids.size()) {
          std::cout << "device " << dev_id << ": ";
          for (int64_t fw_id : dev_id2fw_ids.at(dev_id)) {
            std::cout << chain_node_id2name.at(fw_id) << " ";
          }
          std::cout << std::endl;
X
Xinqi Li 已提交
273 274
        }
        std::cout << std::endl;
X
Xinqi Li 已提交
275
      }
X
Xinqi Li 已提交
276
      if (indecision.At(0) < rethink_threshold) {
X
Xinqi Li 已提交
277 278
        MemFlation = MakeFlation(mem_keep);
        rethink_threshold -= (rethink_threshold > 2) ? 1 : 0.01;
X
Xinqi Li 已提交
279 280 281 282
        const auto& edge_id2src_id = chain_graph.edge_id2src_chain_node_id();
        const auto& edge_id2dst_id = chain_graph.edge_id2dst_chain_node_id();
        auto old_fw_var = fw_var.buffer();
        FOR_RANGE(int, conv_iter, 0, 1) {
283 284 285 286 287 288 289 290 291
          chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id);
          Tensor edge_src_prob =
              ColIndexReduce(chain_node_prob, edge_id2src_id);
          Tensor edge_dst_prob =
              ColIndexReduce(chain_node_prob, edge_id2dst_id);
          Tensor edge_prob =
              Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
          FOR_RANGE(int, i, 0, edge_prob.shape().At(0)) {
            FOR_RANGE(int, j, 0, edge_prob.shape().At(1)) {
X
Xinqi Li 已提交
292 293 294 295 296 297 298 299 300 301 302
              if (edge_prob.At(i, j) > 0.2) {
                int64_t src_fw_id =
                    chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0);
                int64_t dst_fw_id =
                    chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0);
                double avg =
                    (old_fw_var.At(i, src_fw_id) + old_fw_var.At(i, dst_fw_id))
                    / 2;
                fw_var.At(i, src_fw_id) = avg;
                fw_var.At(i, dst_fw_id) = avg;
              }
X
Xinqi Li 已提交
303 304 305
            }
          }
        }
306
        bugo = 20;
X
backup  
Xinqi Li 已提交
307
      }
X
Xinqi Li 已提交
308
    } else {
309
      FOR_RANGE(int, i, 0, 3) {
X
Xinqi Li 已提交
310
        fw_prob = ProbabilityMatrix(&fw_var, lr);
X
Xinqi Li 已提交
311
        Tensor chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id);
X
Xinqi Li 已提交
312 313 314 315
        Tensor copied_mem =
            Sum(CalcDeviceCopiedRegstMem(chain_node_prob, chain_graph));
        BackwardRun(copied_mem);
      }
X
backup  
Xinqi Li 已提交
316
    }
317 318 319 320
  }
}

void AutoPlacementComputationDemo() {
X
Xinqi Li 已提交
321
  Tensor var(Shape({4, 5}), [](size_t index) { return index % 2 ? 0 : 1; });
322 323
  Tensor row_ones(Shape({var.shape().At(0)}), 1);
  Tensor col_ones(Shape({var.shape().At(1)}), 1);
X
Xinqi Li 已提交
324
  Tensor epsilon(0.000000001);
X
Xinqi Li 已提交
325 326
  FOR_RANGE(int, i, 0, 10000) {
    double lr = 0.001;
X
Xinqi Li 已提交
327

X
Xinqi Li 已提交
328
    Tensor x = ADD(Square(FixedExpectation(Update(&var, lr), 1)), epsilon);
329 330 331
    const auto& x_copies = Clone(x, 4);
    Tensor row = MatrixRowSum(x_copies.at(0));
    Tensor col = MatrixColSum(x_copies.at(1));
X
Xinqi Li 已提交
332
    Tensor load = ElemWiseDiv(x_copies.at(2), TensorProduct(row_ones, col));
333
    Tensor table = ElemWiseMul(TensorProduct(row, col_ones), load);
X
Xinqi Li 已提交
334
    Tensor ii = MaxElem(table);
X
Xinqi Li 已提交
335
    BackwardRun(ADD(ii, Variance(MatrixColMax(x_copies.at(3)))));
336 337 338

    std::cout << "x: ";
    for (double i : x.buffer().data()) { std::cout << i << " "; }
X
Xinqi Li 已提交
339 340
    std::cout << std::endl;
    std::cout << "row: ";
341
    for (double i : row.buffer().data()) { std::cout << i << " "; }
X
Xinqi Li 已提交
342 343
    std::cout << std::endl;
    std::cout << "col: ";
344
    for (double i : col.buffer().data()) { std::cout << i << " "; }
X
Xinqi Li 已提交
345
    std::cout << std::endl;
X
Xinqi Li 已提交
346
    std::cout << "table: ";
347
    for (double i : table.buffer().data()) { std::cout << i << " "; }
X
Xinqi Li 已提交
348
    std::cout << std::endl << std::endl;
X
Xinqi Li 已提交
349 350 351
  }
}

352
void DifferentialDemo() {
X
Xinqi Li 已提交
353
  // AutoPlacementComputationDemo();
354 355 356 357 358
  AutoPlacementMemoryDemo();
}

}  // namespace

X
Xinqi Li 已提交
359 360 361 362 363
}  // namespace df

}  // namespace oneflow

int main(int argc, char** argv) {
364
  oneflow::df::DifferentialDemo();
X
Xinqi Li 已提交
365 366
  return 0;
}