diff --git a/examples/stgcn/README.md b/examples/stgcn/README.md index a185210a1fff51dfceee12ec010eac164b2ed99f..ec3c8fd90690d071766054b042135a078ee353d7 100644 --- a/examples/stgcn/README.md +++ b/examples/stgcn/README.md @@ -23,7 +23,7 @@ You can make your customized dataset by the following format: For examples, use gpu to train STGCN on your dataset. ``` -python main.py --use_cuda --input_file dataset/input_csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv +python main.py --use_cuda --input_file dataset/input.csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv ``` #### Hyperparameters diff --git a/examples/stgcn/data_loader/data_utils.py b/examples/stgcn/data_loader/data_utils.py index f9c73c774d6b9b239fe8509024a52ee37a4890f9..c7e20864c77d275b21270ab0ed7c27af23a47471 100644 --- a/examples/stgcn/data_loader/data_utils.py +++ b/examples/stgcn/data_loader/data_utils.py @@ -167,9 +167,6 @@ def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config): x = x.drop(columns=['date']) y = y.drop(columns=['date']) - x = x.drop(columns=['武汉']) - y = y.drop(columns=['武汉']) - # param n_val, n_test = n_config n_train = len(y) - n_val - n_test - 2 diff --git a/examples/stgcn/dataset/W.csv b/examples/stgcn/dataset/W.csv new file mode 100644 index 0000000000000000000000000000000000000000..6d349eb0d3bbb14b5ad7eb37dc893c19c4cb5285 --- /dev/null +++ b/examples/stgcn/dataset/W.csv @@ -0,0 +1,5 @@ +0,3409,2025,509,13098 +2404,0,2207,3654,9485 +21926,18619,0,955,1308 +20160,12493,170,0,1906 +611,572,1204,1066,0 diff --git a/examples/stgcn/dataset/city.csv b/examples/stgcn/dataset/city.csv new file mode 100644 index 0000000000000000000000000000000000000000..4f82c7e2957ce771d4a6ddbe99baa26d99dc74aa --- /dev/null +++ b/examples/stgcn/dataset/city.csv @@ -0,0 +1,7 @@ +num,city +0,A +1,B +2,C +3,D +4,E + diff --git a/examples/stgcn/dataset/input.csv b/examples/stgcn/dataset/input.csv new file mode 100644 index 0000000000000000000000000000000000000000..8fb14964399ed97e7a95d97cfd44156ff014dd1d --- /dev/null +++ b/examples/stgcn/dataset/input.csv @@ -0,0 +1,47 @@ +date,A,B,C,D,E +2327/1/1,178,3907,2907,1170,832 +2327/1/2,220,2720,2548,1370,1039 +2327/1/3,222,5065,4286,2051,1582 +2327/1/4,183,5291,4626,2096,1614 +2327/1/5,172,3916,3538,1726,1349 +2327/1/6,219,4079,4110,2044,1701 +2327/1/7,220,4707,4673,2589,2177 +2327/1/8,222,5306,5512,3015,2463 +2327/1/9,215,5762,5802,3184,2558 +2327/1/10,217,4977,4641,2659,2185 +2327/1/11,186,6849,6106,3092,2310 +2327/1/12,175,5953,4986,2521,1769 +2327/1/13,215,5270,4983,2559,1818 +2327/1/14,213,5304,5307,2516,1707 +2327/1/15,205,5499,5684,2659,1633 +2327/1/16,205,5811,6531,2920,1793 +2327/1/17,222,6397,7745,3159,2036 +2327/1/18,253,7759,9681,4011,2331 +2327/1/19,859,8791,8215,4507,2480 +2327/1/20,837,10348,9960,5655,3167 +2327/1/21,931,12782,13621,7107,4291 +2327/1/22,1048,15298,16222,8206,4730 +2327/1/23,835,16287,14803,6504,3679 +2327/1/24,635,4806,3970,1551,816 +2327/1/25,511,1028,1023,401,205 +2327/1/26,387,483,632,249,111 +2327/1/27,459,457,591,209,126 +2327/1/28,1073,513,707,234,176 +2327/1/29,1301,651,932,276,264 +2327/1/30,1502,757,1266,369,302 +2327/1/31,1823,972,1286,490,487 +2327/2/1,2219,1113,1594,579,548 +2327/2/2,2719,1345,2172,695,703 +2327/2/3,3563,1556,2517,931,823 +2327/2/4,4335,1824,2837,1095,928 +2327/2/5,5568,2343,3323,1244,1043 +2327/2/6,6070,2917,3420,1295,1054 +2327/2/7,7169,3278,3758,1516,1185 +2327/2/8,8284,3616,3982,1639,1333 +2327/2/9,9229,3799,4200,1726,1418 +2327/2/10,10425,3876,4334,1750,1449 +2327/2/11,11213,3920,4522,1818,1484 +2327/2/12,11653,4106,4831,1881,1512 +2327/2/13,20427,4343,5413,2537,1570 +2327/2/14,24164,4666,5914,2636,1607 +2327/2/15,22608,4901,5546,2812,1557 diff --git a/examples/stgcn/dataset/output.csv b/examples/stgcn/dataset/output.csv new file mode 100644 index 0000000000000000000000000000000000000000..aac7afc1f5e3c0124bb0d583eb33a8fc6249c23f --- /dev/null +++ b/examples/stgcn/dataset/output.csv @@ -0,0 +1,25 @@ +date,A,B,C,D,E +2327/1/24,70,22,0,2,0 +2327/1/25,77,4,52,2,1 +2327/1/26,46,29,58,23,7 +2327/1/27,80,45,32,14,28 +2327/1/28,892,73,59,24,34 +2327/1/29,315,101,111,30,61 +2327/1/30,356,125,172,50,32 +2327/1/31,378,142,77,70,123 +2327/2/1,576,87,153,66,61 +2327/2/2,894,121,276,46,94 +2327/2/3,1033,169,244,166,107 +2327/2/4,1242,202,176,114,84 +2327/2/5,1967,342,223,100,103 +2327/2/6,1766,424,162,88,52 +2327/2/7,1501,255,90,84,51 +2327/2/8,1985,172,144,56,69 +2327/2/9,1379,123,100,56,81 +2327/2/10,1920,105,111,48,31 +2327/2/11,1552,101,80,30,44 +2327/2/12,1104,109,66,35,25 +2327/2/13,13436,123,264,321,13 +2327/2/14,2997,135,129,16,10 +2327/2/15,1923,105,26,31,17 +2327/2/16,1548,87,6,12,16 diff --git a/examples/stgcn/main.py b/examples/stgcn/main.py index 26adb6a4e6f3c81b4e4e2d35e3f049d0d80f03f5..a2d3b11454a95b2a25ab6dd8157630ad9f050ffd 100644 --- a/examples/stgcn/main.py +++ b/examples/stgcn/main.py @@ -124,7 +124,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--n_route', type=int, default=74) + parser.add_argument('--n_route', type=int, default=5) parser.add_argument('--n_his', type=int, default=23) parser.add_argument('--n_pred', type=int, default=3) parser.add_argument('--batch_size', type=int, default=10)