diff --git a/oneflow/core/auto_placement/demo_chain_graph.cpp b/oneflow/core/auto_placement/demo_chain_graph.cpp index 7bde91368649b1be1a6de78e425d8cf4556e018a..9b08d46cb57a4113b20274273e87b064bf8db875 100644 --- a/oneflow/core/auto_placement/demo_chain_graph.cpp +++ b/oneflow/core/auto_placement/demo_chain_graph.cpp @@ -109,12 +109,24 @@ size_t DemoChainGraph::ChainNodeNum() const { } DemoChainGraph::DemoChainGraph( + int piece_num_in_batch, const std::function& Build) - : chain_node_id_(-1), chain_regst_id_(-1) { + : chain_node_id_(-1), + chain_regst_id_(-1), + piece_num_in_batch_(piece_num_in_batch) { DemoChainGraphBuilder builder(this); Build(&builder); InitIsReachable(); InitRegst2ChainNodeSubGraphs(); + InitChainNodeId2FwChainNodeId(); + InitChainRegstId2ProducerChainNodeId(); + InitChainRegstId2PathChainNodeIds(); + InitEdgeId2SrcChainNodeId(); + InitEdgeId2DstChainNodeId(); + InitEdgeId2ChainRegstId(); + InitChainNodeId2ChainNodeName(); + InitChainRegstId2IsCloned(); + InitChainRegstId2IIScale(); } IsReachablePredicator DemoChainGraph::MakeIsReachablePredicator() const { @@ -180,33 +192,31 @@ void DemoChainGraph::InitRegst2ChainNodeSubGraphs() { } } -std::vector> -DemoChainGraph::CalcChainNodeId2FwChainNodeId() const { +void DemoChainGraph::InitChainNodeId2FwChainNodeId() { std::vector> ret(ChainNodeNum()); ForEachNode([&](const DemoChainNode* node) { CHECK_LT(node->chain_node_id(), ret.size()); ret.at(node->chain_node_id()).push_back(node->fw_chain_node_id()); }); - return ret; + chain_node_id2fw_chain_node_id_ = ret; } -std::vector> -DemoChainGraph::CalcChainRegstId2ProducerChainNodeId() const { +void DemoChainGraph::InitChainRegstId2ProducerChainNodeId() { std::vector> ret(regsts_.size()); for (const auto& regst : regsts_) { int64_t chain_node_id = regst->producer()->chain_node_id(); CHECK_LT(regst->chain_regst_id(), ret.size()); ret.at(regst->chain_regst_id()).push_back(chain_node_id); } - return ret; + chain_regst_id2producer_chain_node_id_ = ret; } -std::vector DemoChainGraph::CalcChainNodeId2ChainNodeName() const { +void DemoChainGraph::InitChainNodeId2ChainNodeName() { std::vector ret(node_num()); ForEachNode([&](const DemoChainNode* node) { ret.at(node->chain_node_id()) = node->name(); }); - return ret; + chain_node_id2chain_node_name_ = ret; } void DemoChainNodeSubGraph::TopoForEachChainNode( @@ -299,54 +309,52 @@ DemoChainGraph::CalcChainRegstId2PathChainNodeIds( return ret; } -std::vector> DemoChainGraph::CalcEdgeId2SrcChainNodeId() - const { +void DemoChainGraph::InitEdgeId2SrcChainNodeId() { std::vector> ret(edge_num()); int index = -1; ForEachEdge([&](DemoChainEdge* edge) { ret.at(++index) = std::vector{edge->src_chain_node_id()}; }); - return ret; + edge_id2src_chain_node_id_ = ret; } -std::vector> DemoChainGraph::CalcEdgeId2DstChainNodeId() - const { +void DemoChainGraph::InitEdgeId2DstChainNodeId() { std::vector> ret(edge_num()); int index = -1; ForEachEdge([&](DemoChainEdge* edge) { ret.at(++index) = std::vector{edge->dst_chain_node_id()}; }); - return ret; + edge_id2dst_chain_node_id_ = ret; } -std::vector> DemoChainGraph::CalcEdgeId2RegstId() const { +void DemoChainGraph::InitEdgeId2ChainRegstId() { std::vector> ret(edge_num()); int index = -1; ForEachEdge([&](DemoChainEdge* edge) { ret.at(++index) = std::vector{edge->chain_regst_id()}; }); - return ret; + edge_id2chain_regst_id_ = ret; } -std::vector DemoChainGraph::RegstId2IsCloned() const { +void DemoChainGraph::InitChainRegstId2IsCloned() { std::vector ret(regsts_.size()); for (const auto& regst : regsts_) { ret.at(regst->chain_regst_id()) = (regst->IsRegstCloned() ? 1 : 0); } - return ret; + chain_regst_id2is_cloned_ = ret; } -std::vector DemoChainGraph::RegstIIRatio(int piece_num_in_batch) const { +void DemoChainGraph::InitChainRegstId2IIScale() { std::vector ret(regsts_.size()); for (const auto& regst : regsts_) { - double ii_ratio = 1; + double ii_scale = 1; if (regst->producer()->task_type() == TaskType::kMdDiffAcc || regst->producer()->task_type() == TaskType::kMdUpdt) { - ii_ratio = piece_num_in_batch; + ii_scale = piece_num_in_batch_; } - ret.at(regst->chain_regst_id()) = ii_ratio; + ret.at(regst->chain_regst_id()) = ii_scale; } - return ret; + chain_regst_id2ii_scale_ = ret; } bool DemoChainRegst::IsRegstCloned() const { diff --git a/oneflow/core/auto_placement/demo_chain_graph.h b/oneflow/core/auto_placement/demo_chain_graph.h index f72b8a1fc81ddd9e1ed8ccccc9211f2cf2b0fbb7..3c27ff0b438cbcc1d68e1ea94e8d022ec35de01d 100644 --- a/oneflow/core/auto_placement/demo_chain_graph.h +++ b/oneflow/core/auto_placement/demo_chain_graph.h @@ -133,34 +133,48 @@ class DemoChainGraphBuilder; class DemoChainGraph final : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(DemoChainGraph); - DemoChainGraph(const std::function& Build); + DemoChainGraph(int piece_num_in_batch, + const std::function& Build); virtual ~DemoChainGraph() = default; size_t FwChainNodeNum() const; size_t ChainNodeNum() const; - std::vector> CalcChainNodeId2FwChainNodeId() const; - - std::vector> CalcChainRegstId2ProducerChainNodeId() - const; - - std::vector> CalcChainRegstId2PathChainNodeIds( - const std::function& GetTime) const; + const std::vector>& chain_node_id2fw_chain_node_id() + const { + return chain_node_id2fw_chain_node_id_; + } - std::vector> CalcChainRegstId2PathChainNodeIds() const { - return CalcChainRegstId2PathChainNodeIds( - [](int64_t) -> double { return 1; }); + const std::vector>& + chain_regst_id2producer_chain_node_id() const { + return chain_regst_id2producer_chain_node_id_; } - std::vector> CalcEdgeId2SrcChainNodeId() const; - std::vector> CalcEdgeId2DstChainNodeId() const; - std::vector> CalcEdgeId2RegstId() const; + const std::vector>& chain_regst_id2path_chain_node_ids() + const { + return chain_regst_id2path_chain_node_ids_; + } - std::vector CalcChainNodeId2ChainNodeName() const; + const std::vector>& edge_id2src_chain_node_id() const { + return edge_id2src_chain_node_id_; + } + const std::vector>& edge_id2dst_chain_node_id() const { + return edge_id2dst_chain_node_id_; + } + const std::vector>& edge_id2chain_regst_id() const { + return edge_id2chain_regst_id_; + } + const std::vector& chain_node_id2chain_node_name() const { + return chain_node_id2chain_node_name_; + } - std::vector RegstId2IsCloned() const; + const std::vector& chain_regst_id2is_cloned() const { + return chain_regst_id2is_cloned_; + } - std::vector RegstIIRatio(int piece_num_in_batch) const; + const std::vector& chain_regst_id2ii_scale() const { + return chain_regst_id2ii_scale_; + } private: friend class DemoChainGraphBuilder; @@ -176,12 +190,44 @@ class DemoChainGraph final : public Graph { const DemoChainNode* node, const std::function& Handler) const; + void InitChainNodeId2FwChainNodeId(); + + void InitChainRegstId2ProducerChainNodeId(); + + std::vector> CalcChainRegstId2PathChainNodeIds( + const std::function& GetTime) const; + + void InitChainRegstId2PathChainNodeIds() { + chain_regst_id2path_chain_node_ids_ = + CalcChainRegstId2PathChainNodeIds([](int64_t) -> double { return 1; }); + } + + void InitEdgeId2SrcChainNodeId(); + void InitEdgeId2DstChainNodeId(); + void InitEdgeId2ChainRegstId(); + + void InitChainNodeId2ChainNodeName(); + + void InitChainRegstId2IsCloned(); + + void InitChainRegstId2IIScale(); + int64_t chain_node_id_; int64_t chain_regst_id_; + int64_t piece_num_in_batch_; std::list> regsts_; IsReachablePredicator is_reachable_; HashMap> regst2chain_node_sub_graph_; + std::vector> chain_node_id2fw_chain_node_id_; + std::vector> chain_regst_id2producer_chain_node_id_; + std::vector> chain_regst_id2path_chain_node_ids_; + std::vector> edge_id2src_chain_node_id_; + std::vector> edge_id2dst_chain_node_id_; + std::vector> edge_id2chain_regst_id_; + std::vector chain_node_id2chain_node_name_; + std::vector chain_regst_id2is_cloned_; + std::vector chain_regst_id2ii_scale_; }; class DemoChainGraphBuilder final { diff --git a/oneflow/core/auto_placement/demo_chain_graph_test.cpp b/oneflow/core/auto_placement/demo_chain_graph_test.cpp index 03aee39bd73a8384ba71d5b929bddad9972ca612..c112d54ff67d56a96157da5a64ff8f0a30a14f1f 100644 --- a/oneflow/core/auto_placement/demo_chain_graph_test.cpp +++ b/oneflow/core/auto_placement/demo_chain_graph_test.cpp @@ -7,7 +7,7 @@ namespace df { namespace test { TEST(DemoChainGraph, simple_without_model) { - DemoChainGraph graph([](DemoChainGraphBuilder* builder) { + DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) { builder->Backward(builder->Op( "soft_max", {builder->Op("feature"), builder->Op("label")})); }); @@ -15,22 +15,22 @@ TEST(DemoChainGraph, simple_without_model) { ASSERT_EQ(graph.ChainNodeNum(), 3 * 2); std::vector> expected_fw_ids{{0}, {1}, {2}, {2}, {1}, {0}}; - ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids); + ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids); std::vector> expected_producer_ids{{0}, {1}, {2}, {2}, {3}, {3}}; - ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId() + ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id() == expected_producer_ids); std::vector> expected_path{ {0, 2, 3, 5}, {1, 2, 3, 4}, {2, 3}, {2, 3}, {3, 4}, {3, 5}}; - ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path); + ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path); std::vector expected_regst_id2is_cloned{0, 0, 0, 0, 0, 0}; - ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned); + ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned); } TEST(DemoChainGraph, simple_with_model) { - DemoChainGraph graph([](DemoChainGraphBuilder* builder) { + DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) { builder->Backward( builder->Op("op0", {builder->Op("data"), builder->Model("model")})); }); @@ -39,18 +39,18 @@ TEST(DemoChainGraph, simple_with_model) { std::vector> expected_fw_ids{{0}, {1}, {1}, {1}, {1}, {0}}; - ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids); + ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids); std::vector> expected_producer_ids{{0}, {4}, {1}, {1}, {2}, {3}, {2}}; - ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId() + ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id() == expected_producer_ids); std::vector> expected_path{ {0, 1, 2, 5}, {4, 1, 2}, {1, 2}, {1, 2}, {2, 3}, {3, 4}, {2, 5}}; - ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path); + ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path); std::vector expected_regst_id2is_cloned{0, 1, 0, 0, 1, 1, 0}; - ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned); + ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned); } } // namespace test diff --git a/oneflow/core/auto_placement/df_demo.cpp b/oneflow/core/auto_placement/df_demo.cpp index f915d5edae52849706f66c8e4626cbaf990cb93f..f7b22a3aea8620d6cc463a95a8d91a419533b7ec 100644 --- a/oneflow/core/auto_placement/df_demo.cpp +++ b/oneflow/core/auto_placement/df_demo.cpp @@ -19,13 +19,6 @@ Tensor CalcDeviceComputeTime(const Tensor& prob_matrix) { Tensor CalcTaskNodeTime(const Tensor& chain_node_placement) { return chain_node_placement; - // Tensor compute_time = CalcTaskNodeComputeTime(chain_node_placement); - // Tensor col_ones(Shape({chain_node_placement.shape().At(1)}), 1); - // return TensorProduct(MatrixRowSum(compute_time), col_ones); - // auto compute_time_copies = Clone(compute_time, 2); - // Tensor row_sum = - // TensorProduct(MatrixRowSum(compute_time_copies.at(0)), col_ones); - // return Mul(Tensor(0.5), ADD(row_sum, compute_time_copies.at(1))); } Tensor CalcRegstDuration(const Tensor& chain_node_placement, @@ -33,16 +26,14 @@ Tensor CalcRegstDuration(const Tensor& chain_node_placement, Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1); Tensor task_node_time = CalcTaskNodeTime(chain_node_placement); Tensor chain_node_time = MatrixColSum(task_node_time); - auto GetTime = [chain_node_time](int64_t chain_node_id) -> double { - return chain_node_time.At(chain_node_id); - }; - auto regst2path = chain_graph.CalcChainRegstId2PathChainNodeIds(GetTime); + const auto& regst2path = chain_graph.chain_regst_id2path_chain_node_ids(); return ColIndexReduce(TensorProduct(row_ones, chain_node_time), regst2path); } Tensor CalcRegstMemory(const Tensor& chain_node_placement, const DemoChainGraph& chain_graph) { - auto regst2producer = chain_graph.CalcChainRegstId2ProducerChainNodeId(); + const auto& regst2producer = + chain_graph.chain_regst_id2producer_chain_node_id(); int64_t regst_num = regst2producer.size(); Tensor regst_placement = ColIndexReduce(chain_node_placement, regst2producer); Tensor row_ones(Shape({regst_placement.shape().At(0)}), 1); @@ -51,7 +42,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement, Tensor split_workload_ratio = ElemWiseDiv(copies.at(1), col_sum); Tensor clone_workload_ratio = copies.at(2); Tensor clone_weight = TensorProduct( - row_ones, Tensor(Shape({regst_num}), chain_graph.RegstId2IsCloned())); + row_ones, + Tensor(Shape({regst_num}), chain_graph.chain_regst_id2is_cloned())); auto clone_weight_copies = Clone(clone_weight, 2); return ADD(ElemWiseMul(clone_workload_ratio, clone_weight_copies.at(0)), ElemWiseMul(split_workload_ratio, @@ -59,8 +51,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement, } Tensor CalcIIRatio(const Tensor& chain_node_placement, - const DemoChainGraph& chain_graph, int piece_num_in_batch) { - auto ii_ratios = chain_graph.RegstIIRatio(piece_num_in_batch); + const DemoChainGraph& chain_graph) { + const auto& ii_ratios = chain_graph.chain_regst_id2ii_scale(); int64_t regst_num = ii_ratios.size(); Tensor ii_ratio_tensor(Shape({regst_num}), ii_ratios); Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1); @@ -69,11 +61,9 @@ Tensor CalcIIRatio(const Tensor& chain_node_placement, Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement, Tensor regst_duration, - const DemoChainGraph& chain_graph, - int piece_num_in_batch) { + const DemoChainGraph& chain_graph) { Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph); - Tensor ii_ratio = - CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch); + Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph); return MatrixRowSum( ElemWiseMul(ElemWiseMul(ii_ratio, regst_duration), regst_mem)); } @@ -83,12 +73,12 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob, const DemoChainGraph& chain_graph) { auto chain_node_prob_copies = Clone(chain_node_prob, 2); Tensor edge_src_prob = ColIndexReduce( - chain_node_prob_copies.at(0), chain_graph.CalcEdgeId2SrcChainNodeId()); + chain_node_prob_copies.at(0), chain_graph.edge_id2src_chain_node_id()); Tensor edge_dst_prob = ColIndexReduce( - chain_node_prob_copies.at(1), chain_graph.CalcEdgeId2DstChainNodeId()); + chain_node_prob_copies.at(1), chain_graph.edge_id2dst_chain_node_id()); Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob))); Tensor edge_regst_duration_prob = - ColIndexReduce(regst_duration, chain_graph.CalcEdgeId2RegstId()); + ColIndexReduce(regst_duration, chain_graph.edge_id2chain_regst_id()); Tensor copied_task_regst_prob = ElemWiseMul(edge_prob, edge_regst_duration_prob); return MatrixRowSum(copied_task_regst_prob); @@ -104,29 +94,26 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob, } Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob, - const DemoChainGraph& chain_graph, - int piece_num_in_batch) { + const DemoChainGraph& chain_graph) { auto chain_node_prob_copies = Clone(chain_node_prob, 3); Tensor regst_duration = CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph); auto regst_duration_copies = Clone(regst_duration, 2); return ADD( CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0), - regst_duration_copies.at(0), chain_graph, - piece_num_in_batch), + regst_duration_copies.at(0), chain_graph), CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1), regst_duration_copies.at(1), chain_graph)); } Tensor CalcDeviceMemII(const Tensor& chain_node_placement, const DemoChainGraph& chain_graph, - int piece_num_in_batch, double mem_size_per_device) { + double mem_size_per_device) { auto placement_copies = Clone(chain_node_placement, 2); Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph); Tensor regst_duration = CalcRegstDuration(placement_copies.at(1), chain_graph); - Tensor ii_ratio = - CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch); + Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph); auto ii_ratio_copies = Clone(ii_ratio, 2); auto regst_mem_copies = Clone(regst_mem, 2); Tensor weighted_mem_time = @@ -154,34 +141,45 @@ Tensor ProbabilityMatrix(Tensor* var, double lr) { return ElemWiseDiv(x_copies.at(1), x_col_sum); } +std::function MakeFlation(int keep, double ratio) { + std::shared_ptr exec_cnt(new int(-1)); + return [=]() { + if (++(*exec_cnt) < keep) { return 1.0; } + return 1.0 / (((*exec_cnt) - keep) * ratio + 1.0); + }; +} + +std::function MakeFlation(int keep) { + return MakeFlation(keep, 0.005); +} + void AutoPlacementMemoryDemo() { std::random_device rd{}; std::mt19937 gen{rd()}; std::normal_distribution distr(1, 0.01); - DemoChainGraph chain_graph([](DemoChainGraphBuilder* builder) { + DemoChainGraph chain_graph(4, [](DemoChainGraphBuilder* builder) { auto regst = builder->ModelOp("op0"); FOR_RANGE(int, i, 1, 63) { regst = builder->ModelOp("op" + std::to_string(i), {regst}); } builder->Backward(builder->ModelOp("loss", {regst})); }); - auto chain_node_id2fw_id = chain_graph.CalcChainNodeId2FwChainNodeId(); + const auto& chain_node_id2fw_id = + chain_graph.chain_node_id2fw_chain_node_id(); int64_t fw_node_num = chain_graph.FwChainNodeNum(); - Shape shape({4, fw_node_num}); + Shape shape({2, fw_node_num}); Tensor fw_var(shape, [&](size_t index) { return distr(gen); }); Tensor fw_prob; - auto chain_node_id2name = chain_graph.CalcChainNodeId2ChainNodeName(); + const auto& chain_node_id2name = chain_graph.chain_node_id2chain_node_name(); double bugo = 2; - double rethink_threshold = 60; - int rethink_cnt = -1; - int deep_think = 1; - Tensor decision_ratio(Shape({fw_node_num}), - [&](int64_t index) { return 1 + 10.0 / (index + 1); }); - double mem_importance_flation = 0; + double rethink_threshold = 20; + Tensor decision_ratio(Shape({fw_node_num}), [&](int64_t index) { + return 1 + fw_node_num * 0.5 / (index + 1); + }); + std::function MemFlation = MakeFlation(100); FOR_RANGE(int, step, 0, 100000) { double lr = 0.01; if (step % (static_cast(bugo += 0.05))) { - double mem_importance = 1 / (1 + (mem_importance_flation += 0.005)); fw_prob = ProbabilityMatrix(&fw_var, lr); auto fw_prob_copies = Clone(fw_prob, 2); Tensor chain_node_prob = @@ -189,9 +187,9 @@ void AutoPlacementMemoryDemo() { auto chain_prob_copies = Clone(chain_node_prob, 2); Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0)); Tensor dev_mem = - CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph, 4); + CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph); Tensor normalized_dev_mem = - Mul(Tensor(1.65 * mem_importance), Sqrt(dev_mem)); + Mul(Tensor(2.5 * MemFlation()), Sqrt(dev_mem)); Tensor fw_indecision = Mul(Sub(MatrixColSum(Sqrt(fw_prob_copies.at(1))), Tensor(1)), decision_ratio); @@ -249,17 +247,13 @@ void AutoPlacementMemoryDemo() { } std::cout << std::endl; } - if ((indecision.At(0) < rethink_threshold) - && (rethink_threshold > 4 || !((++rethink_cnt) % deep_think))) { - mem_importance_flation = 0; - if (rethink_threshold > 4) { - rethink_threshold -= 1; - } else { - deep_think += 5; - } - auto edge_id2src_id = chain_graph.CalcEdgeId2SrcChainNodeId(); - auto edge_id2dst_id = chain_graph.CalcEdgeId2DstChainNodeId(); - FOR_RANGE(int, conv_iter, 0, 3) { + if (indecision.At(0) < rethink_threshold) { + MemFlation = MakeFlation(100); + rethink_threshold -= 1; + const auto& edge_id2src_id = chain_graph.edge_id2src_chain_node_id(); + const auto& edge_id2dst_id = chain_graph.edge_id2dst_chain_node_id(); + auto old_fw_var = fw_var.buffer(); + FOR_RANGE(int, conv_iter, 0, 1) { chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id); Tensor edge_src_prob = ColIndexReduce(chain_node_prob, edge_id2src_id); @@ -269,13 +263,17 @@ void AutoPlacementMemoryDemo() { Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob))); FOR_RANGE(int, i, 0, edge_prob.shape().At(0)) { FOR_RANGE(int, j, 0, edge_prob.shape().At(1)) { - int64_t src_fw_id = - chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0); - int64_t dst_fw_id = - chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0); - fw_var.At(i, src_fw_id) += fw_var.At(i, dst_fw_id); - fw_var.At(i, src_fw_id) /= 2; - fw_var.At(i, dst_fw_id) = fw_var.At(i, src_fw_id); + if (edge_prob.At(i, j) > 0.2) { + int64_t src_fw_id = + chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0); + int64_t dst_fw_id = + chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0); + double avg = + (old_fw_var.At(i, src_fw_id) + old_fw_var.At(i, dst_fw_id)) + / 2; + fw_var.At(i, src_fw_id) = avg; + fw_var.At(i, dst_fw_id) = avg; + } } } }