提交 2e9f2e3d 编写于 作者: X Xinqi Li

print indecision


Former-commit-id: 35f038955856ba4282a2f20aff9626a44e648517
上级 e9c96eb5
......@@ -147,7 +147,7 @@ Tensor CalcDeviceMemII(const Tensor& chain_node_placement,
Tensor ProbabilityMatrix(Tensor* var, double lr) {
Tensor row_ones(Shape({var->shape().At(0)}), 1);
Tensor epsilon(0.000000001);
Tensor epsilon(0.000000000000000001);
Tensor x = ADD(Square(FixedExpectation(Update(var, lr), 1)), epsilon);
auto x_copies = Clone(x, 2);
Tensor x_col_sum = TensorProduct(row_ones, MatrixColSum(x_copies.at(0)));
......@@ -167,11 +167,8 @@ void AutoPlacementMemoryDemo() {
});
auto chain_node2fw_id = chain_graph.CalcChainNodeId2FwChainNodeId();
int64_t fw_node_num = chain_graph.FwChainNodeNum();
// std::cout << fw_node_num << std::endl;
// return;
Shape shape({2, fw_node_num});
Tensor fw_var(shape, [&](size_t index) { return distr(gen); });
Tensor floor_tensor(shape, 0.000000001);
Tensor fw_prob;
auto chain_node_id2name = chain_graph.CalcChainNodeId2ChainNodeName();
double bugo = 2;
......@@ -179,16 +176,18 @@ void AutoPlacementMemoryDemo() {
double lr = 0.01;
fw_prob = ProbabilityMatrix(&fw_var, lr);
Tensor chain_node_prob = ColIndexReduce(fw_prob, chain_node2fw_id);
if (step % (static_cast<int>(bugo += 0.01))) {
if (step % (static_cast<int>(bugo += 0.05))) {
auto chain_prob_copies = Clone(chain_node_prob, 3);
Tensor computation_ii = MatrixRowSum(chain_prob_copies.at(0));
auto compo_ii_copies = Clone(computation_ii, 2);
Tensor dev_mem =
CalcDeviceMemConsumed(chain_prob_copies.at(2), chain_graph, 4);
Tensor ii = MaxElem(compo_ii_copies.at(1));
Tensor penalty = ADD(Sum(Sqrt(chain_prob_copies.at(1))),
ADD(AvgAbsDeviation(dev_mem),
AvgAbsDeviation(compo_ii_copies.at(0))));
Tensor indecision = Sub(Sum(Sqrt(chain_prob_copies.at(1))),
Tensor(chain_node_prob.shape().At(1)));
Tensor penalty =
ADD(indecision, ADD(AvgAbsDeviation(dev_mem),
AvgAbsDeviation(compo_ii_copies.at(0))));
BackwardRun(ADD(ii, penalty));
std::cout << "fw_prob: " << std::endl;
FOR_RANGE(int, j, 0, fw_prob.shape().At(1)) {
......@@ -196,10 +195,11 @@ void AutoPlacementMemoryDemo() {
double x = fw_prob.At(i, j);
if (x < 0.01) { x = 0; }
if (x > 0.99) { x = 1; }
std::cout << std::setprecision(3) << x << "\t";
std::cout << x << "\t";
}
std::cout << std::endl;
}
std::cout << "indecision: " << indecision.At(0) << std::endl;
std::cout << "computation_ii: ";
for (double i : computation_ii.buffer().data()) { std::cout << i << " "; }
std::cout << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册