From 504e79814e8730958fe1e553d775caa672fdfecb Mon Sep 17 00:00:00 2001 From: liweibin Date: Sun, 1 Mar 2020 16:29:17 +0800 Subject: [PATCH] add stgcn model --- examples/stgcn/README.md | 36 +++ examples/stgcn/data_loader/__init__.py | 14 ++ examples/stgcn/data_loader/data_utils.py | 251 +++++++++++++++++++++ examples/stgcn/data_loader/graph.py | 92 ++++++++ examples/stgcn/main.py | 157 +++++++++++++ examples/stgcn/models/model.py | 271 +++++++++++++++++++++++ examples/stgcn/models/tester.py | 134 +++++++++++ examples/stgcn/utils/math_utils.py | 65 ++++++ 8 files changed, 1020 insertions(+) create mode 100644 examples/stgcn/README.md create mode 100644 examples/stgcn/data_loader/__init__.py create mode 100644 examples/stgcn/data_loader/data_utils.py create mode 100644 examples/stgcn/data_loader/graph.py create mode 100644 examples/stgcn/main.py create mode 100644 examples/stgcn/models/model.py create mode 100644 examples/stgcn/models/tester.py create mode 100644 examples/stgcn/utils/math_utils.py diff --git a/examples/stgcn/README.md b/examples/stgcn/README.md new file mode 100644 index 0000000..a185210 --- /dev/null +++ b/examples/stgcn/README.md @@ -0,0 +1,36 @@ +# STGCN: Spatio-Temporal Graph Convolutional Network + +[Spatio-Temporal Graph Convolutional Network \(STGCN\)](https://arxiv.org/pdf/1709.04875.pdf) is a novel deep learning framework to tackle time series prediction problem. Based on PGL, we reproduce STGCN algorithms to predict new confirmed patients in some cities with the historical immigration records. + +### Datasets + +You can make your customized dataset by the following format: + +* input.csv: Historical immigration records with shape of [num\_time\_steps * num\_cities]. + +* output.csv: New confirmed patients records with shape of [num\_time\_steps * num\_cities]. + +* W.csv: Weighted Adjacency Matrix with shape of [num\_cities * num\_cities]. + +* city.csv: Each line is a number and the corresponding city name. + +### Dependencies + +- paddlepaddle 1.6 +- pgl 1.0.0 + +### How to run + +For examples, use gpu to train STGCN on your dataset. +``` +python main.py --use_cuda --input_file dataset/input_csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv +``` + +#### Hyperparameters + +- n\_route: Number of city. +- n\_his: "n\_his" time steps of previous observations of historical immigration records. +- n\_pred: Next "n\_pred" time steps of New confirmed patients records. +- Ks: Number of GCN layers. +- Kt: Kernel size of temporal convolution. +- use\_cuda: Use gpu if assign use\_cuda. diff --git a/examples/stgcn/data_loader/__init__.py b/examples/stgcn/data_loader/__init__.py new file mode 100644 index 0000000..28650aa --- /dev/null +++ b/examples/stgcn/data_loader/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""__init__""" diff --git a/examples/stgcn/data_loader/data_utils.py b/examples/stgcn/data_loader/data_utils.py new file mode 100644 index 0000000..f9c73c7 --- /dev/null +++ b/examples/stgcn/data_loader/data_utils.py @@ -0,0 +1,251 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""data processing +""" +import numpy as np +import pandas as pd + +from utils.math_utils import z_score + + +class Dataset(object): + """Dataset + """ + + def __init__(self, data, stats): + self.__data = data + self.mean = stats['mean'] + self.std = stats['std'] + + def get_data(self, type): # type: train, val or test + return self.__data[type] + + def get_stats(self): + return {'mean': self.mean, 'std': self.std} + + def get_len(self, type): + return len(self.__data[type]) + + def z_inverse(self, type): + return self.__data[type] * self.std + self.mean + + +def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1): + """Generate data in the form of standard sequence unit.""" + n_slot = day_slot - n_frame + 1 + + tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0)) + for i in range(len_seq): + for j in range(n_slot): + sta = (i + offset) * day_slot + j + end = sta + n_frame + tmp_seq[i * n_slot + j, :, :, :] = np.reshape( + data_seq[sta:end, :], [n_frame, n_route, C_0]) + return tmp_seq + + +def adj_matrx_gen_custom(input_file, city_file): + """genenrate Adjacency Matrix from file + """ + print("generate adj_matrix data (take long time)...") + # data + df = pd.read_csv( + input_file, + sep='\t', + names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数']) + # 只需要2020年的数据 + df['date'] = pd.to_datetime(df['date'], format="%Y%m%d") + df = df.set_index('date') + df = df['2020'] + city_df = pd.read_csv(city_file) + # 剔除武汉 + city_df = city_df.drop(0) + num = len(city_df) + matrix = np.zeros([num, num]) + for i in city_df['city']: + for j in city_df['city']: + if (i == j): + continue + # 选出从i到j的每日人数 + cut = df[df['迁出城市'].str.contains(i)] + cut = cut[cut['迁入城市'].str.contains(j)] + # 求均值作为权重 + average = cut['人数'].mean() + # 赋值给matrix + i_index = int(city_df[city_df['city'] == i]['num']) - 1 + j_index = int(city_df[city_df['city'] == j]['num']) - 1 + matrix[i_index, j_index] = average + + np.savetxt("dataset/W_74.csv", matrix, delimiter=",") + + +def data_gen_custom(input_file, output_file, city_file, n, n_his, n_pred, + n_config): + """data_gen_custom""" + print("generate training data...") + # data + df = pd.read_csv( + input_file, + sep='\t', + names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数']) + # 只需要2020年的数据 + df['date'] = pd.to_datetime(df['date'], format="%Y%m%d") + df = df.set_index('date') + df = df['2020'] + city_df = pd.read_csv(city_file) + input_df = pd.DataFrame() + + out_df_wuhan = df[df['迁出城市'].str.contains('武汉')] + for i in city_df['city']: + # 筛选迁入城市 + in_df_i = out_df_wuhan[out_df_wuhan['迁入城市'].str.contains(i)] + # 确保按时间升序 + # in_df_i.sort_values("date",inplace=True) + # 按时间插入 + in_df_i.reset_index(drop=True, inplace=True) + input_df[i] = in_df_i['人数'] + + # 替换Nan值 + input_df = input_df.replace(np.nan, 0) + + x = input_df + y = pd.read_csv(output_file) + # 删除第1列 + x.drop( + x.columns[x.columns.str.contains( + 'unnamed', case=False)], + axis=1, + inplace=True) + y = y.drop(columns=['date']) + + # 剔除迁入武汉的数据 + x = x.drop(columns=['武汉']) + y = y.drop(columns=['武汉']) + + # param + n_val, n_test = n_config + n_train = len(y) - n_val - n_test - 2 + + # (?,26,74,1) + df = pd.DataFrame(columns=x.columns) + for i in range(len(y) - n_pred + 1): + df = df.append(x[i:i + n_his]) + df = df.append(y[i:i + n_pred]) + data = df.values.reshape(-1, n_his + n_pred, n, + 1) # n == num_nodes == city num + + x_stats = {'mean': np.mean(data), 'std': np.std(data)} + + x_train = data[:n_train] + x_val = data[n_train:n_train + n_val] + x_test = data[n_train + n_val:] + + x_data = {'train': x_train, 'val': x_val, 'test': x_test} + dataset = Dataset(x_data, x_stats) + print("generate successfully!") + + return dataset + + +def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config): + """data processing + """ + # data + x = pd.read_csv(input_file) + y = pd.read_csv(label_file) + x = x.drop(columns=['date']) + y = y.drop(columns=['date']) + + x = x.drop(columns=['武汉']) + y = y.drop(columns=['武汉']) + + # param + n_val, n_test = n_config + n_train = len(y) - n_val - n_test - 2 + + # (?,26,74,1) + df = pd.DataFrame(columns=x.columns) + for i in range(len(y) - n_pred + 1): + df = df.append(x[i:i + n_his]) + df = df.append(y[i:i + n_pred]) + + data = df.values.reshape(-1, n_his + n_pred, n, 1) + + x_stats = {'mean': np.mean(data), 'std': np.std(data)} + + x_train = data[:n_train] + x_val = data[n_train:n_train + n_val] + x_test = data[n_train + n_val:] + + x_data = {'train': x_train, 'val': x_val, 'test': x_test} + dataset = Dataset(x_data, x_stats) + return dataset + + +def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288): + """Source file load and dataset generation.""" + n_train, n_val, n_test = data_config + # generate training, validation and test data + try: + data_seq = pd.read_csv(file_path, header=None).values + except FileNotFoundError: + print(f'ERROR: input file was not found in {file_path}.') + + seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot) + seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot) + seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route, + day_slot) + + # x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation. + x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)} + + # x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size]. + x_train = z_score(seq_train, x_stats['mean'], x_stats['std']) + x_val = z_score(seq_val, x_stats['mean'], x_stats['std']) + x_test = z_score(seq_test, x_stats['mean'], x_stats['std']) + + x_data = {'train': x_train, 'val': x_val, 'test': x_test} + dataset = Dataset(x_data, x_stats) + return dataset + + +def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False): + """Data iterator in batch. + + Args: + inputs: np.ndarray, [len_seq, n_frame, n_route, C_0], standard sequence units. + batch_size: int, size of batch. + dynamic_batch: bool, whether changes the batch size in the last batch + if its length is less than the default. + shuffle: bool, whether shuffle the batches. + """ + len_inputs = len(inputs) + + if shuffle: + idx = np.arange(len_inputs) + np.random.shuffle(idx) + + for start_idx in range(0, len_inputs, batch_size): + end_idx = start_idx + batch_size + if end_idx > len_inputs: + if dynamic_batch: + end_idx = len_inputs + else: + break + if shuffle: + slide = idx[start_idx:end_idx] + else: + slide = slice(start_idx, end_idx) + + yield inputs[slide] diff --git a/examples/stgcn/data_loader/graph.py b/examples/stgcn/data_loader/graph.py new file mode 100644 index 0000000..fc04eec --- /dev/null +++ b/examples/stgcn/data_loader/graph.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PGL Graph +""" +import sys +import os +import numpy as np +import pandas as pd + +from pgl.graph import Graph + + +def weight_matrix(file_path, sigma2=0.1, epsilon=0.5, scaling=True): + """Load weight matrix function.""" + try: + W = pd.read_csv(file_path, header=None).values + except FileNotFoundError: + print(f'ERROR: input file was not found in {file_path}.') + + # check whether W is a 0/1 matrix. + if set(np.unique(W)) == {0, 1}: + print('The input graph is a 0/1 matrix; set "scaling" to False.') + scaling = False + + if scaling: + n = W.shape[0] + W = W / 10000. + W2, W_mask = W * W, np.ones([n, n]) - np.identity(n) + # refer to Eq.10 + return np.exp(-W2 / sigma2) * ( + np.exp(-W2 / sigma2) >= epsilon) * W_mask + else: + return W + + +class GraphFactory(object): + """GraphFactory""" + + def __init__(self, args): + self.args = args + self.adj_matrix = weight_matrix(self.args.adj_mat_file) + + L = np.eye(self.adj_matrix.shape[0]) + self.adj_matrix + D = np.sum(self.adj_matrix, axis=1) + # L = D - self.adj_matrix + # import ipdb; ipdb.set_trace() + + edges = [] + weights = [] + for i in range(self.adj_matrix.shape[0]): + for j in range(self.adj_matrix.shape[1]): + edges.append([i, j]) + weights.append(L[i][j]) + + self.edges = np.array(edges, dtype=np.int64) + self.weights = np.array(weights, dtype=np.float32).reshape(-1, 1) + + self.norm = np.zeros_like(D, dtype=np.float32) + self.norm[D > 0] = np.power(D[D > 0], -0.5) + self.norm = self.norm.reshape(-1, 1) + + def build_graph(self, x_batch): + """build graph""" + B, T, n, _ = x_batch.shape + batch = B * T + + batch_edges = [] + for i in range(batch): + batch_edges.append(self.edges + (i * n)) + batch_edges = np.vstack(batch_edges) + + num_nodes = B * T * n + node_feat = {'norm': np.tile(self.norm, [batch, 1])} + edge_feat = {'weights': np.tile(self.weights, [batch, 1])} + graph = Graph( + num_nodes=num_nodes, + edges=batch_edges, + node_feat=node_feat, + edge_feat=edge_feat) + + return graph diff --git a/examples/stgcn/main.py b/examples/stgcn/main.py new file mode 100644 index 0000000..6be8df9 --- /dev/null +++ b/examples/stgcn/main.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file implement the training process of STGCN model. +""" + +import os +import sys +import time +import argparse +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl +from pgl.utils.logger import log + +from data_loader.data_utils import data_gen_mydata, gen_batch +from data_loader.graph import GraphFactory +from models.model import STGCNModel +from models.tester import model_inference, model_test + + +def main(args): + """main""" + PeMS = data_gen_mydata(args.input_file, args.label_file, args.n_route, + args.n_his, args.n_pred, (args.n_val, args.n_test)) + + log.info(PeMS.get_stats()) + log.info(PeMS.get_len('train')) + + gf = GraphFactory(args) + + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + train_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + gw = pgl.graph_wrapper.GraphWrapper( + "gw", + place, + node_feat=[('norm', [None, 1], "float32")], + edge_feat=[('weights', [None, 1], "float32")]) + + model = STGCNModel(args, gw) + train_loss, y_pred = model.forward() + + infer_program = train_program.clone(for_test=True) + + with fluid.program_guard(train_program, startup_program): + epoch_step = int(PeMS.get_len('train') / args.batch_size) + 1 + lr = fl.exponential_decay( + learning_rate=args.lr, + decay_steps=5 * epoch_step, + decay_rate=0.7, + staircase=True) + if args.opt == 'RMSProp': + train_op = fluid.optimizer.RMSPropOptimizer(lr).minimize( + train_loss) + elif args.opt == 'ADAM': + train_op = fluid.optimizer.Adam(lr).minimize(train_loss) + + exe = fluid.Executor(place) + exe.run(startup_program) + + if args.inf_mode == 'sep': + # for inference mode 'sep', the type of step index is int. + step_idx = args.n_pred - 1 + tmp_idx = [step_idx] + min_val = min_va_val = np.array([4e1, 1e5, 1e5]) + elif args.inf_mode == 'merge': + # for inference mode 'merge', the type of step index is np.ndarray. + step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1 + min_val = min_va_val = np.array([4e1, 1e5, 1e5]) * len(step_idx) + else: + raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.') + + step = 0 + for epoch in range(1, args.epochs + 1): + for idx, x_batch in enumerate( + gen_batch( + PeMS.get_data('train'), + args.batch_size, + dynamic_batch=True, + shuffle=True)): + + x = np.array(x_batch[:, 0:args.n_his, :, :], dtype=np.float32) + graph = gf.build_graph(x) + feed = gw.to_feed(graph) + feed['input'] = np.array( + x_batch[:, 0:args.n_his + 1, :, :], dtype=np.float32) + b_loss, b_lr = exe.run(train_program, + feed=feed, + fetch_list=[train_loss, lr]) + + if idx % 5 == 0: + log.info("epoch %d | step %d | lr %.6f | loss %.6f" % + (epoch, idx, b_lr[0], b_loss[0])) + + min_va_val, min_val = \ + model_inference(exe, gw, gf, infer_program, y_pred, PeMS, args, \ + step_idx, min_va_val, min_val) + + for ix in tmp_idx: + va, te = min_va_val[ix - 2:ix + 1], min_val[ix - 2:ix + 1] + print(f'Time Step {ix + 1}: ' + f'MAPE {va[0]:7.3%}, {te[0]:7.3%}; ' + f'MAE {va[1]:4.3f}, {te[1]:4.3f}; ' + f'RMSE {va[2]:6.3f}, {te[2]:6.3f}.') + + if epoch % 5 == 0: + model_test(exe, gw, gf, infer_program, y_pred, PeMS, args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--n_route', type=int, default=74) + parser.add_argument('--n_his', type=int, default=23) + parser.add_argument('--n_pred', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=10) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--save', type=int, default=10) + parser.add_argument('--Ks', type=int, default=3) #equal to num_layers + parser.add_argument('--Kt', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-2) + parser.add_argument('--keep_prob', type=float, default=1.0) + parser.add_argument('--opt', type=str, default='RMSProp') + parser.add_argument('--inf_mode', type=str, default='sep') + parser.add_argument('--input_file', type=str, default='dataset/input.csv') + parser.add_argument('--label_file', type=str, default='dataset/output.csv') + parser.add_argument( + '--city_file', type=str, default='dataset/crawl_list.csv') + parser.add_argument('--adj_mat_file', type=str, default='dataset/W_74.csv') + parser.add_argument('--output_path', type=str, default='./outputs/') + parser.add_argument('--n_val', type=str, default=1) + parser.add_argument('--n_test', type=str, default=1) + parser.add_argument('--use_cuda', action='store_true') + args = parser.parse_args() + + blocks = [[1, 32, 64], [64, 32, 128]] + args.blocks = blocks + log.info(args) + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + main(args) diff --git a/examples/stgcn/models/model.py b/examples/stgcn/models/model.py new file mode 100644 index 0000000..1e8f496 --- /dev/null +++ b/examples/stgcn/models/model.py @@ -0,0 +1,271 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file implement the STGCN model. +""" +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl + + +class STGCNModel(object): + """Implementation of Spatio-Temporal Graph Convolutional Networks""" + + def __init__(self, args, gw): + self.args = args + self.gw = gw + + self.input = fl.data( + name="input", + shape=[None, args.n_his + 1, args.n_route, 1], + dtype="float32") + + def forward(self): + """forward""" + x = self.input[:, 0:self.args.n_his, :, :] + # Ko>0: kernel size of temporal convolution in the output layer. + Ko = self.args.n_his + # ST-Block + for i, channels in enumerate(self.args.blocks): + x = self.st_conv_block( + x, + self.args.Ks, + self.args.Kt, + channels, + "st_conv_%d" % i, + self.args.keep_prob, + act_func='GLU') + + # output layer + if Ko > 1: + y = self.output_layer(x, Ko, 'output_layer') + else: + raise ValueError(f'ERROR: kernel size Ko must be greater than 1, \ + but received "{Ko}".') + + label = self.input[:, self.args.n_his:self.args.n_his + 1, :, :] + train_loss = fl.reduce_sum((y - label) * (y - label)) + single_pred = y[:, 0, :, :] # shape: [batch, n, 1] + + return train_loss, single_pred + + def st_conv_block(self, + x, + Ks, + Kt, + channels, + name, + keep_prob, + act_func='GLU'): + """Spatio-Temporal convolution block""" + c_si, c_t, c_oo = channels + + x_s = self.temporal_conv_layer( + x, Kt, c_si, c_t, "%s_tconv_in" % name, act_func=act_func) + x_t = self.spatio_conv_layer(x_s, Ks, c_t, c_t, "%s_sonv" % name) + x_o = self.temporal_conv_layer(x_t, Kt, c_t, c_oo, + "%s_tconv_out" % name) + + x_ln = fl.layer_norm(x_o) + return fl.dropout(x_ln, dropout_prob=(1.0 - keep_prob)) + + def temporal_conv_layer(self, x, Kt, c_in, c_out, name, act_func='relu'): + """Temporal convolution layer""" + _, T, n, _ = x.shape + if c_in > c_out: + x_input = fl.conv2d( + input=x, + num_filters=c_out, + filter_size=[1, 1], + stride=[1, 1], + padding="SAME", + data_format="NHWC", + param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name)) + elif c_in < c_out: + # if the size of input channel is less than the output, + # padding x to the same size of output channel. + pad = fl.fill_constant_batch_size_like( + input=x, + shape=[-1, T, n, c_out - c_in], + dtype="float32", + value=0.0) + x_input = fl.concat([x, pad], axis=3) + else: + x_input = x + + # x_input = x_input[:, Kt - 1:T, :, :] + if act_func == 'GLU': + # gated liner unit + bt_init = fluid.initializer.ConstantInitializer(value=0.0) + bt = fl.create_parameter( + shape=[2 * c_out], + dtype="float32", + attr=fluid.ParamAttr( + name="%s_bt" % name, trainable=True, initializer=bt_init), + ) + x_conv = fl.conv2d( + input=x, + num_filters=2 * c_out, + filter_size=[Kt, 1], + stride=[1, 1], + padding="SAME", + data_format="NHWC", + param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name)) + x_conv = x_conv + bt + return (x_conv[:, :, :, 0:c_out] + x_input + ) * fl.sigmoid(x_conv[:, :, :, -c_out:]) + else: + bt_init = fluid.initializer.ConstantInitializer(value=0.0) + bt = fl.create_parameter( + shape=[c_out], + dtype="float32", + attr=fluid.ParamAttr( + name="%s_bt" % name, trainable=True, initializer=bt_init), + ) + x_conv = fl.conv2d( + input=x, + num_filters=c_out, + filter_size=[Kt, 1], + stride=[1, 1], + padding="SAME", + data_format="NHWC", + param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name)) + x_conv = x_conv + bt + if act_func == "linear": + return x_conv + elif act_func == "sigmoid": + return fl.sigmoid(x_conv) + elif act_func == "relu": + return fl.relu(x_conv + x_input) + else: + raise ValueError( + f'ERROR: activation function "{act_func}" is not defined.') + + def spatio_conv_layer(self, x, Ks, c_in, c_out, name): + """Spatio convolution layer""" + _, T, n, _ = x.shape + if c_in > c_out: + x_input = fl.conv2d( + input=x, + num_filters=c_out, + filter_size=[1, 1], + stride=[1, 1], + padding="SAME", + data_format="NHWC", + param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name)) + elif c_in < c_out: + # if the size of input channel is less than the output, + # padding x to the same size of output channel. + pad = fl.fill_constant_batch_size_like( + input=x, + shape=[-1, T, n, c_out - c_in], + dtype="float32", + value=0.0) + x_input = fl.concat([x, pad], axis=3) + else: + x_input = x + + for i in range(Ks): + # x_input shape: [B,T, num_nodes, c_out] + x_input = fl.reshape(x_input, [-1, c_out]) + + x_input = self.message_passing( + self.gw, + x_input, + name="%s_mp_%d" % (name, i), + norm=self.gw.node_feat["norm"]) + + x_input = fl.fc(x_input, + size=c_out, + bias_attr=False, + param_attr=fluid.ParamAttr(name="%s_gcn_fc_%d" % + (name, i))) + + bias = fluid.layers.create_parameter( + shape=[c_out], + dtype='float32', + is_bias=True, + name='%s_gcn_bias_%d' % (name, i)) + x_input = fluid.layers.elementwise_add(x_input, bias, act="relu") + + x_input = fl.reshape(x_input, [-1, T, n, c_out]) + + return x_input + + def message_passing(self, gw, feature, name, norm=None): + """Message passing layer""" + + def send_src_copy(src_feat, dst_feat, edge_feat): + """send function""" + return src_feat["h"] * edge_feat['w'] + + if norm is not None: + feature = feature * norm + + msg = gw.send( + send_src_copy, + nfeat_list=[("h", feature)], + efeat_list=[('w', gw.edge_feat['weights'])]) + output = gw.recv(msg, "sum") + + if norm is not None: + output = output * norm + + return output + + def output_layer(self, x, T, name, act_func='GLU'): + """Output layer""" + _, _, n, channel = x.shape + + # maps multi-steps to one. + x_i = self.temporal_conv_layer( + x=x, + Kt=T, + c_in=channel, + c_out=channel, + name="%s_in" % name, + act_func=act_func) + x_ln = fl.layer_norm(x_i) + x_o = self.temporal_conv_layer( + x=x_ln, + Kt=1, + c_in=channel, + c_out=channel, + name="%s_out" % name, + act_func='sigmoid') + + # maps multi-channels to one. + x_fc = self.fully_con_layer( + x=x_o, n=n, channel=channel, name="%s_fc" % name) + return x_fc + + def fully_con_layer(self, x, n, channel, name): + """Fully connected layer""" + bt_init = fluid.initializer.ConstantInitializer(value=0.0) + bt = fl.create_parameter( + shape=[n, 1], + dtype="float32", + attr=fluid.ParamAttr( + name="%s_bt" % name, trainable=True, initializer=bt_init), ) + x_conv = fl.conv2d( + input=x, + num_filters=1, + filter_size=[1, 1], + stride=[1, 1], + padding="SAME", + data_format="NHWC", + param_attr=fluid.ParamAttr(name="%s_conv2d" % name)) + x_conv = x_conv + bt + return x_conv diff --git a/examples/stgcn/models/tester.py b/examples/stgcn/models/tester.py new file mode 100644 index 0000000..5c10cae --- /dev/null +++ b/examples/stgcn/models/tester.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file implement the testing process of STGCN model. +""" +import os +import sys +import time +import argparse +import numpy as np +import pandas as pd + +import paddle.fluid as fluid +import paddle.fluid.layers as fl +import pgl +from pgl.utils.logger import log + +from data_loader.data_utils import gen_batch +from utils.math_utils import evaluation + + +def multi_pred(exe, gw, gf, program, y_pred, seq, batch_size, \ + n_his, n_pred, step_idx, dynamic_batch=True): + """multi step prediction""" + pred_list = [] + for i in gen_batch( + seq, min(batch_size, len(seq)), dynamic_batch=dynamic_batch): + + # Note: use np.copy() to avoid the modification of source data. + test_seq = np.copy(i[:, 0:n_his + 1, :, :]).astype(np.float32) + graph = gf.build_graph(i[:, 0:n_his, :, :]) + feed = gw.to_feed(graph) + step_list = [] + for j in range(n_pred): + feed['input'] = test_seq + pred = exe.run(program, feed=feed, fetch_list=[y_pred]) + if isinstance(pred, list): + pred = np.array(pred[0]) + test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :] + test_seq[:, n_his - 1, :, :] = pred + step_list.append(pred) + pred_list.append(step_list) + # pred_array -> [n_pred, len(seq), n_route, C_0) + pred_array = np.concatenate(pred_list, axis=1) + return pred_array, pred_array.shape[1] + + +def model_inference(exe, gw, gf, program, pred, inputs, args, step_idx, + min_va_val, min_val): + """inference model""" + x_val, x_test, x_stats = inputs.get_data('val'), inputs.get_data( + 'test'), inputs.get_stats() + + if args.n_his + args.n_pred > x_val.shape[1]: + raise ValueError( + f'ERROR: the value of n_pred "{args.n_pred}" exceeds the length limit.' + ) + + # y_val shape: [n_pred, len(x_val), n_route, C_0) + y_val, len_val = multi_pred(exe, gw, gf, program, pred, \ + x_val, args.batch_size, args.n_his, args.n_pred, step_idx) + + evl_val = evaluation(x_val[0:len_val, step_idx + args.n_his, :, :], + y_val[step_idx], x_stats) + + # chks: indicator that reflects the relationship of values between evl_val and min_va_val. + chks = evl_val < min_va_val + # update the metric on test set, if model's performance got improved on the validation. + if sum(chks): + min_va_val[chks] = evl_val[chks] + y_pred, len_pred = multi_pred(exe, gw, gf, program, pred, \ + x_test, args.batch_size, args.n_his, args.n_pred, step_idx) + + evl_pred = evaluation(x_test[0:len_pred, step_idx + args.n_his, :, :], + y_pred[step_idx], x_stats) + min_val = evl_pred + + return min_va_val, min_val + + +def model_test(exe, gw, gf, program, pred, inputs, args): + """test model""" + if args.inf_mode == 'sep': + # for inference mode 'sep', the type of step index is int. + step_idx = args.n_pred - 1 + tmp_idx = [step_idx] + elif args.inf_mode == 'merge': + # for inference mode 'merge', the type of step index is np.ndarray. + step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1 + print(step_idx) + else: + raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.') + + x_test, x_stats = inputs.get_data('test'), inputs.get_stats() + y_test, len_test = multi_pred(exe, gw, gf, program, pred, \ + x_test, args.batch_size, args.n_his, args.n_pred, step_idx) + + # save result + gt = x_test[0:len_test, args.n_his:, :, :].reshape(-1, args.n_route) + y_pred = y_test.reshape(-1, args.n_route) + city_df = pd.read_csv(args.city_file) + city_df = city_df.drop(0) + + np.savetxt( + os.path.join(args.output_path, "groundtruth.csv"), + gt.astype(np.int32), + fmt='%d', + delimiter=',', + header=",".join(city_df['city'])) + np.savetxt( + os.path.join(args.output_path, "prediction.csv"), + y_pred.astype(np.int32), + fmt='%d', + delimiter=",", + header=",".join(city_df['city'])) + + for i in range(step_idx + 1): + evl = evaluation(x_test[0:len_test, step_idx + args.n_his, :, :], + y_test[i], x_stats) + for ix in tmp_idx: + te = evl[ix - 2:ix + 1] + print( + f'Time Step {i + 1}: MAPE {te[0]:7.3%}; MAE {te[1]:4.3f}; RMSE {te[2]:6.3f}.' + ) diff --git a/examples/stgcn/utils/math_utils.py b/examples/stgcn/utils/math_utils.py new file mode 100644 index 0000000..ddc0b3f --- /dev/null +++ b/examples/stgcn/utils/math_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation""" +import os +import sys +import time +import argparse +import numpy as np + + +def z_score(x, mean, std): + """z_score""" + return (x - mean) / std + + +def z_inverse(x, mean, std): + """The inverse of function z_score""" + return x * std + mean + + +def MAPE(v, v_): + """Mean absolute percentage error.""" + return np.mean(np.abs(v_ - v) / (v + 1e-5)) + + +def RMSE(v, v_): + """Mean squared error.""" + return np.sqrt(np.mean((v_ - v)**2)) + + +def MAE(v, v_): + """Mean absolute error.""" + return np.mean(np.abs(v_ - v)) + + +def evaluation(y, y_, x_stats): + """Calculate MAPE, MAE and RMSE between ground truth and prediction.""" + dim = len(y_.shape) + + if dim == 3: + # single_step case + v = z_inverse(y, x_stats['mean'], x_stats['std']) + v_ = z_inverse(y_, x_stats['mean'], x_stats['std']) + return np.array([MAPE(v, v_), MAE(v, v_), RMSE(v, v_)]) + else: + # multi_step case + tmp_list = [] + # y -> [time_step, batch_size, n_route, 1] + y = np.swapaxes(y, 0, 1) + # recursively call + for i in range(y_.shape[0]): + tmp_res = evaluation(y[i], y_[i], x_stats) + tmp_list.append(tmp_res) + return np.concatenate(tmp_list, axis=-1) -- GitLab