提交 74e4ea27 编写于 作者: X Xinqi Li

run fast


Former-commit-id: 1d8714eae417c30418e55d4922ea6e539e2c3e0e
上级 5371f089
......@@ -109,12 +109,24 @@ size_t DemoChainGraph::ChainNodeNum() const {
}
DemoChainGraph::DemoChainGraph(
int piece_num_in_batch,
const std::function<void(DemoChainGraphBuilder*)>& Build)
: chain_node_id_(-1), chain_regst_id_(-1) {
: chain_node_id_(-1),
chain_regst_id_(-1),
piece_num_in_batch_(piece_num_in_batch) {
DemoChainGraphBuilder builder(this);
Build(&builder);
InitIsReachable();
InitRegst2ChainNodeSubGraphs();
InitChainNodeId2FwChainNodeId();
InitChainRegstId2ProducerChainNodeId();
InitChainRegstId2PathChainNodeIds();
InitEdgeId2SrcChainNodeId();
InitEdgeId2DstChainNodeId();
InitEdgeId2ChainRegstId();
InitChainNodeId2ChainNodeName();
InitChainRegstId2IsCloned();
InitChainRegstId2IIScale();
}
IsReachablePredicator DemoChainGraph::MakeIsReachablePredicator() const {
......@@ -180,33 +192,31 @@ void DemoChainGraph::InitRegst2ChainNodeSubGraphs() {
}
}
std::vector<std::vector<int64_t>>
DemoChainGraph::CalcChainNodeId2FwChainNodeId() const {
void DemoChainGraph::InitChainNodeId2FwChainNodeId() {
std::vector<std::vector<int64_t>> ret(ChainNodeNum());
ForEachNode([&](const DemoChainNode* node) {
CHECK_LT(node->chain_node_id(), ret.size());
ret.at(node->chain_node_id()).push_back(node->fw_chain_node_id());
});
return ret;
chain_node_id2fw_chain_node_id_ = ret;
}
std::vector<std::vector<int64_t>>
DemoChainGraph::CalcChainRegstId2ProducerChainNodeId() const {
void DemoChainGraph::InitChainRegstId2ProducerChainNodeId() {
std::vector<std::vector<int64_t>> ret(regsts_.size());
for (const auto& regst : regsts_) {
int64_t chain_node_id = regst->producer()->chain_node_id();
CHECK_LT(regst->chain_regst_id(), ret.size());
ret.at(regst->chain_regst_id()).push_back(chain_node_id);
}
return ret;
chain_regst_id2producer_chain_node_id_ = ret;
}
std::vector<std::string> DemoChainGraph::CalcChainNodeId2ChainNodeName() const {
void DemoChainGraph::InitChainNodeId2ChainNodeName() {
std::vector<std::string> ret(node_num());
ForEachNode([&](const DemoChainNode* node) {
ret.at(node->chain_node_id()) = node->name();
});
return ret;
chain_node_id2chain_node_name_ = ret;
}
void DemoChainNodeSubGraph::TopoForEachChainNode(
......@@ -299,54 +309,52 @@ DemoChainGraph::CalcChainRegstId2PathChainNodeIds(
return ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2SrcChainNodeId()
const {
void DemoChainGraph::InitEdgeId2SrcChainNodeId() {
std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1;
ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->src_chain_node_id()};
});
return ret;
edge_id2src_chain_node_id_ = ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2DstChainNodeId()
const {
void DemoChainGraph::InitEdgeId2DstChainNodeId() {
std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1;
ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->dst_chain_node_id()};
});
return ret;
edge_id2dst_chain_node_id_ = ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2RegstId() const {
void DemoChainGraph::InitEdgeId2ChainRegstId() {
std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1;
ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->chain_regst_id()};
});
return ret;
edge_id2chain_regst_id_ = ret;
}
std::vector<double> DemoChainGraph::RegstId2IsCloned() const {
void DemoChainGraph::InitChainRegstId2IsCloned() {
std::vector<double> ret(regsts_.size());
for (const auto& regst : regsts_) {
ret.at(regst->chain_regst_id()) = (regst->IsRegstCloned() ? 1 : 0);
}
return ret;
chain_regst_id2is_cloned_ = ret;
}
std::vector<double> DemoChainGraph::RegstIIRatio(int piece_num_in_batch) const {
void DemoChainGraph::InitChainRegstId2IIScale() {
std::vector<double> ret(regsts_.size());
for (const auto& regst : regsts_) {
double ii_ratio = 1;
double ii_scale = 1;
if (regst->producer()->task_type() == TaskType::kMdDiffAcc
|| regst->producer()->task_type() == TaskType::kMdUpdt) {
ii_ratio = piece_num_in_batch;
ii_scale = piece_num_in_batch_;
}
ret.at(regst->chain_regst_id()) = ii_ratio;
ret.at(regst->chain_regst_id()) = ii_scale;
}
return ret;
chain_regst_id2ii_scale_ = ret;
}
bool DemoChainRegst::IsRegstCloned() const {
......
......@@ -133,34 +133,48 @@ class DemoChainGraphBuilder;
class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(DemoChainGraph);
DemoChainGraph(const std::function<void(DemoChainGraphBuilder*)>& Build);
DemoChainGraph(int piece_num_in_batch,
const std::function<void(DemoChainGraphBuilder*)>& Build);
virtual ~DemoChainGraph() = default;
size_t FwChainNodeNum() const;
size_t ChainNodeNum() const;
std::vector<std::vector<int64_t>> CalcChainNodeId2FwChainNodeId() const;
std::vector<std::vector<int64_t>> CalcChainRegstId2ProducerChainNodeId()
const;
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds(
const std::function<double(int64_t)>& GetTime) const;
const std::vector<std::vector<int64_t>>& chain_node_id2fw_chain_node_id()
const {
return chain_node_id2fw_chain_node_id_;
}
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds() const {
return CalcChainRegstId2PathChainNodeIds(
[](int64_t) -> double { return 1; });
const std::vector<std::vector<int64_t>>&
chain_regst_id2producer_chain_node_id() const {
return chain_regst_id2producer_chain_node_id_;
}
std::vector<std::vector<int64_t>> CalcEdgeId2SrcChainNodeId() const;
std::vector<std::vector<int64_t>> CalcEdgeId2DstChainNodeId() const;
std::vector<std::vector<int64_t>> CalcEdgeId2RegstId() const;
const std::vector<std::vector<int64_t>>& chain_regst_id2path_chain_node_ids()
const {
return chain_regst_id2path_chain_node_ids_;
}
std::vector<std::string> CalcChainNodeId2ChainNodeName() const;
const std::vector<std::vector<int64_t>>& edge_id2src_chain_node_id() const {
return edge_id2src_chain_node_id_;
}
const std::vector<std::vector<int64_t>>& edge_id2dst_chain_node_id() const {
return edge_id2dst_chain_node_id_;
}
const std::vector<std::vector<int64_t>>& edge_id2chain_regst_id() const {
return edge_id2chain_regst_id_;
}
const std::vector<std::string>& chain_node_id2chain_node_name() const {
return chain_node_id2chain_node_name_;
}
std::vector<double> RegstId2IsCloned() const;
const std::vector<double>& chain_regst_id2is_cloned() const {
return chain_regst_id2is_cloned_;
}
std::vector<double> RegstIIRatio(int piece_num_in_batch) const;
const std::vector<double>& chain_regst_id2ii_scale() const {
return chain_regst_id2ii_scale_;
}
private:
friend class DemoChainGraphBuilder;
......@@ -176,12 +190,44 @@ class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
const DemoChainNode* node,
const std::function<void(const DemoChainNode*)>& Handler) const;
void InitChainNodeId2FwChainNodeId();
void InitChainRegstId2ProducerChainNodeId();
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds(
const std::function<double(int64_t)>& GetTime) const;
void InitChainRegstId2PathChainNodeIds() {
chain_regst_id2path_chain_node_ids_ =
CalcChainRegstId2PathChainNodeIds([](int64_t) -> double { return 1; });
}
void InitEdgeId2SrcChainNodeId();
void InitEdgeId2DstChainNodeId();
void InitEdgeId2ChainRegstId();
void InitChainNodeId2ChainNodeName();
void InitChainRegstId2IsCloned();
void InitChainRegstId2IIScale();
int64_t chain_node_id_;
int64_t chain_regst_id_;
int64_t piece_num_in_batch_;
std::list<std::unique_ptr<DemoChainRegst>> regsts_;
IsReachablePredicator is_reachable_;
HashMap<const DemoChainRegst*, std::unique_ptr<DemoChainNodeSubGraph>>
regst2chain_node_sub_graph_;
std::vector<std::vector<int64_t>> chain_node_id2fw_chain_node_id_;
std::vector<std::vector<int64_t>> chain_regst_id2producer_chain_node_id_;
std::vector<std::vector<int64_t>> chain_regst_id2path_chain_node_ids_;
std::vector<std::vector<int64_t>> edge_id2src_chain_node_id_;
std::vector<std::vector<int64_t>> edge_id2dst_chain_node_id_;
std::vector<std::vector<int64_t>> edge_id2chain_regst_id_;
std::vector<std::string> chain_node_id2chain_node_name_;
std::vector<double> chain_regst_id2is_cloned_;
std::vector<double> chain_regst_id2ii_scale_;
};
class DemoChainGraphBuilder final {
......
......@@ -7,7 +7,7 @@ namespace df {
namespace test {
TEST(DemoChainGraph, simple_without_model) {
DemoChainGraph graph([](DemoChainGraphBuilder* builder) {
DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) {
builder->Backward(builder->Op(
"soft_max", {builder->Op("feature"), builder->Op("label")}));
});
......@@ -15,22 +15,22 @@ TEST(DemoChainGraph, simple_without_model) {
ASSERT_EQ(graph.ChainNodeNum(), 3 * 2);
std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {2},
{2}, {1}, {0}};
ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids);
ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids);
std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {1}, {2},
{2}, {3}, {3}};
ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId()
ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id()
== expected_producer_ids);
std::vector<std::vector<int64_t>> expected_path{
{0, 2, 3, 5}, {1, 2, 3, 4}, {2, 3}, {2, 3}, {3, 4}, {3, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path);
ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path);
std::vector<double> expected_regst_id2is_cloned{0, 0, 0, 0, 0, 0};
ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned);
ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned);
}
TEST(DemoChainGraph, simple_with_model) {
DemoChainGraph graph([](DemoChainGraphBuilder* builder) {
DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) {
builder->Backward(
builder->Op("op0", {builder->Op("data"), builder->Model("model")}));
});
......@@ -39,18 +39,18 @@ TEST(DemoChainGraph, simple_with_model) {
std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {1},
{1}, {1}, {0}};
ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids);
ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids);
std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {4}, {1}, {1},
{2}, {3}, {2}};
ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId()
ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id()
== expected_producer_ids);
std::vector<std::vector<int64_t>> expected_path{
{0, 1, 2, 5}, {4, 1, 2}, {1, 2}, {1, 2}, {2, 3}, {3, 4}, {2, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path);
ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path);
std::vector<double> expected_regst_id2is_cloned{0, 1, 0, 0, 1, 1, 0};
ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned);
ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned);
}
} // namespace test
......
......@@ -19,13 +19,6 @@ Tensor CalcDeviceComputeTime(const Tensor& prob_matrix) {
Tensor CalcTaskNodeTime(const Tensor& chain_node_placement) {
return chain_node_placement;
// Tensor compute_time = CalcTaskNodeComputeTime(chain_node_placement);
// Tensor col_ones(Shape({chain_node_placement.shape().At(1)}), 1);
// return TensorProduct(MatrixRowSum(compute_time), col_ones);
// auto compute_time_copies = Clone(compute_time, 2);
// Tensor row_sum =
// TensorProduct(MatrixRowSum(compute_time_copies.at(0)), col_ones);
// return Mul(Tensor(0.5), ADD(row_sum, compute_time_copies.at(1)));
}
Tensor CalcRegstDuration(const Tensor& chain_node_placement,
......@@ -33,16 +26,14 @@ Tensor CalcRegstDuration(const Tensor& chain_node_placement,
Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
Tensor task_node_time = CalcTaskNodeTime(chain_node_placement);
Tensor chain_node_time = MatrixColSum(task_node_time);
auto GetTime = [chain_node_time](int64_t chain_node_id) -> double {
return chain_node_time.At(chain_node_id);
};
auto regst2path = chain_graph.CalcChainRegstId2PathChainNodeIds(GetTime);
const auto& regst2path = chain_graph.chain_regst_id2path_chain_node_ids();
return ColIndexReduce(TensorProduct(row_ones, chain_node_time), regst2path);
}
Tensor CalcRegstMemory(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph) {
auto regst2producer = chain_graph.CalcChainRegstId2ProducerChainNodeId();
const auto& regst2producer =
chain_graph.chain_regst_id2producer_chain_node_id();
int64_t regst_num = regst2producer.size();
Tensor regst_placement = ColIndexReduce(chain_node_placement, regst2producer);
Tensor row_ones(Shape({regst_placement.shape().At(0)}), 1);
......@@ -51,7 +42,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement,
Tensor split_workload_ratio = ElemWiseDiv(copies.at(1), col_sum);
Tensor clone_workload_ratio = copies.at(2);
Tensor clone_weight = TensorProduct(
row_ones, Tensor(Shape({regst_num}), chain_graph.RegstId2IsCloned()));
row_ones,
Tensor(Shape({regst_num}), chain_graph.chain_regst_id2is_cloned()));
auto clone_weight_copies = Clone(clone_weight, 2);
return ADD(ElemWiseMul(clone_workload_ratio, clone_weight_copies.at(0)),
ElemWiseMul(split_workload_ratio,
......@@ -59,8 +51,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement,
}
Tensor CalcIIRatio(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph, int piece_num_in_batch) {
auto ii_ratios = chain_graph.RegstIIRatio(piece_num_in_batch);
const DemoChainGraph& chain_graph) {
const auto& ii_ratios = chain_graph.chain_regst_id2ii_scale();
int64_t regst_num = ii_ratios.size();
Tensor ii_ratio_tensor(Shape({regst_num}), ii_ratios);
Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
......@@ -69,11 +61,9 @@ Tensor CalcIIRatio(const Tensor& chain_node_placement,
Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement,
Tensor regst_duration,
const DemoChainGraph& chain_graph,
int piece_num_in_batch) {
const DemoChainGraph& chain_graph) {
Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph);
Tensor ii_ratio =
CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch);
Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
return MatrixRowSum(
ElemWiseMul(ElemWiseMul(ii_ratio, regst_duration), regst_mem));
}
......@@ -83,12 +73,12 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph) {
auto chain_node_prob_copies = Clone(chain_node_prob, 2);
Tensor edge_src_prob = ColIndexReduce(
chain_node_prob_copies.at(0), chain_graph.CalcEdgeId2SrcChainNodeId());
chain_node_prob_copies.at(0), chain_graph.edge_id2src_chain_node_id());
Tensor edge_dst_prob = ColIndexReduce(
chain_node_prob_copies.at(1), chain_graph.CalcEdgeId2DstChainNodeId());
chain_node_prob_copies.at(1), chain_graph.edge_id2dst_chain_node_id());
Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
Tensor edge_regst_duration_prob =
ColIndexReduce(regst_duration, chain_graph.CalcEdgeId2RegstId());
ColIndexReduce(regst_duration, chain_graph.edge_id2chain_regst_id());
Tensor copied_task_regst_prob =
ElemWiseMul(edge_prob, edge_regst_duration_prob);
return MatrixRowSum(copied_task_regst_prob);
......@@ -104,29 +94,26 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
}
Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph,
int piece_num_in_batch) {
const DemoChainGraph& chain_graph) {
auto chain_node_prob_copies = Clone(chain_node_prob, 3);
Tensor regst_duration =
CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph);
auto regst_duration_copies = Clone(regst_duration, 2);
return ADD(
CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0),
regst_duration_copies.at(0), chain_graph,
piece_num_in_batch),
regst_duration_copies.at(0), chain_graph),
CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1),
regst_duration_copies.at(1), chain_graph));
}
Tensor CalcDeviceMemII(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph,
int piece_num_in_batch, double mem_size_per_device) {
double mem_size_per_device) {
auto placement_copies = Clone(chain_node_placement, 2);
Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph);
Tensor regst_duration =
CalcRegstDuration(placement_copies.at(1), chain_graph);
Tensor ii_ratio =
CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch);
Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
auto ii_ratio_copies = Clone(ii_ratio, 2);
auto regst_mem_copies = Clone(regst_mem, 2);
Tensor weighted_mem_time =
......@@ -154,34 +141,45 @@ Tensor ProbabilityMatrix(Tensor* var, double lr) {
return ElemWiseDiv(x_copies.at(1), x_col_sum);
}
std::function<double()> MakeFlation(int keep, double ratio) {
std::shared_ptr<int> exec_cnt(new int(-1));
return [=]() {
if (++(*exec_cnt) < keep) { return 1.0; }
return 1.0 / (((*exec_cnt) - keep) * ratio + 1.0);
};
}
std::function<double()> MakeFlation(int keep) {
return MakeFlation(keep, 0.005);
}
void AutoPlacementMemoryDemo() {
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<double> distr(1, 0.01);
DemoChainGraph chain_graph([](DemoChainGraphBuilder* builder) {
DemoChainGraph chain_graph(4, [](DemoChainGraphBuilder* builder) {
auto regst = builder->ModelOp("op0");
FOR_RANGE(int, i, 1, 63) {
regst = builder->ModelOp("op" + std::to_string(i), {regst});
}
builder->Backward(builder->ModelOp("loss", {regst}));
});
auto chain_node_id2fw_id = chain_graph.CalcChainNodeId2FwChainNodeId();
const auto& chain_node_id2fw_id =
chain_graph.chain_node_id2fw_chain_node_id();
int64_t fw_node_num = chain_graph.FwChainNodeNum();
Shape shape({4, fw_node_num});
Shape shape({2, fw_node_num});
Tensor fw_var(shape, [&](size_t index) { return distr(gen); });
Tensor fw_prob;
auto chain_node_id2name = chain_graph.CalcChainNodeId2ChainNodeName();
const auto& chain_node_id2name = chain_graph.chain_node_id2chain_node_name();
double bugo = 2;
double rethink_threshold = 60;
int rethink_cnt = -1;
int deep_think = 1;
Tensor decision_ratio(Shape({fw_node_num}),
[&](int64_t index) { return 1 + 10.0 / (index + 1); });
double mem_importance_flation = 0;
double rethink_threshold = 20;
Tensor decision_ratio(Shape({fw_node_num}), [&](int64_t index) {
return 1 + fw_node_num * 0.5 / (index + 1);
});
std::function<double()> MemFlation = MakeFlation(100);
FOR_RANGE(int, step, 0, 100000) {
double lr = 0.01;
if (step % (static_cast<int>(bugo += 0.05))) {
double mem_importance = 1 / (1 + (mem_importance_flation += 0.005));
fw_prob = ProbabilityMatrix(&fw_var, lr);
auto fw_prob_copies = Clone(fw_prob, 2);
Tensor chain_node_prob =
......@@ -189,9 +187,9 @@ void AutoPlacementMemoryDemo() {
auto chain_prob_copies = Clone(chain_node_prob, 2);
Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
Tensor dev_mem =
CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph, 4);
CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph);
Tensor normalized_dev_mem =
Mul(Tensor(1.65 * mem_importance), Sqrt(dev_mem));
Mul(Tensor(2.5 * MemFlation()), Sqrt(dev_mem));
Tensor fw_indecision =
Mul(Sub(MatrixColSum(Sqrt(fw_prob_copies.at(1))), Tensor(1)),
decision_ratio);
......@@ -249,17 +247,13 @@ void AutoPlacementMemoryDemo() {
}
std::cout << std::endl;
}
if ((indecision.At(0) < rethink_threshold)
&& (rethink_threshold > 4 || !((++rethink_cnt) % deep_think))) {
mem_importance_flation = 0;
if (rethink_threshold > 4) {
rethink_threshold -= 1;
} else {
deep_think += 5;
}
auto edge_id2src_id = chain_graph.CalcEdgeId2SrcChainNodeId();
auto edge_id2dst_id = chain_graph.CalcEdgeId2DstChainNodeId();
FOR_RANGE(int, conv_iter, 0, 3) {
if (indecision.At(0) < rethink_threshold) {
MemFlation = MakeFlation(100);
rethink_threshold -= 1;
const auto& edge_id2src_id = chain_graph.edge_id2src_chain_node_id();
const auto& edge_id2dst_id = chain_graph.edge_id2dst_chain_node_id();
auto old_fw_var = fw_var.buffer();
FOR_RANGE(int, conv_iter, 0, 1) {
chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id);
Tensor edge_src_prob =
ColIndexReduce(chain_node_prob, edge_id2src_id);
......@@ -269,13 +263,17 @@ void AutoPlacementMemoryDemo() {
Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
FOR_RANGE(int, i, 0, edge_prob.shape().At(0)) {
FOR_RANGE(int, j, 0, edge_prob.shape().At(1)) {
int64_t src_fw_id =
chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0);
int64_t dst_fw_id =
chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0);
fw_var.At(i, src_fw_id) += fw_var.At(i, dst_fw_id);
fw_var.At(i, src_fw_id) /= 2;
fw_var.At(i, dst_fw_id) = fw_var.At(i, src_fw_id);
if (edge_prob.At(i, j) > 0.2) {
int64_t src_fw_id =
chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0);
int64_t dst_fw_id =
chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0);
double avg =
(old_fw_var.At(i, src_fw_id) + old_fw_var.At(i, dst_fw_id))
/ 2;
fw_var.At(i, src_fw_id) = avg;
fw_var.At(i, dst_fw_id) = avg;
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册