提交 504e7981 编写于 作者: L liweibin

add stgcn model

上级 a545a4d2
# STGCN: Spatio-Temporal Graph Convolutional Network
[Spatio-Temporal Graph Convolutional Network \(STGCN\)](https://arxiv.org/pdf/1709.04875.pdf) is a novel deep learning framework to tackle time series prediction problem. Based on PGL, we reproduce STGCN algorithms to predict new confirmed patients in some cities with the historical immigration records.
### Datasets
You can make your customized dataset by the following format:
* input.csv: Historical immigration records with shape of [num\_time\_steps * num\_cities].
* output.csv: New confirmed patients records with shape of [num\_time\_steps * num\_cities].
* W.csv: Weighted Adjacency Matrix with shape of [num\_cities * num\_cities].
* city.csv: Each line is a number and the corresponding city name.
### Dependencies
- paddlepaddle 1.6
- pgl 1.0.0
### How to run
For examples, use gpu to train STGCN on your dataset.
```
python main.py --use_cuda --input_file dataset/input_csv --label_file dataset/output.csv --adj_mat_file dataset/W.csv --city_file dataset/city.csv
```
#### Hyperparameters
- n\_route: Number of city.
- n\_his: "n\_his" time steps of previous observations of historical immigration records.
- n\_pred: Next "n\_pred" time steps of New confirmed patients records.
- Ks: Number of GCN layers.
- Kt: Kernel size of temporal convolution.
- use\_cuda: Use gpu if assign use\_cuda.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""data processing
"""
import numpy as np
import pandas as pd
from utils.math_utils import z_score
class Dataset(object):
"""Dataset
"""
def __init__(self, data, stats):
self.__data = data
self.mean = stats['mean']
self.std = stats['std']
def get_data(self, type): # type: train, val or test
return self.__data[type]
def get_stats(self):
return {'mean': self.mean, 'std': self.std}
def get_len(self, type):
return len(self.__data[type])
def z_inverse(self, type):
return self.__data[type] * self.std + self.mean
def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1):
"""Generate data in the form of standard sequence unit."""
n_slot = day_slot - n_frame + 1
tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0))
for i in range(len_seq):
for j in range(n_slot):
sta = (i + offset) * day_slot + j
end = sta + n_frame
tmp_seq[i * n_slot + j, :, :, :] = np.reshape(
data_seq[sta:end, :], [n_frame, n_route, C_0])
return tmp_seq
def adj_matrx_gen_custom(input_file, city_file):
"""genenrate Adjacency Matrix from file
"""
print("generate adj_matrix data (take long time)...")
# data
df = pd.read_csv(
input_file,
sep='\t',
names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数'])
# 只需要2020年的数据
df['date'] = pd.to_datetime(df['date'], format="%Y%m%d")
df = df.set_index('date')
df = df['2020']
city_df = pd.read_csv(city_file)
# 剔除武汉
city_df = city_df.drop(0)
num = len(city_df)
matrix = np.zeros([num, num])
for i in city_df['city']:
for j in city_df['city']:
if (i == j):
continue
# 选出从i到j的每日人数
cut = df[df['迁出城市'].str.contains(i)]
cut = cut[cut['迁入城市'].str.contains(j)]
# 求均值作为权重
average = cut['人数'].mean()
# 赋值给matrix
i_index = int(city_df[city_df['city'] == i]['num']) - 1
j_index = int(city_df[city_df['city'] == j]['num']) - 1
matrix[i_index, j_index] = average
np.savetxt("dataset/W_74.csv", matrix, delimiter=",")
def data_gen_custom(input_file, output_file, city_file, n, n_his, n_pred,
n_config):
"""data_gen_custom"""
print("generate training data...")
# data
df = pd.read_csv(
input_file,
sep='\t',
names=['date', '迁出省份', '迁出城市', '迁入省份', '迁入城市', '人数'])
# 只需要2020年的数据
df['date'] = pd.to_datetime(df['date'], format="%Y%m%d")
df = df.set_index('date')
df = df['2020']
city_df = pd.read_csv(city_file)
input_df = pd.DataFrame()
out_df_wuhan = df[df['迁出城市'].str.contains('武汉')]
for i in city_df['city']:
# 筛选迁入城市
in_df_i = out_df_wuhan[out_df_wuhan['迁入城市'].str.contains(i)]
# 确保按时间升序
# in_df_i.sort_values("date",inplace=True)
# 按时间插入
in_df_i.reset_index(drop=True, inplace=True)
input_df[i] = in_df_i['人数']
# 替换Nan值
input_df = input_df.replace(np.nan, 0)
x = input_df
y = pd.read_csv(output_file)
# 删除第1列
x.drop(
x.columns[x.columns.str.contains(
'unnamed', case=False)],
axis=1,
inplace=True)
y = y.drop(columns=['date'])
# 剔除迁入武汉的数据
x = x.drop(columns=['武汉'])
y = y.drop(columns=['武汉'])
# param
n_val, n_test = n_config
n_train = len(y) - n_val - n_test - 2
# (?,26,74,1)
df = pd.DataFrame(columns=x.columns)
for i in range(len(y) - n_pred + 1):
df = df.append(x[i:i + n_his])
df = df.append(y[i:i + n_pred])
data = df.values.reshape(-1, n_his + n_pred, n,
1) # n == num_nodes == city num
x_stats = {'mean': np.mean(data), 'std': np.std(data)}
x_train = data[:n_train]
x_val = data[n_train:n_train + n_val]
x_test = data[n_train + n_val:]
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
print("generate successfully!")
return dataset
def data_gen_mydata(input_file, label_file, n, n_his, n_pred, n_config):
"""data processing
"""
# data
x = pd.read_csv(input_file)
y = pd.read_csv(label_file)
x = x.drop(columns=['date'])
y = y.drop(columns=['date'])
x = x.drop(columns=['武汉'])
y = y.drop(columns=['武汉'])
# param
n_val, n_test = n_config
n_train = len(y) - n_val - n_test - 2
# (?,26,74,1)
df = pd.DataFrame(columns=x.columns)
for i in range(len(y) - n_pred + 1):
df = df.append(x[i:i + n_his])
df = df.append(y[i:i + n_pred])
data = df.values.reshape(-1, n_his + n_pred, n, 1)
x_stats = {'mean': np.mean(data), 'std': np.std(data)}
x_train = data[:n_train]
x_val = data[n_train:n_train + n_val]
x_test = data[n_train + n_val:]
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset
def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288):
"""Source file load and dataset generation."""
n_train, n_val, n_test = data_config
# generate training, validation and test data
try:
data_seq = pd.read_csv(file_path, header=None).values
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot)
seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot)
seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route,
day_slot)
# x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation.
x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)}
# x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size].
x_train = z_score(seq_train, x_stats['mean'], x_stats['std'])
x_val = z_score(seq_val, x_stats['mean'], x_stats['std'])
x_test = z_score(seq_test, x_stats['mean'], x_stats['std'])
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset
def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False):
"""Data iterator in batch.
Args:
inputs: np.ndarray, [len_seq, n_frame, n_route, C_0], standard sequence units.
batch_size: int, size of batch.
dynamic_batch: bool, whether changes the batch size in the last batch
if its length is less than the default.
shuffle: bool, whether shuffle the batches.
"""
len_inputs = len(inputs)
if shuffle:
idx = np.arange(len_inputs)
np.random.shuffle(idx)
for start_idx in range(0, len_inputs, batch_size):
end_idx = start_idx + batch_size
if end_idx > len_inputs:
if dynamic_batch:
end_idx = len_inputs
else:
break
if shuffle:
slide = idx[start_idx:end_idx]
else:
slide = slice(start_idx, end_idx)
yield inputs[slide]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PGL Graph
"""
import sys
import os
import numpy as np
import pandas as pd
from pgl.graph import Graph
def weight_matrix(file_path, sigma2=0.1, epsilon=0.5, scaling=True):
"""Load weight matrix function."""
try:
W = pd.read_csv(file_path, header=None).values
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
# check whether W is a 0/1 matrix.
if set(np.unique(W)) == {0, 1}:
print('The input graph is a 0/1 matrix; set "scaling" to False.')
scaling = False
if scaling:
n = W.shape[0]
W = W / 10000.
W2, W_mask = W * W, np.ones([n, n]) - np.identity(n)
# refer to Eq.10
return np.exp(-W2 / sigma2) * (
np.exp(-W2 / sigma2) >= epsilon) * W_mask
else:
return W
class GraphFactory(object):
"""GraphFactory"""
def __init__(self, args):
self.args = args
self.adj_matrix = weight_matrix(self.args.adj_mat_file)
L = np.eye(self.adj_matrix.shape[0]) + self.adj_matrix
D = np.sum(self.adj_matrix, axis=1)
# L = D - self.adj_matrix
# import ipdb; ipdb.set_trace()
edges = []
weights = []
for i in range(self.adj_matrix.shape[0]):
for j in range(self.adj_matrix.shape[1]):
edges.append([i, j])
weights.append(L[i][j])
self.edges = np.array(edges, dtype=np.int64)
self.weights = np.array(weights, dtype=np.float32).reshape(-1, 1)
self.norm = np.zeros_like(D, dtype=np.float32)
self.norm[D > 0] = np.power(D[D > 0], -0.5)
self.norm = self.norm.reshape(-1, 1)
def build_graph(self, x_batch):
"""build graph"""
B, T, n, _ = x_batch.shape
batch = B * T
batch_edges = []
for i in range(batch):
batch_edges.append(self.edges + (i * n))
batch_edges = np.vstack(batch_edges)
num_nodes = B * T * n
node_feat = {'norm': np.tile(self.norm, [batch, 1])}
edge_feat = {'weights': np.tile(self.weights, [batch, 1])}
graph = Graph(
num_nodes=num_nodes,
edges=batch_edges,
node_feat=node_feat,
edge_feat=edge_feat)
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the training process of STGCN model.
"""
import os
import sys
import time
import argparse
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils.logger import log
from data_loader.data_utils import data_gen_mydata, gen_batch
from data_loader.graph import GraphFactory
from models.model import STGCNModel
from models.tester import model_inference, model_test
def main(args):
"""main"""
PeMS = data_gen_mydata(args.input_file, args.label_file, args.n_route,
args.n_his, args.n_pred, (args.n_val, args.n_test))
log.info(PeMS.get_stats())
log.info(PeMS.get_len('train'))
gf = GraphFactory(args)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"gw",
place,
node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")])
model = STGCNModel(args, gw)
train_loss, y_pred = model.forward()
infer_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
epoch_step = int(PeMS.get_len('train') / args.batch_size) + 1
lr = fl.exponential_decay(
learning_rate=args.lr,
decay_steps=5 * epoch_step,
decay_rate=0.7,
staircase=True)
if args.opt == 'RMSProp':
train_op = fluid.optimizer.RMSPropOptimizer(lr).minimize(
train_loss)
elif args.opt == 'ADAM':
train_op = fluid.optimizer.Adam(lr).minimize(train_loss)
exe = fluid.Executor(place)
exe.run(startup_program)
if args.inf_mode == 'sep':
# for inference mode 'sep', the type of step index is int.
step_idx = args.n_pred - 1
tmp_idx = [step_idx]
min_val = min_va_val = np.array([4e1, 1e5, 1e5])
elif args.inf_mode == 'merge':
# for inference mode 'merge', the type of step index is np.ndarray.
step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1
min_val = min_va_val = np.array([4e1, 1e5, 1e5]) * len(step_idx)
else:
raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.')
step = 0
for epoch in range(1, args.epochs + 1):
for idx, x_batch in enumerate(
gen_batch(
PeMS.get_data('train'),
args.batch_size,
dynamic_batch=True,
shuffle=True)):
x = np.array(x_batch[:, 0:args.n_his, :, :], dtype=np.float32)
graph = gf.build_graph(x)
feed = gw.to_feed(graph)
feed['input'] = np.array(
x_batch[:, 0:args.n_his + 1, :, :], dtype=np.float32)
b_loss, b_lr = exe.run(train_program,
feed=feed,
fetch_list=[train_loss, lr])
if idx % 5 == 0:
log.info("epoch %d | step %d | lr %.6f | loss %.6f" %
(epoch, idx, b_lr[0], b_loss[0]))
min_va_val, min_val = \
model_inference(exe, gw, gf, infer_program, y_pred, PeMS, args, \
step_idx, min_va_val, min_val)
for ix in tmp_idx:
va, te = min_va_val[ix - 2:ix + 1], min_val[ix - 2:ix + 1]
print(f'Time Step {ix + 1}: '
f'MAPE {va[0]:7.3%}, {te[0]:7.3%}; '
f'MAE {va[1]:4.3f}, {te[1]:4.3f}; '
f'RMSE {va[2]:6.3f}, {te[2]:6.3f}.')
if epoch % 5 == 0:
model_test(exe, gw, gf, infer_program, y_pred, PeMS, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--n_route', type=int, default=74)
parser.add_argument('--n_his', type=int, default=23)
parser.add_argument('--n_pred', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--save', type=int, default=10)
parser.add_argument('--Ks', type=int, default=3) #equal to num_layers
parser.add_argument('--Kt', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--keep_prob', type=float, default=1.0)
parser.add_argument('--opt', type=str, default='RMSProp')
parser.add_argument('--inf_mode', type=str, default='sep')
parser.add_argument('--input_file', type=str, default='dataset/input.csv')
parser.add_argument('--label_file', type=str, default='dataset/output.csv')
parser.add_argument(
'--city_file', type=str, default='dataset/crawl_list.csv')
parser.add_argument('--adj_mat_file', type=str, default='dataset/W_74.csv')
parser.add_argument('--output_path', type=str, default='./outputs/')
parser.add_argument('--n_val', type=str, default=1)
parser.add_argument('--n_test', type=str, default=1)
parser.add_argument('--use_cuda', action='store_true')
args = parser.parse_args()
blocks = [[1, 32, 64], [64, 32, 128]]
args.blocks = blocks
log.info(args)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file implement the STGCN model.
"""
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
class STGCNModel(object):
"""Implementation of Spatio-Temporal Graph Convolutional Networks"""
def __init__(self, args, gw):
self.args = args
self.gw = gw
self.input = fl.data(
name="input",
shape=[None, args.n_his + 1, args.n_route, 1],
dtype="float32")
def forward(self):
"""forward"""
x = self.input[:, 0:self.args.n_his, :, :]
# Ko>0: kernel size of temporal convolution in the output layer.
Ko = self.args.n_his
# ST-Block
for i, channels in enumerate(self.args.blocks):
x = self.st_conv_block(
x,
self.args.Ks,
self.args.Kt,
channels,
"st_conv_%d" % i,
self.args.keep_prob,
act_func='GLU')
# output layer
if Ko > 1:
y = self.output_layer(x, Ko, 'output_layer')
else:
raise ValueError(f'ERROR: kernel size Ko must be greater than 1, \
but received "{Ko}".')
label = self.input[:, self.args.n_his:self.args.n_his + 1, :, :]
train_loss = fl.reduce_sum((y - label) * (y - label))
single_pred = y[:, 0, :, :] # shape: [batch, n, 1]
return train_loss, single_pred
def st_conv_block(self,
x,
Ks,
Kt,
channels,
name,
keep_prob,
act_func='GLU'):
"""Spatio-Temporal convolution block"""
c_si, c_t, c_oo = channels
x_s = self.temporal_conv_layer(
x, Kt, c_si, c_t, "%s_tconv_in" % name, act_func=act_func)
x_t = self.spatio_conv_layer(x_s, Ks, c_t, c_t, "%s_sonv" % name)
x_o = self.temporal_conv_layer(x_t, Kt, c_t, c_oo,
"%s_tconv_out" % name)
x_ln = fl.layer_norm(x_o)
return fl.dropout(x_ln, dropout_prob=(1.0 - keep_prob))
def temporal_conv_layer(self, x, Kt, c_in, c_out, name, act_func='relu'):
"""Temporal convolution layer"""
_, T, n, _ = x.shape
if c_in > c_out:
x_input = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name))
elif c_in < c_out:
# if the size of input channel is less than the output,
# padding x to the same size of output channel.
pad = fl.fill_constant_batch_size_like(
input=x,
shape=[-1, T, n, c_out - c_in],
dtype="float32",
value=0.0)
x_input = fl.concat([x, pad], axis=3)
else:
x_input = x
# x_input = x_input[:, Kt - 1:T, :, :]
if act_func == 'GLU':
# gated liner unit
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[2 * c_out],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init),
)
x_conv = fl.conv2d(
input=x,
num_filters=2 * c_out,
filter_size=[Kt, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name))
x_conv = x_conv + bt
return (x_conv[:, :, :, 0:c_out] + x_input
) * fl.sigmoid(x_conv[:, :, :, -c_out:])
else:
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[c_out],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init),
)
x_conv = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[Kt, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_wt" % name))
x_conv = x_conv + bt
if act_func == "linear":
return x_conv
elif act_func == "sigmoid":
return fl.sigmoid(x_conv)
elif act_func == "relu":
return fl.relu(x_conv + x_input)
else:
raise ValueError(
f'ERROR: activation function "{act_func}" is not defined.')
def spatio_conv_layer(self, x, Ks, c_in, c_out, name):
"""Spatio convolution layer"""
_, T, n, _ = x.shape
if c_in > c_out:
x_input = fl.conv2d(
input=x,
num_filters=c_out,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d_1" % name))
elif c_in < c_out:
# if the size of input channel is less than the output,
# padding x to the same size of output channel.
pad = fl.fill_constant_batch_size_like(
input=x,
shape=[-1, T, n, c_out - c_in],
dtype="float32",
value=0.0)
x_input = fl.concat([x, pad], axis=3)
else:
x_input = x
for i in range(Ks):
# x_input shape: [B,T, num_nodes, c_out]
x_input = fl.reshape(x_input, [-1, c_out])
x_input = self.message_passing(
self.gw,
x_input,
name="%s_mp_%d" % (name, i),
norm=self.gw.node_feat["norm"])
x_input = fl.fc(x_input,
size=c_out,
bias_attr=False,
param_attr=fluid.ParamAttr(name="%s_gcn_fc_%d" %
(name, i)))
bias = fluid.layers.create_parameter(
shape=[c_out],
dtype='float32',
is_bias=True,
name='%s_gcn_bias_%d' % (name, i))
x_input = fluid.layers.elementwise_add(x_input, bias, act="relu")
x_input = fl.reshape(x_input, [-1, T, n, c_out])
return x_input
def message_passing(self, gw, feature, name, norm=None):
"""Message passing layer"""
def send_src_copy(src_feat, dst_feat, edge_feat):
"""send function"""
return src_feat["h"] * edge_feat['w']
if norm is not None:
feature = feature * norm
msg = gw.send(
send_src_copy,
nfeat_list=[("h", feature)],
efeat_list=[('w', gw.edge_feat['weights'])])
output = gw.recv(msg, "sum")
if norm is not None:
output = output * norm
return output
def output_layer(self, x, T, name, act_func='GLU'):
"""Output layer"""
_, _, n, channel = x.shape
# maps multi-steps to one.
x_i = self.temporal_conv_layer(
x=x,
Kt=T,
c_in=channel,
c_out=channel,
name="%s_in" % name,
act_func=act_func)
x_ln = fl.layer_norm(x_i)
x_o = self.temporal_conv_layer(
x=x_ln,
Kt=1,
c_in=channel,
c_out=channel,
name="%s_out" % name,
act_func='sigmoid')
# maps multi-channels to one.
x_fc = self.fully_con_layer(
x=x_o, n=n, channel=channel, name="%s_fc" % name)
return x_fc
def fully_con_layer(self, x, n, channel, name):
"""Fully connected layer"""
bt_init = fluid.initializer.ConstantInitializer(value=0.0)
bt = fl.create_parameter(
shape=[n, 1],
dtype="float32",
attr=fluid.ParamAttr(
name="%s_bt" % name, trainable=True, initializer=bt_init), )
x_conv = fl.conv2d(
input=x,
num_filters=1,
filter_size=[1, 1],
stride=[1, 1],
padding="SAME",
data_format="NHWC",
param_attr=fluid.ParamAttr(name="%s_conv2d" % name))
x_conv = x_conv + bt
return x_conv
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file implement the testing process of STGCN model.
"""
import os
import sys
import time
import argparse
import numpy as np
import pandas as pd
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import pgl
from pgl.utils.logger import log
from data_loader.data_utils import gen_batch
from utils.math_utils import evaluation
def multi_pred(exe, gw, gf, program, y_pred, seq, batch_size, \
n_his, n_pred, step_idx, dynamic_batch=True):
"""multi step prediction"""
pred_list = []
for i in gen_batch(
seq, min(batch_size, len(seq)), dynamic_batch=dynamic_batch):
# Note: use np.copy() to avoid the modification of source data.
test_seq = np.copy(i[:, 0:n_his + 1, :, :]).astype(np.float32)
graph = gf.build_graph(i[:, 0:n_his, :, :])
feed = gw.to_feed(graph)
step_list = []
for j in range(n_pred):
feed['input'] = test_seq
pred = exe.run(program, feed=feed, fetch_list=[y_pred])
if isinstance(pred, list):
pred = np.array(pred[0])
test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :]
test_seq[:, n_his - 1, :, :] = pred
step_list.append(pred)
pred_list.append(step_list)
# pred_array -> [n_pred, len(seq), n_route, C_0)
pred_array = np.concatenate(pred_list, axis=1)
return pred_array, pred_array.shape[1]
def model_inference(exe, gw, gf, program, pred, inputs, args, step_idx,
min_va_val, min_val):
"""inference model"""
x_val, x_test, x_stats = inputs.get_data('val'), inputs.get_data(
'test'), inputs.get_stats()
if args.n_his + args.n_pred > x_val.shape[1]:
raise ValueError(
f'ERROR: the value of n_pred "{args.n_pred}" exceeds the length limit.'
)
# y_val shape: [n_pred, len(x_val), n_route, C_0)
y_val, len_val = multi_pred(exe, gw, gf, program, pred, \
x_val, args.batch_size, args.n_his, args.n_pred, step_idx)
evl_val = evaluation(x_val[0:len_val, step_idx + args.n_his, :, :],
y_val[step_idx], x_stats)
# chks: indicator that reflects the relationship of values between evl_val and min_va_val.
chks = evl_val < min_va_val
# update the metric on test set, if model's performance got improved on the validation.
if sum(chks):
min_va_val[chks] = evl_val[chks]
y_pred, len_pred = multi_pred(exe, gw, gf, program, pred, \
x_test, args.batch_size, args.n_his, args.n_pred, step_idx)
evl_pred = evaluation(x_test[0:len_pred, step_idx + args.n_his, :, :],
y_pred[step_idx], x_stats)
min_val = evl_pred
return min_va_val, min_val
def model_test(exe, gw, gf, program, pred, inputs, args):
"""test model"""
if args.inf_mode == 'sep':
# for inference mode 'sep', the type of step index is int.
step_idx = args.n_pred - 1
tmp_idx = [step_idx]
elif args.inf_mode == 'merge':
# for inference mode 'merge', the type of step index is np.ndarray.
step_idx = tmp_idx = np.arange(3, args.n_pred + 1, 3) - 1
print(step_idx)
else:
raise ValueError(f'ERROR: test mode "{args.inf_mode}" is not defined.')
x_test, x_stats = inputs.get_data('test'), inputs.get_stats()
y_test, len_test = multi_pred(exe, gw, gf, program, pred, \
x_test, args.batch_size, args.n_his, args.n_pred, step_idx)
# save result
gt = x_test[0:len_test, args.n_his:, :, :].reshape(-1, args.n_route)
y_pred = y_test.reshape(-1, args.n_route)
city_df = pd.read_csv(args.city_file)
city_df = city_df.drop(0)
np.savetxt(
os.path.join(args.output_path, "groundtruth.csv"),
gt.astype(np.int32),
fmt='%d',
delimiter=',',
header=",".join(city_df['city']))
np.savetxt(
os.path.join(args.output_path, "prediction.csv"),
y_pred.astype(np.int32),
fmt='%d',
delimiter=",",
header=",".join(city_df['city']))
for i in range(step_idx + 1):
evl = evaluation(x_test[0:len_test, step_idx + args.n_his, :, :],
y_test[i], x_stats)
for ix in tmp_idx:
te = evl[ix - 2:ix + 1]
print(
f'Time Step {i + 1}: MAPE {te[0]:7.3%}; MAE {te[1]:4.3f}; RMSE {te[2]:6.3f}.'
)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation"""
import os
import sys
import time
import argparse
import numpy as np
def z_score(x, mean, std):
"""z_score"""
return (x - mean) / std
def z_inverse(x, mean, std):
"""The inverse of function z_score"""
return x * std + mean
def MAPE(v, v_):
"""Mean absolute percentage error."""
return np.mean(np.abs(v_ - v) / (v + 1e-5))
def RMSE(v, v_):
"""Mean squared error."""
return np.sqrt(np.mean((v_ - v)**2))
def MAE(v, v_):
"""Mean absolute error."""
return np.mean(np.abs(v_ - v))
def evaluation(y, y_, x_stats):
"""Calculate MAPE, MAE and RMSE between ground truth and prediction."""
dim = len(y_.shape)
if dim == 3:
# single_step case
v = z_inverse(y, x_stats['mean'], x_stats['std'])
v_ = z_inverse(y_, x_stats['mean'], x_stats['std'])
return np.array([MAPE(v, v_), MAE(v, v_), RMSE(v, v_)])
else:
# multi_step case
tmp_list = []
# y -> [time_step, batch_size, n_route, 1]
y = np.swapaxes(y, 0, 1)
# recursively call
for i in range(y_.shape[0]):
tmp_res = evaluation(y[i], y_[i], x_stats)
tmp_list.append(tmp_res)
return np.concatenate(tmp_list, axis=-1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册