提交 74e4ea27 编写于 作者: X Xinqi Li

run fast


Former-commit-id: 1d8714eae417c30418e55d4922ea6e539e2c3e0e
上级 5371f089
...@@ -109,12 +109,24 @@ size_t DemoChainGraph::ChainNodeNum() const { ...@@ -109,12 +109,24 @@ size_t DemoChainGraph::ChainNodeNum() const {
} }
DemoChainGraph::DemoChainGraph( DemoChainGraph::DemoChainGraph(
int piece_num_in_batch,
const std::function<void(DemoChainGraphBuilder*)>& Build) const std::function<void(DemoChainGraphBuilder*)>& Build)
: chain_node_id_(-1), chain_regst_id_(-1) { : chain_node_id_(-1),
chain_regst_id_(-1),
piece_num_in_batch_(piece_num_in_batch) {
DemoChainGraphBuilder builder(this); DemoChainGraphBuilder builder(this);
Build(&builder); Build(&builder);
InitIsReachable(); InitIsReachable();
InitRegst2ChainNodeSubGraphs(); InitRegst2ChainNodeSubGraphs();
InitChainNodeId2FwChainNodeId();
InitChainRegstId2ProducerChainNodeId();
InitChainRegstId2PathChainNodeIds();
InitEdgeId2SrcChainNodeId();
InitEdgeId2DstChainNodeId();
InitEdgeId2ChainRegstId();
InitChainNodeId2ChainNodeName();
InitChainRegstId2IsCloned();
InitChainRegstId2IIScale();
} }
IsReachablePredicator DemoChainGraph::MakeIsReachablePredicator() const { IsReachablePredicator DemoChainGraph::MakeIsReachablePredicator() const {
...@@ -180,33 +192,31 @@ void DemoChainGraph::InitRegst2ChainNodeSubGraphs() { ...@@ -180,33 +192,31 @@ void DemoChainGraph::InitRegst2ChainNodeSubGraphs() {
} }
} }
std::vector<std::vector<int64_t>> void DemoChainGraph::InitChainNodeId2FwChainNodeId() {
DemoChainGraph::CalcChainNodeId2FwChainNodeId() const {
std::vector<std::vector<int64_t>> ret(ChainNodeNum()); std::vector<std::vector<int64_t>> ret(ChainNodeNum());
ForEachNode([&](const DemoChainNode* node) { ForEachNode([&](const DemoChainNode* node) {
CHECK_LT(node->chain_node_id(), ret.size()); CHECK_LT(node->chain_node_id(), ret.size());
ret.at(node->chain_node_id()).push_back(node->fw_chain_node_id()); ret.at(node->chain_node_id()).push_back(node->fw_chain_node_id());
}); });
return ret; chain_node_id2fw_chain_node_id_ = ret;
} }
std::vector<std::vector<int64_t>> void DemoChainGraph::InitChainRegstId2ProducerChainNodeId() {
DemoChainGraph::CalcChainRegstId2ProducerChainNodeId() const {
std::vector<std::vector<int64_t>> ret(regsts_.size()); std::vector<std::vector<int64_t>> ret(regsts_.size());
for (const auto& regst : regsts_) { for (const auto& regst : regsts_) {
int64_t chain_node_id = regst->producer()->chain_node_id(); int64_t chain_node_id = regst->producer()->chain_node_id();
CHECK_LT(regst->chain_regst_id(), ret.size()); CHECK_LT(regst->chain_regst_id(), ret.size());
ret.at(regst->chain_regst_id()).push_back(chain_node_id); ret.at(regst->chain_regst_id()).push_back(chain_node_id);
} }
return ret; chain_regst_id2producer_chain_node_id_ = ret;
} }
std::vector<std::string> DemoChainGraph::CalcChainNodeId2ChainNodeName() const { void DemoChainGraph::InitChainNodeId2ChainNodeName() {
std::vector<std::string> ret(node_num()); std::vector<std::string> ret(node_num());
ForEachNode([&](const DemoChainNode* node) { ForEachNode([&](const DemoChainNode* node) {
ret.at(node->chain_node_id()) = node->name(); ret.at(node->chain_node_id()) = node->name();
}); });
return ret; chain_node_id2chain_node_name_ = ret;
} }
void DemoChainNodeSubGraph::TopoForEachChainNode( void DemoChainNodeSubGraph::TopoForEachChainNode(
...@@ -299,54 +309,52 @@ DemoChainGraph::CalcChainRegstId2PathChainNodeIds( ...@@ -299,54 +309,52 @@ DemoChainGraph::CalcChainRegstId2PathChainNodeIds(
return ret; return ret;
} }
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2SrcChainNodeId() void DemoChainGraph::InitEdgeId2SrcChainNodeId() {
const {
std::vector<std::vector<int64_t>> ret(edge_num()); std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1; int index = -1;
ForEachEdge([&](DemoChainEdge* edge) { ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->src_chain_node_id()}; ret.at(++index) = std::vector<int64_t>{edge->src_chain_node_id()};
}); });
return ret; edge_id2src_chain_node_id_ = ret;
} }
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2DstChainNodeId() void DemoChainGraph::InitEdgeId2DstChainNodeId() {
const {
std::vector<std::vector<int64_t>> ret(edge_num()); std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1; int index = -1;
ForEachEdge([&](DemoChainEdge* edge) { ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->dst_chain_node_id()}; ret.at(++index) = std::vector<int64_t>{edge->dst_chain_node_id()};
}); });
return ret; edge_id2dst_chain_node_id_ = ret;
} }
std::vector<std::vector<int64_t>> DemoChainGraph::CalcEdgeId2RegstId() const { void DemoChainGraph::InitEdgeId2ChainRegstId() {
std::vector<std::vector<int64_t>> ret(edge_num()); std::vector<std::vector<int64_t>> ret(edge_num());
int index = -1; int index = -1;
ForEachEdge([&](DemoChainEdge* edge) { ForEachEdge([&](DemoChainEdge* edge) {
ret.at(++index) = std::vector<int64_t>{edge->chain_regst_id()}; ret.at(++index) = std::vector<int64_t>{edge->chain_regst_id()};
}); });
return ret; edge_id2chain_regst_id_ = ret;
} }
std::vector<double> DemoChainGraph::RegstId2IsCloned() const { void DemoChainGraph::InitChainRegstId2IsCloned() {
std::vector<double> ret(regsts_.size()); std::vector<double> ret(regsts_.size());
for (const auto& regst : regsts_) { for (const auto& regst : regsts_) {
ret.at(regst->chain_regst_id()) = (regst->IsRegstCloned() ? 1 : 0); ret.at(regst->chain_regst_id()) = (regst->IsRegstCloned() ? 1 : 0);
} }
return ret; chain_regst_id2is_cloned_ = ret;
} }
std::vector<double> DemoChainGraph::RegstIIRatio(int piece_num_in_batch) const { void DemoChainGraph::InitChainRegstId2IIScale() {
std::vector<double> ret(regsts_.size()); std::vector<double> ret(regsts_.size());
for (const auto& regst : regsts_) { for (const auto& regst : regsts_) {
double ii_ratio = 1; double ii_scale = 1;
if (regst->producer()->task_type() == TaskType::kMdDiffAcc if (regst->producer()->task_type() == TaskType::kMdDiffAcc
|| regst->producer()->task_type() == TaskType::kMdUpdt) { || regst->producer()->task_type() == TaskType::kMdUpdt) {
ii_ratio = piece_num_in_batch; ii_scale = piece_num_in_batch_;
} }
ret.at(regst->chain_regst_id()) = ii_ratio; ret.at(regst->chain_regst_id()) = ii_scale;
} }
return ret; chain_regst_id2ii_scale_ = ret;
} }
bool DemoChainRegst::IsRegstCloned() const { bool DemoChainRegst::IsRegstCloned() const {
......
...@@ -133,34 +133,48 @@ class DemoChainGraphBuilder; ...@@ -133,34 +133,48 @@ class DemoChainGraphBuilder;
class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> { class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DemoChainGraph); OF_DISALLOW_COPY_AND_MOVE(DemoChainGraph);
DemoChainGraph(const std::function<void(DemoChainGraphBuilder*)>& Build); DemoChainGraph(int piece_num_in_batch,
const std::function<void(DemoChainGraphBuilder*)>& Build);
virtual ~DemoChainGraph() = default; virtual ~DemoChainGraph() = default;
size_t FwChainNodeNum() const; size_t FwChainNodeNum() const;
size_t ChainNodeNum() const; size_t ChainNodeNum() const;
std::vector<std::vector<int64_t>> CalcChainNodeId2FwChainNodeId() const; const std::vector<std::vector<int64_t>>& chain_node_id2fw_chain_node_id()
const {
std::vector<std::vector<int64_t>> CalcChainRegstId2ProducerChainNodeId() return chain_node_id2fw_chain_node_id_;
const; }
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds(
const std::function<double(int64_t)>& GetTime) const;
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds() const { const std::vector<std::vector<int64_t>>&
return CalcChainRegstId2PathChainNodeIds( chain_regst_id2producer_chain_node_id() const {
[](int64_t) -> double { return 1; }); return chain_regst_id2producer_chain_node_id_;
} }
std::vector<std::vector<int64_t>> CalcEdgeId2SrcChainNodeId() const; const std::vector<std::vector<int64_t>>& chain_regst_id2path_chain_node_ids()
std::vector<std::vector<int64_t>> CalcEdgeId2DstChainNodeId() const; const {
std::vector<std::vector<int64_t>> CalcEdgeId2RegstId() const; return chain_regst_id2path_chain_node_ids_;
}
std::vector<std::string> CalcChainNodeId2ChainNodeName() const; const std::vector<std::vector<int64_t>>& edge_id2src_chain_node_id() const {
return edge_id2src_chain_node_id_;
}
const std::vector<std::vector<int64_t>>& edge_id2dst_chain_node_id() const {
return edge_id2dst_chain_node_id_;
}
const std::vector<std::vector<int64_t>>& edge_id2chain_regst_id() const {
return edge_id2chain_regst_id_;
}
const std::vector<std::string>& chain_node_id2chain_node_name() const {
return chain_node_id2chain_node_name_;
}
std::vector<double> RegstId2IsCloned() const; const std::vector<double>& chain_regst_id2is_cloned() const {
return chain_regst_id2is_cloned_;
}
std::vector<double> RegstIIRatio(int piece_num_in_batch) const; const std::vector<double>& chain_regst_id2ii_scale() const {
return chain_regst_id2ii_scale_;
}
private: private:
friend class DemoChainGraphBuilder; friend class DemoChainGraphBuilder;
...@@ -176,12 +190,44 @@ class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> { ...@@ -176,12 +190,44 @@ class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
const DemoChainNode* node, const DemoChainNode* node,
const std::function<void(const DemoChainNode*)>& Handler) const; const std::function<void(const DemoChainNode*)>& Handler) const;
void InitChainNodeId2FwChainNodeId();
void InitChainRegstId2ProducerChainNodeId();
std::vector<std::vector<int64_t>> CalcChainRegstId2PathChainNodeIds(
const std::function<double(int64_t)>& GetTime) const;
void InitChainRegstId2PathChainNodeIds() {
chain_regst_id2path_chain_node_ids_ =
CalcChainRegstId2PathChainNodeIds([](int64_t) -> double { return 1; });
}
void InitEdgeId2SrcChainNodeId();
void InitEdgeId2DstChainNodeId();
void InitEdgeId2ChainRegstId();
void InitChainNodeId2ChainNodeName();
void InitChainRegstId2IsCloned();
void InitChainRegstId2IIScale();
int64_t chain_node_id_; int64_t chain_node_id_;
int64_t chain_regst_id_; int64_t chain_regst_id_;
int64_t piece_num_in_batch_;
std::list<std::unique_ptr<DemoChainRegst>> regsts_; std::list<std::unique_ptr<DemoChainRegst>> regsts_;
IsReachablePredicator is_reachable_; IsReachablePredicator is_reachable_;
HashMap<const DemoChainRegst*, std::unique_ptr<DemoChainNodeSubGraph>> HashMap<const DemoChainRegst*, std::unique_ptr<DemoChainNodeSubGraph>>
regst2chain_node_sub_graph_; regst2chain_node_sub_graph_;
std::vector<std::vector<int64_t>> chain_node_id2fw_chain_node_id_;
std::vector<std::vector<int64_t>> chain_regst_id2producer_chain_node_id_;
std::vector<std::vector<int64_t>> chain_regst_id2path_chain_node_ids_;
std::vector<std::vector<int64_t>> edge_id2src_chain_node_id_;
std::vector<std::vector<int64_t>> edge_id2dst_chain_node_id_;
std::vector<std::vector<int64_t>> edge_id2chain_regst_id_;
std::vector<std::string> chain_node_id2chain_node_name_;
std::vector<double> chain_regst_id2is_cloned_;
std::vector<double> chain_regst_id2ii_scale_;
}; };
class DemoChainGraphBuilder final { class DemoChainGraphBuilder final {
......
...@@ -7,7 +7,7 @@ namespace df { ...@@ -7,7 +7,7 @@ namespace df {
namespace test { namespace test {
TEST(DemoChainGraph, simple_without_model) { TEST(DemoChainGraph, simple_without_model) {
DemoChainGraph graph([](DemoChainGraphBuilder* builder) { DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) {
builder->Backward(builder->Op( builder->Backward(builder->Op(
"soft_max", {builder->Op("feature"), builder->Op("label")})); "soft_max", {builder->Op("feature"), builder->Op("label")}));
}); });
...@@ -15,22 +15,22 @@ TEST(DemoChainGraph, simple_without_model) { ...@@ -15,22 +15,22 @@ TEST(DemoChainGraph, simple_without_model) {
ASSERT_EQ(graph.ChainNodeNum(), 3 * 2); ASSERT_EQ(graph.ChainNodeNum(), 3 * 2);
std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {2}, std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {2},
{2}, {1}, {0}}; {2}, {1}, {0}};
ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids); ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids);
std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {1}, {2}, std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {1}, {2},
{2}, {3}, {3}}; {2}, {3}, {3}};
ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId() ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id()
== expected_producer_ids); == expected_producer_ids);
std::vector<std::vector<int64_t>> expected_path{ std::vector<std::vector<int64_t>> expected_path{
{0, 2, 3, 5}, {1, 2, 3, 4}, {2, 3}, {2, 3}, {3, 4}, {3, 5}}; {0, 2, 3, 5}, {1, 2, 3, 4}, {2, 3}, {2, 3}, {3, 4}, {3, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path); ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path);
std::vector<double> expected_regst_id2is_cloned{0, 0, 0, 0, 0, 0}; std::vector<double> expected_regst_id2is_cloned{0, 0, 0, 0, 0, 0};
ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned); ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned);
} }
TEST(DemoChainGraph, simple_with_model) { TEST(DemoChainGraph, simple_with_model) {
DemoChainGraph graph([](DemoChainGraphBuilder* builder) { DemoChainGraph graph(1, [](DemoChainGraphBuilder* builder) {
builder->Backward( builder->Backward(
builder->Op("op0", {builder->Op("data"), builder->Model("model")})); builder->Op("op0", {builder->Op("data"), builder->Model("model")}));
}); });
...@@ -39,18 +39,18 @@ TEST(DemoChainGraph, simple_with_model) { ...@@ -39,18 +39,18 @@ TEST(DemoChainGraph, simple_with_model) {
std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {1}, std::vector<std::vector<int64_t>> expected_fw_ids{{0}, {1}, {1},
{1}, {1}, {0}}; {1}, {1}, {0}};
ASSERT_TRUE(graph.CalcChainNodeId2FwChainNodeId() == expected_fw_ids); ASSERT_TRUE(graph.chain_node_id2fw_chain_node_id() == expected_fw_ids);
std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {4}, {1}, {1}, std::vector<std::vector<int64_t>> expected_producer_ids{{0}, {4}, {1}, {1},
{2}, {3}, {2}}; {2}, {3}, {2}};
ASSERT_TRUE(graph.CalcChainRegstId2ProducerChainNodeId() ASSERT_TRUE(graph.chain_regst_id2producer_chain_node_id()
== expected_producer_ids); == expected_producer_ids);
std::vector<std::vector<int64_t>> expected_path{ std::vector<std::vector<int64_t>> expected_path{
{0, 1, 2, 5}, {4, 1, 2}, {1, 2}, {1, 2}, {2, 3}, {3, 4}, {2, 5}}; {0, 1, 2, 5}, {4, 1, 2}, {1, 2}, {1, 2}, {2, 3}, {3, 4}, {2, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path); ASSERT_TRUE(graph.chain_regst_id2path_chain_node_ids() == expected_path);
std::vector<double> expected_regst_id2is_cloned{0, 1, 0, 0, 1, 1, 0}; std::vector<double> expected_regst_id2is_cloned{0, 1, 0, 0, 1, 1, 0};
ASSERT_TRUE(graph.RegstId2IsCloned() == expected_regst_id2is_cloned); ASSERT_TRUE(graph.chain_regst_id2is_cloned() == expected_regst_id2is_cloned);
} }
} // namespace test } // namespace test
......
...@@ -19,13 +19,6 @@ Tensor CalcDeviceComputeTime(const Tensor& prob_matrix) { ...@@ -19,13 +19,6 @@ Tensor CalcDeviceComputeTime(const Tensor& prob_matrix) {
Tensor CalcTaskNodeTime(const Tensor& chain_node_placement) { Tensor CalcTaskNodeTime(const Tensor& chain_node_placement) {
return chain_node_placement; return chain_node_placement;
// Tensor compute_time = CalcTaskNodeComputeTime(chain_node_placement);
// Tensor col_ones(Shape({chain_node_placement.shape().At(1)}), 1);
// return TensorProduct(MatrixRowSum(compute_time), col_ones);
// auto compute_time_copies = Clone(compute_time, 2);
// Tensor row_sum =
// TensorProduct(MatrixRowSum(compute_time_copies.at(0)), col_ones);
// return Mul(Tensor(0.5), ADD(row_sum, compute_time_copies.at(1)));
} }
Tensor CalcRegstDuration(const Tensor& chain_node_placement, Tensor CalcRegstDuration(const Tensor& chain_node_placement,
...@@ -33,16 +26,14 @@ Tensor CalcRegstDuration(const Tensor& chain_node_placement, ...@@ -33,16 +26,14 @@ Tensor CalcRegstDuration(const Tensor& chain_node_placement,
Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1); Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
Tensor task_node_time = CalcTaskNodeTime(chain_node_placement); Tensor task_node_time = CalcTaskNodeTime(chain_node_placement);
Tensor chain_node_time = MatrixColSum(task_node_time); Tensor chain_node_time = MatrixColSum(task_node_time);
auto GetTime = [chain_node_time](int64_t chain_node_id) -> double { const auto& regst2path = chain_graph.chain_regst_id2path_chain_node_ids();
return chain_node_time.At(chain_node_id);
};
auto regst2path = chain_graph.CalcChainRegstId2PathChainNodeIds(GetTime);
return ColIndexReduce(TensorProduct(row_ones, chain_node_time), regst2path); return ColIndexReduce(TensorProduct(row_ones, chain_node_time), regst2path);
} }
Tensor CalcRegstMemory(const Tensor& chain_node_placement, Tensor CalcRegstMemory(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph) { const DemoChainGraph& chain_graph) {
auto regst2producer = chain_graph.CalcChainRegstId2ProducerChainNodeId(); const auto& regst2producer =
chain_graph.chain_regst_id2producer_chain_node_id();
int64_t regst_num = regst2producer.size(); int64_t regst_num = regst2producer.size();
Tensor regst_placement = ColIndexReduce(chain_node_placement, regst2producer); Tensor regst_placement = ColIndexReduce(chain_node_placement, regst2producer);
Tensor row_ones(Shape({regst_placement.shape().At(0)}), 1); Tensor row_ones(Shape({regst_placement.shape().At(0)}), 1);
...@@ -51,7 +42,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement, ...@@ -51,7 +42,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement,
Tensor split_workload_ratio = ElemWiseDiv(copies.at(1), col_sum); Tensor split_workload_ratio = ElemWiseDiv(copies.at(1), col_sum);
Tensor clone_workload_ratio = copies.at(2); Tensor clone_workload_ratio = copies.at(2);
Tensor clone_weight = TensorProduct( Tensor clone_weight = TensorProduct(
row_ones, Tensor(Shape({regst_num}), chain_graph.RegstId2IsCloned())); row_ones,
Tensor(Shape({regst_num}), chain_graph.chain_regst_id2is_cloned()));
auto clone_weight_copies = Clone(clone_weight, 2); auto clone_weight_copies = Clone(clone_weight, 2);
return ADD(ElemWiseMul(clone_workload_ratio, clone_weight_copies.at(0)), return ADD(ElemWiseMul(clone_workload_ratio, clone_weight_copies.at(0)),
ElemWiseMul(split_workload_ratio, ElemWiseMul(split_workload_ratio,
...@@ -59,8 +51,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement, ...@@ -59,8 +51,8 @@ Tensor CalcRegstMemory(const Tensor& chain_node_placement,
} }
Tensor CalcIIRatio(const Tensor& chain_node_placement, Tensor CalcIIRatio(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph, int piece_num_in_batch) { const DemoChainGraph& chain_graph) {
auto ii_ratios = chain_graph.RegstIIRatio(piece_num_in_batch); const auto& ii_ratios = chain_graph.chain_regst_id2ii_scale();
int64_t regst_num = ii_ratios.size(); int64_t regst_num = ii_ratios.size();
Tensor ii_ratio_tensor(Shape({regst_num}), ii_ratios); Tensor ii_ratio_tensor(Shape({regst_num}), ii_ratios);
Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1); Tensor row_ones(Shape({chain_node_placement.shape().At(0)}), 1);
...@@ -69,11 +61,9 @@ Tensor CalcIIRatio(const Tensor& chain_node_placement, ...@@ -69,11 +61,9 @@ Tensor CalcIIRatio(const Tensor& chain_node_placement,
Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement, Tensor CalcDeviceMemBasicConsumed(const Tensor& chain_node_placement,
Tensor regst_duration, Tensor regst_duration,
const DemoChainGraph& chain_graph, const DemoChainGraph& chain_graph) {
int piece_num_in_batch) {
Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph); Tensor regst_mem = CalcRegstMemory(chain_node_placement, chain_graph);
Tensor ii_ratio = Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch);
return MatrixRowSum( return MatrixRowSum(
ElemWiseMul(ElemWiseMul(ii_ratio, regst_duration), regst_mem)); ElemWiseMul(ElemWiseMul(ii_ratio, regst_duration), regst_mem));
} }
...@@ -83,12 +73,12 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob, ...@@ -83,12 +73,12 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph) { const DemoChainGraph& chain_graph) {
auto chain_node_prob_copies = Clone(chain_node_prob, 2); auto chain_node_prob_copies = Clone(chain_node_prob, 2);
Tensor edge_src_prob = ColIndexReduce( Tensor edge_src_prob = ColIndexReduce(
chain_node_prob_copies.at(0), chain_graph.CalcEdgeId2SrcChainNodeId()); chain_node_prob_copies.at(0), chain_graph.edge_id2src_chain_node_id());
Tensor edge_dst_prob = ColIndexReduce( Tensor edge_dst_prob = ColIndexReduce(
chain_node_prob_copies.at(1), chain_graph.CalcEdgeId2DstChainNodeId()); chain_node_prob_copies.at(1), chain_graph.edge_id2dst_chain_node_id());
Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob))); Tensor edge_prob = Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
Tensor edge_regst_duration_prob = Tensor edge_regst_duration_prob =
ColIndexReduce(regst_duration, chain_graph.CalcEdgeId2RegstId()); ColIndexReduce(regst_duration, chain_graph.edge_id2chain_regst_id());
Tensor copied_task_regst_prob = Tensor copied_task_regst_prob =
ElemWiseMul(edge_prob, edge_regst_duration_prob); ElemWiseMul(edge_prob, edge_regst_duration_prob);
return MatrixRowSum(copied_task_regst_prob); return MatrixRowSum(copied_task_regst_prob);
...@@ -104,29 +94,26 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob, ...@@ -104,29 +94,26 @@ Tensor CalcDeviceCopiedRegstMem(const Tensor& chain_node_prob,
} }
Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob, Tensor CalcDeviceMemConsumed(const Tensor& chain_node_prob,
const DemoChainGraph& chain_graph, const DemoChainGraph& chain_graph) {
int piece_num_in_batch) {
auto chain_node_prob_copies = Clone(chain_node_prob, 3); auto chain_node_prob_copies = Clone(chain_node_prob, 3);
Tensor regst_duration = Tensor regst_duration =
CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph); CalcRegstDuration(chain_node_prob_copies.at(2), chain_graph);
auto regst_duration_copies = Clone(regst_duration, 2); auto regst_duration_copies = Clone(regst_duration, 2);
return ADD( return ADD(
CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0), CalcDeviceMemBasicConsumed(chain_node_prob_copies.at(0),
regst_duration_copies.at(0), chain_graph, regst_duration_copies.at(0), chain_graph),
piece_num_in_batch),
CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1), CalcDeviceCopiedRegstMem(chain_node_prob_copies.at(1),
regst_duration_copies.at(1), chain_graph)); regst_duration_copies.at(1), chain_graph));
} }
Tensor CalcDeviceMemII(const Tensor& chain_node_placement, Tensor CalcDeviceMemII(const Tensor& chain_node_placement,
const DemoChainGraph& chain_graph, const DemoChainGraph& chain_graph,
int piece_num_in_batch, double mem_size_per_device) { double mem_size_per_device) {
auto placement_copies = Clone(chain_node_placement, 2); auto placement_copies = Clone(chain_node_placement, 2);
Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph); Tensor regst_mem = CalcRegstMemory(placement_copies.at(0), chain_graph);
Tensor regst_duration = Tensor regst_duration =
CalcRegstDuration(placement_copies.at(1), chain_graph); CalcRegstDuration(placement_copies.at(1), chain_graph);
Tensor ii_ratio = Tensor ii_ratio = CalcIIRatio(chain_node_placement, chain_graph);
CalcIIRatio(chain_node_placement, chain_graph, piece_num_in_batch);
auto ii_ratio_copies = Clone(ii_ratio, 2); auto ii_ratio_copies = Clone(ii_ratio, 2);
auto regst_mem_copies = Clone(regst_mem, 2); auto regst_mem_copies = Clone(regst_mem, 2);
Tensor weighted_mem_time = Tensor weighted_mem_time =
...@@ -154,34 +141,45 @@ Tensor ProbabilityMatrix(Tensor* var, double lr) { ...@@ -154,34 +141,45 @@ Tensor ProbabilityMatrix(Tensor* var, double lr) {
return ElemWiseDiv(x_copies.at(1), x_col_sum); return ElemWiseDiv(x_copies.at(1), x_col_sum);
} }
std::function<double()> MakeFlation(int keep, double ratio) {
std::shared_ptr<int> exec_cnt(new int(-1));
return [=]() {
if (++(*exec_cnt) < keep) { return 1.0; }
return 1.0 / (((*exec_cnt) - keep) * ratio + 1.0);
};
}
std::function<double()> MakeFlation(int keep) {
return MakeFlation(keep, 0.005);
}
void AutoPlacementMemoryDemo() { void AutoPlacementMemoryDemo() {
std::random_device rd{}; std::random_device rd{};
std::mt19937 gen{rd()}; std::mt19937 gen{rd()};
std::normal_distribution<double> distr(1, 0.01); std::normal_distribution<double> distr(1, 0.01);
DemoChainGraph chain_graph([](DemoChainGraphBuilder* builder) { DemoChainGraph chain_graph(4, [](DemoChainGraphBuilder* builder) {
auto regst = builder->ModelOp("op0"); auto regst = builder->ModelOp("op0");
FOR_RANGE(int, i, 1, 63) { FOR_RANGE(int, i, 1, 63) {
regst = builder->ModelOp("op" + std::to_string(i), {regst}); regst = builder->ModelOp("op" + std::to_string(i), {regst});
} }
builder->Backward(builder->ModelOp("loss", {regst})); builder->Backward(builder->ModelOp("loss", {regst}));
}); });
auto chain_node_id2fw_id = chain_graph.CalcChainNodeId2FwChainNodeId(); const auto& chain_node_id2fw_id =
chain_graph.chain_node_id2fw_chain_node_id();
int64_t fw_node_num = chain_graph.FwChainNodeNum(); int64_t fw_node_num = chain_graph.FwChainNodeNum();
Shape shape({4, fw_node_num}); Shape shape({2, fw_node_num});
Tensor fw_var(shape, [&](size_t index) { return distr(gen); }); Tensor fw_var(shape, [&](size_t index) { return distr(gen); });
Tensor fw_prob; Tensor fw_prob;
auto chain_node_id2name = chain_graph.CalcChainNodeId2ChainNodeName(); const auto& chain_node_id2name = chain_graph.chain_node_id2chain_node_name();
double bugo = 2; double bugo = 2;
double rethink_threshold = 60; double rethink_threshold = 20;
int rethink_cnt = -1; Tensor decision_ratio(Shape({fw_node_num}), [&](int64_t index) {
int deep_think = 1; return 1 + fw_node_num * 0.5 / (index + 1);
Tensor decision_ratio(Shape({fw_node_num}), });
[&](int64_t index) { return 1 + 10.0 / (index + 1); }); std::function<double()> MemFlation = MakeFlation(100);
double mem_importance_flation = 0;
FOR_RANGE(int, step, 0, 100000) { FOR_RANGE(int, step, 0, 100000) {
double lr = 0.01; double lr = 0.01;
if (step % (static_cast<int>(bugo += 0.05))) { if (step % (static_cast<int>(bugo += 0.05))) {
double mem_importance = 1 / (1 + (mem_importance_flation += 0.005));
fw_prob = ProbabilityMatrix(&fw_var, lr); fw_prob = ProbabilityMatrix(&fw_var, lr);
auto fw_prob_copies = Clone(fw_prob, 2); auto fw_prob_copies = Clone(fw_prob, 2);
Tensor chain_node_prob = Tensor chain_node_prob =
...@@ -189,9 +187,9 @@ void AutoPlacementMemoryDemo() { ...@@ -189,9 +187,9 @@ void AutoPlacementMemoryDemo() {
auto chain_prob_copies = Clone(chain_node_prob, 2); auto chain_prob_copies = Clone(chain_node_prob, 2);
Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0)); Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
Tensor dev_mem = Tensor dev_mem =
CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph, 4); CalcDeviceMemConsumed(chain_prob_copies.at(1), chain_graph);
Tensor normalized_dev_mem = Tensor normalized_dev_mem =
Mul(Tensor(1.65 * mem_importance), Sqrt(dev_mem)); Mul(Tensor(2.5 * MemFlation()), Sqrt(dev_mem));
Tensor fw_indecision = Tensor fw_indecision =
Mul(Sub(MatrixColSum(Sqrt(fw_prob_copies.at(1))), Tensor(1)), Mul(Sub(MatrixColSum(Sqrt(fw_prob_copies.at(1))), Tensor(1)),
decision_ratio); decision_ratio);
...@@ -249,17 +247,13 @@ void AutoPlacementMemoryDemo() { ...@@ -249,17 +247,13 @@ void AutoPlacementMemoryDemo() {
} }
std::cout << std::endl; std::cout << std::endl;
} }
if ((indecision.At(0) < rethink_threshold) if (indecision.At(0) < rethink_threshold) {
&& (rethink_threshold > 4 || !((++rethink_cnt) % deep_think))) { MemFlation = MakeFlation(100);
mem_importance_flation = 0; rethink_threshold -= 1;
if (rethink_threshold > 4) { const auto& edge_id2src_id = chain_graph.edge_id2src_chain_node_id();
rethink_threshold -= 1; const auto& edge_id2dst_id = chain_graph.edge_id2dst_chain_node_id();
} else { auto old_fw_var = fw_var.buffer();
deep_think += 5; FOR_RANGE(int, conv_iter, 0, 1) {
}
auto edge_id2src_id = chain_graph.CalcEdgeId2SrcChainNodeId();
auto edge_id2dst_id = chain_graph.CalcEdgeId2DstChainNodeId();
FOR_RANGE(int, conv_iter, 0, 3) {
chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id); chain_node_prob = ColIndexReduce(fw_prob, chain_node_id2fw_id);
Tensor edge_src_prob = Tensor edge_src_prob =
ColIndexReduce(chain_node_prob, edge_id2src_id); ColIndexReduce(chain_node_prob, edge_id2src_id);
...@@ -269,13 +263,17 @@ void AutoPlacementMemoryDemo() { ...@@ -269,13 +263,17 @@ void AutoPlacementMemoryDemo() {
Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob))); Mul(Tensor(0.5), Abs(Sub(edge_src_prob, edge_dst_prob)));
FOR_RANGE(int, i, 0, edge_prob.shape().At(0)) { FOR_RANGE(int, i, 0, edge_prob.shape().At(0)) {
FOR_RANGE(int, j, 0, edge_prob.shape().At(1)) { FOR_RANGE(int, j, 0, edge_prob.shape().At(1)) {
int64_t src_fw_id = if (edge_prob.At(i, j) > 0.2) {
chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0); int64_t src_fw_id =
int64_t dst_fw_id = chain_node_id2fw_id.at(edge_id2src_id.at(j).at(0)).at(0);
chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0); int64_t dst_fw_id =
fw_var.At(i, src_fw_id) += fw_var.At(i, dst_fw_id); chain_node_id2fw_id.at(edge_id2dst_id.at(j).at(0)).at(0);
fw_var.At(i, src_fw_id) /= 2; double avg =
fw_var.At(i, dst_fw_id) = fw_var.At(i, src_fw_id); (old_fw_var.At(i, src_fw_id) + old_fw_var.At(i, dst_fw_id))
/ 2;
fw_var.At(i, src_fw_id) = avg;
fw_var.At(i, dst_fw_id) = avg;
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册