ShortestPathTraversal.java 3.2 KB
Newer Older
W
wumingfang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
package com.example.bootdemo;

import java.util.*;

public class ShortestPathTraversal {
    static int[][] points = {{1, 1}, {1, 3}, {2, 2}, {4, 4}, {2, -2},{3, -1}, {-2, 2}, {-3, 4}, {-1, -2}, {-3, -3}};

    static double euclideanDistance(int[] p1, int[] p2) {
        double x1 = p1[0], y1 = p1[1];
        double x2 = p2[0], y2 = p2[1];
        return Math.sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2));
    }

    static List<Integer> getShortestPathTraversal(int m, Set<Integer> selected) {
        int n = points.length;
        // 生成子集的边权数组
        double[][] edges = new double[m][m];
        for (int i = 0; i < m; ++i) {
            int rowIdx = 0;
            for (int j = 0; j < n; ++j) {
                if (selected.contains(j)) {
                    edges[i][rowIdx++] = euclideanDistance(points[selected.toArray(new Integer[0])[i]], points[j]);
                }
            }
        }
        // 状态压缩DP
        int N = (int)Math.pow(2, m);
        double[][] dp = new double[N][m];
        for (int i = 0; i < N; ++i) {
            Arrays.fill(dp[i], Double.MAX_VALUE);
        }
        for (int i = 0; i < m; ++i) {
            dp[1<<i][i] = 0;
        }
        for (int i = 1; i < N; ++i) {
            for (int j = 0; j < m; ++j) {
                if ((i & (1<<j)) != 0) {
                    int preState = i ^ (1<<j);
                    for (int k = 0; k < m; ++k) {
                        if ((preState & (1<<k)) != 0) {
                            dp[i][j] = Math.min(dp[i][j], dp[preState][k] + edges[k][j]);
                        }
                    }
                }
            }
        }
        // 回溯路径
        List<Integer> path = new ArrayList<>();
        int currState = N-1;
        int currNode = 0;
        for (int i = 0; i < m; ++i) {
            if (dp[currState][i] < dp[currState][currNode]) {
                currNode = i;
            }
        }
        path.add(selected.toArray(new Integer[0])[currNode]);
        while (currState != 1<<currNode) {
            int nextState = currState ^ (1<<currNode);
            int nextNode = -1;
            for (int i = 0; i < m; ++i) {
                if ((nextState & (1<<i)) != 0 && dp[nextState][i] + edges[currNode][i] == dp[currState][currNode]) {
                    nextNode = i;
                    break;
                }
            }
            if (nextNode < 0) {
                throw new IllegalStateException("No valid path found.");
            }
            path.add(selected.toArray(new Integer[0])[nextNode]);
            currState = nextState;
            currNode = nextNode;
        }
        Collections.reverse(path);
        return path;
    }

    public static void main(String[] args) {
        Scanner input = new Scanner(System.in);
        System.out.print("请输入数字的总数:"); // 先要求输入数字的总数
        int n = input.nextInt();
        Set<Integer> selected = new HashSet<>(n);
        System.out.println("请依次输入 " + n + " 个数字:");
        for (int i = 0; i < n; i++) { // 循环输入每一个数字
            selected.add(input.nextInt());
        }
        System.out.println(getShortestPathTraversal(n, selected));
    }
}