diff --git a/demo/traffic_prediction/dataprovider.py b/demo/traffic_prediction/dataprovider.py index b91506726076c12d574eb0aba283fb59071503ef..19719350f2d25ab333e11d703d4f8efb36414e77 100644 --- a/demo/traffic_prediction/dataprovider.py +++ b/demo/traffic_prediction/dataprovider.py @@ -34,11 +34,11 @@ def initHook(settings, file_list, **kwargs): settings.pool_size = sys.maxint #Use a time seires of the past as feature. #Dense_vector's expression form is [float,float,...,float] - settings.slots = [dense_vector(TERM_NUM)] + settings.input_types = [dense_vector(TERM_NUM)] #There are next FORECASTING_NUM fragments you need predict. #Every predicted condition at time point has four states. for i in range(FORECASTING_NUM): - settings.slots.append(integer_value(LABEL_VALUE_NUM)) + settings.input_types.append(integer_value(LABEL_VALUE_NUM)) @provider( @@ -57,7 +57,7 @@ def process(settings, file_name): pre_spd = map(float, speeds[i - TERM_NUM:i]) # Integer value need predicting, values start from 0, so every one minus 1. - fol_spd = [i - 1 for i in speeds[i:i + FORECASTING_NUM]] + fol_spd = [j - 1 for j in speeds[i:i + FORECASTING_NUM]] # Predicting label is missing, abandon the sample. if -1 in fol_spd: @@ -67,7 +67,7 @@ def process(settings, file_name): def predict_initHook(settings, file_list, **kwargs): settings.pool_size = sys.maxint - settings.slots = [dense_vector(TERM_NUM)] + settings.input_types = [dense_vector(TERM_NUM)] @provider(init_hook=predict_initHook, should_shuffle=False) diff --git a/demo/traffic_prediction/gen_result.py b/demo/traffic_prediction/gen_result.py index cb8f6e68322cc27031ab58b67e1763ba3bd337ee..d6c1b033700813a54c5936e7f9cd2237dc56a56d 100644 --- a/demo/traffic_prediction/gen_result.py +++ b/demo/traffic_prediction/gen_result.py @@ -1,3 +1,17 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + res = [] with open('./rank-00000') as f: for line in f: diff --git a/demo/traffic_prediction/trainer_config.py b/demo/traffic_prediction/trainer_config.py index c8755f7f3c28624e9825ba136609f454e4d1c236..bb6a4ac98755aac0b35ad8fe8198643940bd7752 100755 --- a/demo/traffic_prediction/trainer_config.py +++ b/demo/traffic_prediction/trainer_config.py @@ -1,5 +1,16 @@ -#!/usr/bin/env/python -#-*python-*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from paddle.trainer_config_helpers import * ################################### DATA Configuration #############################################