提交 fb30c9fa 编写于 作者: W Webbley

add a simple dataset for stgcn

上级 299328a0
...@@ -23,7 +23,7 @@ You can make your customized dataset by the following format: ...@@ -23,7 +23,7 @@ You can make your customized dataset by the following format:
For examples, use gpu to train STGCN on your dataset. For examples, use gpu to train STGCN on your dataset.
``` ```
python main.py --use_cuda --input_file dataset/input_csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv python main.py --use_cuda --input_file dataset/input.csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv
``` ```
#### Hyperparameters #### Hyperparameters
......
...@@ -167,9 +167,6 @@ def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config): ...@@ -167,9 +167,6 @@ def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config):
x = x.drop(columns=['date']) x = x.drop(columns=['date'])
y = y.drop(columns=['date']) y = y.drop(columns=['date'])
x = x.drop(columns=['武汉'])
y = y.drop(columns=['武汉'])
# param # param
n_val, n_test = n_config n_val, n_test = n_config
n_train = len(y) - n_val - n_test - 2 n_train = len(y) - n_val - n_test - 2
......
0,3409,2025,509,13098
2404,0,2207,3654,9485
21926,18619,0,955,1308
20160,12493,170,0,1906
611,572,1204,1066,0
num,city
0,A
1,B
2,C
3,D
4,E
date,A,B,C,D,E
2327/1/1,178,3907,2907,1170,832
2327/1/2,220,2720,2548,1370,1039
2327/1/3,222,5065,4286,2051,1582
2327/1/4,183,5291,4626,2096,1614
2327/1/5,172,3916,3538,1726,1349
2327/1/6,219,4079,4110,2044,1701
2327/1/7,220,4707,4673,2589,2177
2327/1/8,222,5306,5512,3015,2463
2327/1/9,215,5762,5802,3184,2558
2327/1/10,217,4977,4641,2659,2185
2327/1/11,186,6849,6106,3092,2310
2327/1/12,175,5953,4986,2521,1769
2327/1/13,215,5270,4983,2559,1818
2327/1/14,213,5304,5307,2516,1707
2327/1/15,205,5499,5684,2659,1633
2327/1/16,205,5811,6531,2920,1793
2327/1/17,222,6397,7745,3159,2036
2327/1/18,253,7759,9681,4011,2331
2327/1/19,859,8791,8215,4507,2480
2327/1/20,837,10348,9960,5655,3167
2327/1/21,931,12782,13621,7107,4291
2327/1/22,1048,15298,16222,8206,4730
2327/1/23,835,16287,14803,6504,3679
2327/1/24,635,4806,3970,1551,816
2327/1/25,511,1028,1023,401,205
2327/1/26,387,483,632,249,111
2327/1/27,459,457,591,209,126
2327/1/28,1073,513,707,234,176
2327/1/29,1301,651,932,276,264
2327/1/30,1502,757,1266,369,302
2327/1/31,1823,972,1286,490,487
2327/2/1,2219,1113,1594,579,548
2327/2/2,2719,1345,2172,695,703
2327/2/3,3563,1556,2517,931,823
2327/2/4,4335,1824,2837,1095,928
2327/2/5,5568,2343,3323,1244,1043
2327/2/6,6070,2917,3420,1295,1054
2327/2/7,7169,3278,3758,1516,1185
2327/2/8,8284,3616,3982,1639,1333
2327/2/9,9229,3799,4200,1726,1418
2327/2/10,10425,3876,4334,1750,1449
2327/2/11,11213,3920,4522,1818,1484
2327/2/12,11653,4106,4831,1881,1512
2327/2/13,20427,4343,5413,2537,1570
2327/2/14,24164,4666,5914,2636,1607
2327/2/15,22608,4901,5546,2812,1557
date,A,B,C,D,E
2327/1/24,70,22,0,2,0
2327/1/25,77,4,52,2,1
2327/1/26,46,29,58,23,7
2327/1/27,80,45,32,14,28
2327/1/28,892,73,59,24,34
2327/1/29,315,101,111,30,61
2327/1/30,356,125,172,50,32
2327/1/31,378,142,77,70,123
2327/2/1,576,87,153,66,61
2327/2/2,894,121,276,46,94
2327/2/3,1033,169,244,166,107
2327/2/4,1242,202,176,114,84
2327/2/5,1967,342,223,100,103
2327/2/6,1766,424,162,88,52
2327/2/7,1501,255,90,84,51
2327/2/8,1985,172,144,56,69
2327/2/9,1379,123,100,56,81
2327/2/10,1920,105,111,48,31
2327/2/11,1552,101,80,30,44
2327/2/12,1104,109,66,35,25
2327/2/13,13436,123,264,321,13
2327/2/14,2997,135,129,16,10
2327/2/15,1923,105,26,31,17
2327/2/16,1548,87,6,12,16
...@@ -124,7 +124,7 @@ def main(args): ...@@ -124,7 +124,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--n_route', type=int, default=74) parser.add_argument('--n_route', type=int, default=5)
parser.add_argument('--n_his', type=int, default=23) parser.add_argument('--n_his', type=int, default=23)
parser.add_argument('--n_pred', type=int, default=3) parser.add_argument('--n_pred', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=10) parser.add_argument('--batch_size', type=int, default=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册