提交 58666348 编写于 作者: X Xinqi Li

it's weird, but it fucking works


Former-commit-id: a3d86fee0a7c39a062ef4b385e67e08db3d568e3
上级 449b90b3
......@@ -319,6 +319,15 @@ std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2DstChainNodeId()
return ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2RegstId() const {
std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1;
ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->chain_regst_id()};
});
return ret;
}
std::vector<double> DemoChainGraph::RegstId2IsCloned() const {
std::vector<double> ret(regsts_.size());
for (const auto& regst : regsts_) {
......
......@@ -119,7 +119,7 @@ class DemoChainEdge final : public Edge<DemoChainNode, DemoChainEdge> {
explicit DemoChainEdge(const DemoChainRegst* regst) : regst_(regst) {}
~DemoChainEdge() = default;
const DemoChainRegst& regst() const { return *regst_; }
int64_t chain_regst_id() const { return regst_->chain_regst_id(); }
int64_t src_chain_node_id() const { return src_node()->chain_node_id(); }
int64_t dst_chain_node_id() const { return dst_node()->chain_node_id(); }
......@@ -154,6 +154,7 @@ class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
std::vector<std::vector<int64_t>> CalcEdgeId2SrcChainNodeId() const;
std::vector<std::vector<int64_t>> CalcEdgeId2DstChainNodeId() const;
std::vector<std::vector<int64_t>> CalcEdgeId2RegstId() const;
std::vector<std::string> CalcChainNodeId2ChainNodeName() const;
......
......@@ -68,12 +68,10 @@ Tensor CalcIIRatio(const Tensor& chain_node_placement,
}
Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement,
Tensor regst_duration,
const DemoChainGraph& chain_graph,
int piece_num_in_batch) {
auto placement_copies = Clone(chain_node_placement, 2);
Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph);
Tensor regst_duration =
CalcRegstDuration(placement_copies.at(1), chain_graph);
Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph);
Tensor ii_ratio =
CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch);
return MatrixRowSum(
......@@ -81,30 +79,43 @@ Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement,
}
Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
Tensor regst_duration,
const DemoChainGraph& chain_graph) {
auto chain_node_prob_copies = Clone(chain_node_prob, 2);
Tensor edge_src_prob = ColIndexReduce(
chain_node_prob_copies.at(0), chain_graph.CalcEdgeId2SrcChainNodeId());
Tensor edge_dst_prob = ColIndexReduce(
chain_node_prob_copies.at(1), chain_graph.CalcEdgeId2DstChainNodeId());
auto edge_dst_prob_copies = Clone(edge_dst_prob, 2);
Tensor edge_prob = Abs(Sub(edge_src_prob, edge_dst_prob_copies.at(0)));
Tensor copied_chain_regst_prob = Mul(Tensor(0.5), MatrixColSum(edge_prob));
Tensor row_ones(Shape({chain_node_prob.shape().At(0)}), 1);
Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
Tensor edge_regst_duration_prob =
ColIndexReduce(regst_duration, chain_graph.CalcEdgeId2RegstId());
Tensor copied_task_regst_prob =
ElemWiseMul(TensorProduct(row_ones, copied_chain_regst_prob),
edge_dst_prob_copies.at(1));
ElemWiseMul(edge_prob, edge_regst_duration_prob);
return MatrixRowSum(copied_task_regst_prob);
}
Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph) {
auto chain_node_prob_copies = Clone(chain_node_prob, 2);
Tensor regst_duration =
CalcRegstDuration(chain_node_prob_copies.at(0), chain_graph);
return CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1), regst_duration,
chain_graph);
}
Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph,
int piece_num_in_batch) {
auto chain_node_prob_copies = Clone(chain_node_prob, 2);
auto chain_node_prob_copies = Clone(chain_node_prob, 3);
Tensor regst_duration =
CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph);
auto regst_duration_copies = Clone(regst_duration, 2);
return ADD(
CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0), chain_graph,
CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0),
regst_duration_copies.at(0), chain_graph,
piece_num_in_batch),
CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1), chain_graph));
CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1),
regst_duration_copies.at(1), chain_graph));
}
Tensor CalcDeviceMemII(const Tensor& chain_node_placement,
......@@ -149,7 +160,7 @@ void AutoPlacementMemoryDemo() {
std::normal_distribution<double> distr(1, 0.1);
DemoChainGraph chain_graph([](DemoChainGraphBuilder* builder) {
auto regst = builder->ModelOp("op0");
FOR_RANGE(int, i, 1, 19) {
FOR_RANGE(int, i, 1, 23) {
regst = builder->ModelOp("op" + std::to_string(i), {regst});
}
builder->Backward(builder->ModelOp("loss", {regst}));
......@@ -158,59 +169,57 @@ void AutoPlacementMemoryDemo() {
int64_t fw_node_num = chain_graph.FwChainNodeNum();
// std::cout << fw_node_num << std::endl;
// return;
Shape shape({5, fw_node_num});
Shape shape({2, fw_node_num});
Tensor fw_var(shape, [&](size_t index) { return distr(gen); });
Tensor floor_tensor(shape, 0.000000001);
Tensor fw_prob;
auto chain_node_id2name = chain_graph.CalcChainNodeId2ChainNodeName();
FOR_RANGE(int, i, 0, 3000) {
double bugo = 2;
FOR_RANGE(int, step, 0, 5000) {
double lr = 0.01;
fw_prob = ProbabilityMatrix(&fw_var, lr);
Tensor chain_node_prob = ColIndexReduce(fw_prob, chain_node2fw_id);
auto chain_prob_copies = Clone(chain_node_prob, 3);
Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
auto compo_ii_copies = Clone(computation_ii, 2);
Tensor dev_mem =
CalcDeviceMemConsumed(chain_prob_copies.at(2), chain_graph, 4);
Tensor ii = MaxElem(compo_ii_copies.at(1));
// Tensor copied_mem =
// Sum(CalcDeviceCopiedRegstMem(chain_prob_copies.at(3),
// chain_graph));
Tensor penalty = Mul(ADD(Sum(Sqrt(chain_prob_copies.at(1))),
ADD(Mul(Variance(dev_mem), Tensor(1)),
Variance(compo_ii_copies.at(0)))),
Tensor(1));
BackwardRun(ADD(ii, penalty));
// std::cout << "copied_mem: " << copied_mem.At(0) << std::endl;
std::cout << "fw_prob: " << std::endl;
FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
double x = fw_prob.At(i, j);
if (x < 0.01) { x = 0; }
if (x > 0.99) { x = 1; }
std::cout << std::setprecision(3) << x << "\t";
if (step % (static_cast<int>(bugo += 0.01))) {
auto chain_prob_copies = Clone(chain_node_prob, 3);
Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
auto compo_ii_copies = Clone(computation_ii, 2);
Tensor dev_mem =
CalcDeviceMemConsumed(chain_prob_copies.at(2), chain_graph, 4);
Tensor ii = MaxElem(compo_ii_copies.at(1));
Tensor penalty = ADD(Sum(Sqrt(chain_prob_copies.at(1))),
ADD(AvgAbsDeviation(dev_mem),
AvgAbsDeviation(compo_ii_copies.at(0))));
BackwardRun(ADD(ii, penalty));
std::cout << "fw_prob: " << std::endl;
FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
double x = fw_prob.At(i, j);
if (x < 0.01) { x = 0; }
if (x > 0.99) { x = 1; }
std::cout << std::setprecision(3) << x << "\t";
}
std::cout << std::endl;
}
std::cout << "computation_ii: ";
for (double i : computation_ii.buffer().data()) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "dev_mem: ";
for (double i : dev_mem.buffer().data()) { std::cout << i << " "; }
std::cout << std::endl;
}
std::cout << "computation_ii: ";
for (double i : computation_ii.buffer().data()) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "dev_mem: ";
for (double i : dev_mem.buffer().data()) { std::cout << i << " "; }
std::cout << std::endl;
FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
std::cout << "device " << i << ": ";
FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
if (fw_prob.At(i, j) >= 0.5) {
std::cout << chain_node_id2name.at(j) << " ";
FOR_RANGE(int, i, 0, fw_prob.shape().At(0)) {
std::cout << "device " << i << ": ";
FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
if (fw_prob.At(i, j) >= 0.5) {
std::cout << chain_node_id2name.at(j) << " ";
}
}
std::cout << std::endl;
}
std::cout << std::endl;
} else {
BackwardRun(Sum(CalcDeviceCopiedRegstMem(chain_node_prob, chain_graph)));
}
std::cout << std::endl;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册