未验证 提交 221a3df4 编写于 作者: A anpark 提交者: GitHub

add PaddleST/Research/CIKM2019-MONOPOLY paper code & readme (#3199)

* add CIKM2019 MONOPOLY paper code

* add CIKM2019 MONOPOLY readme

* add CIKM2019 MONOPOLY readme
上级 d9f31608
# MONOPOLY
## 简介
### 任务说明
互联网地图,作为一个经典的时空大数据平台,收集了大量有关固定资产(Point-of-Interest,简称POI)、出行轨迹、地点查询等相关信息。
Monopoly是一个POI商业智能算法,能够利用少量的房产价格,对大量其他的固定资产进行价值估计。
文章地址:XXXX
### 研究意义与发现
1)Monopoly能够帮助我们发现:各个城市居民对于不同类型公共资产价格评估的偏好,并且给出量化分析。
2)Monopoly能够帮助我们探索:不同城市居民对于私有房价评估的偏好,并且给出量化分析。
3)Monopoly能够帮助我们确定:评估一个固定资产价格需要考虑的空间范围。
### 效果说明
## 安装说明
1. paddle安装
本项目依赖于Paddle Fluid 1.5.1 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装
2. 下载代码
克隆数据集代码库到本地, 本代码依赖[Paddle-EPEP框架](https://github.com/PaddlePaddle/epep)
```
git clone https://github.com/PaddlePaddle/epep.git
cd epep
git clone https://github.com/PaddlePaddle/models.git
ln -s models/PaddleST/Research/CIKM2019-MONOPOLY/conf/house_price conf/house_price
ln -s models/PaddleST/Research/CIKM2019-MONOPOLY/datasets/house_price datasets/house_price
ln -s models/PaddleST/Research/CIKM2019-MONOPOLY/nets/house_price nets/house_price
```
3. 环境依赖
python版本依赖python 2.7
### 开始第一次模型调用
1. 数据准备
TODO
```
#script to download
```
2. 模型训练
```
sh run.sh -c conf/house_price/house_price.local.conf -m train [ -g 0 ]
```
3. 模型评估
```
sh run.sh -c conf/house_price/house_price.local.conf -m pred
cat $c.out | grep ^qid | python utils/calc_metric.py
```
## Reference this paper
=====
[DEFAULT]
sample_seed: 1234
# The value in `DEFAULT` section will be referenced by other sections.
# For convinence, we will put the variables which changes frequently here and
# let other section refer them
# Input settings
dataset_name: HousePrice
max_house_num: 100
max_public_num: 100
batch_shuffle: False
CUDA_VISIBLE_DEVICES: 0
FLAGS_fraction_of_gpu_memory_to_use: 0.8
# Input settings
#reader: dataset | pyreader | async | datafeed | sync
#data_reader: pyreader
data_reader: datafeed
dataset_mode: Memory
#local-cpu | local-gpu
platform: local-gpu
#platform: local-cpu
dis_radius: 1.0
avg_eval: False
with_car_dis: False
with_house_attr: False
bj_batch_size: 5256
#bj_batch_size: 7573
#bj_batch_size: 423
sh_batch_size: 8126
#sh_batch_size: 11604
#sh_batch_size: 822
gz_batch_size: 4560
#gz_batch_size: 6508
#gz_batch_size: 367
sz_batch_size: 2693
#sz_batch_size: 3849
#sz_batch_size: 192
city_name: <c>
num_samples_train: ${DEFAULT:<c>_batch_size}
train_batch_size: ${DEFAULT:<c>_batch_size}
#train_batch_size: 2
num_samples_eval: 10
eval_batch_size: 10
kv_path: None
# Model settings
model_name: HousePrice
preprocessing_name: None
file_pattern: part-
num_in_dimension: 3
num_out_dimension: 1
# Learning options
max_number_of_steps: None
init_learning_rate: 0.2
emb_lr: ${DEFAULT:init_learning_rate}
fc_lr: ${DEFAULT:init_learning_rate}
base_lr: ${DEFAULT:init_learning_rate}
[Convert]
# The name of the dataset to convert
dataset_name: ${DEFAULT:dataset_name}
#dataset_dir: ${DEFAULT:dataset_dir}
dataset_dir: stream
# The output Records file name prefix.
dataset_split_name: train
# The number of Records per shard
num_per_shard: 100000
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The output file name pattern with two placeholders ("%s" and "%d"),
# it must correspond to the glob `file_pattern' in Train and Evaluate
# config sections
[Train]
#######################
# Dataset Configure #
#######################
# The name of the dataset to load
dataset_name: ${DEFAULT:dataset_name}
# The directory where the dataset files are stored
dataset_dir: ${DEFAULT:dataset_dir}
file_list: ../tmp/data/poi/raw/<c>/poi_sample.train
# dataset_split_name
dataset_split_name: train
# The glob pattern for data path, `file_pattern' must contain only one "%s"
# which is the placeholder for split name (such as 'train', 'validation')
file_pattern: ${DEFAULT:file_pattern}
# The file type text or record
file_type: record
# kv path, used in image_sim
kv_path: ${DEFAULT:kv_path}
# The number of input sample for training
num_samples: ${DEFAULT:num_samples_train}
# The number of parallel readers that read data from the dataset
num_readers: 2
# The number of threads used to create the batches
num_preprocessing_threads: 4
# Number of epochs from dataset source
num_epochs_input: 200
###########################
# Basic Train Configure #
###########################
# Directory where checkpoints and event logs are written to.
train_dir: ../tmp/model/house_price/save_model/${DEFAULT:city_name}
# The max number of ckpt files to store variables
save_max_to_keep: 40
# The frequency with which the model is saved, in steps.
save_model_steps: 5
# The name of the architecture to train
model_name: ${DEFAULT:model_name}
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The dimensions of net output vector, it will be num of classes in image classify task
num_out_dimension: ${DEFAULT:num_out_dimension}
#####################################
# Training Optimization Configure #
#####################################
# The number of samples in each batch
batch_size: ${DEFAULT:train_batch_size}
# The maximum number of training steps
max_number_of_steps: ${DEFAULT:max_number_of_steps}
# The weight decay on the model weights
#weight_decay: 0.00000001
weight_decay: None
# The decay to use for the moving average. If left as None, then moving averages are not used
moving_average_decay: None
# ***************** learning rate options ***************** #
# Initial learning rate
init_learning_rate: ${DEFAULT:init_learning_rate}
# Specifies how the learning rate is decayed. One of "fixed", "exponential" or "polynomial"
learning_rate_decay_type: fixed
# Learning rate decay factor
learning_rate_decay_factor: 0.1
num_learning_rate_warmup_epochs: None
# The minimal end learning rate used by a polynomial decay learning rate
end_learning_rate: 0.0001
# Number of epochs after which learning rate decays
num_epochs_per_decay: 10
# A boolean, whether or not it should cycle beyond decay_steps
learning_rate_polynomial_decay_cycle: False
# ******************* optimizer options ******************* #
# The name of the optimizer, one of the following:
# "adadelta", "adagrad", "adam", "ftrl", "momentum", "sgd" or "rmsprop"
#optimizer: weight_decay_adam
optimizer: adam
#optimizer: sgd
# Epsilon term for the optimizer, used for adadelta, adam, rmsprop
opt_epsilon: 1e-6
# conf for adadelta
# The decay rate for adadelta
adadelta_rho: 0.95
# Starting value for the AdaGrad accumulators
adagrad_initial_accumulator_value: 0.1
# conf for adam
# The exponential decay rate for the 1st moment estimates
adam_beta1: 0.9
# The exponential decay rate for the 2nd moment estimates
adam_beta2: 0.999
adam_weight_decay: 0.01
#adam_exclude_from_weight_decay: LayerNorm,layer_norm,bias
# conf for ftrl
# The learning rate power
ftrl_learning_rate_power: -0.1
# Starting value for the FTRL accumulators
ftrl_initial_accumulator_value: 0.1
# The FTRL l1 regularization strength
ftrl_l1: 0.0
# The FTRL l2 regularization strength
ftrl_l2: 0.01
# conf for momentum
# The momentum for the MomentumOptimizer and RMSPropOptimizer
momentum: 0.9
# conf for rmsprop
# Decay term for RMSProp
rmsprop_decay: 0.9
# Number of model clones to deploy
num_gpus: 1
# The frequency with which logs are trace.
trace_every_n_steps: 5
[Evaluate]
#######################
# Dataset Configure #
#######################
# The name of the dataset to load
dataset_name: ${DEFAULT:dataset_name}
# The name of the train/test split
#dataset_split_name: validation
dataset_split_name: train
# The glob pattern for data path, `file_pattern' must contain only one "%s"
# which is the placeholder for split name (such as 'train', 'validation')
file_pattern: ${DEFAULT:file_pattern}
#reader: dataset | pyreader | async | datafeed | sync
data_reader: datafeed
#local-cpu | local-gpu
platform: local-cpu
file_list: ../tmp/data/poi/raw/<c>/poi_sample.test
# The file type or record
file_type: text
# kv path, used in image_sim
kv_path: ${DEFAULT:kv_path}
# The number of input sample for evaluation
num_samples: ${DEFAULT:num_samples_eval}
# The number of parallel readers that read data from the dataset
num_readers: 2
# The number of threads used to create the batches
num_preprocessing_threads: 2
# Number of epochs from dataset source
num_epochs_input: 1
# The name of the architecture to evaluate
model_name: ${DEFAULT:model_name}
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The dimensions of net output vector, it will be num of classes in image classify task
num_out_dimension: ${DEFAULT:num_out_dimension}
# Directory where the results are saved to
eval_dir: ${Train:train_dir}/epoch<s>
# The number of samples in each batch
batch_size: ${DEFAULT:eval_batch_size}
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
File: baseline_sklearn.py
"""
import sys
import numpy as np
from sklearn import linear_model
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
class CityInfo(object):
"""
city info
"""
def __init__(self, city):
self._set_env(city)
def _set_env(self, city):
if city == 'sh':
#sh
self.business_num = 389
self.wuye_num = 2931
self.kfs_num = 4056
self.age_num = 104
self.lou_num = 11
self.average_price = 5.5669712771458115
self.house_num = 11604
self.public_num = 970566 + 1
elif city == 'gz':
#gz
self.business_num = 246
self.wuye_num = 1436
self.kfs_num = 1753
self.age_num = 48
self.lou_num = 12
self.average_price = 3.120921450522434
self.house_num = 6508
self.public_num = 810465 + 1
elif city == 'sz':
#sz
self.business_num = 127
self.wuye_num = 1096
self.kfs_num = 1426
self.age_num = 40
self.lou_num = 15
self.average_price = 5.947788464536243
self.house_num = 3849
self.public_num = 724320 + 1
else:#bj, default
self.business_num = 429
self.wuye_num = 1548
self.kfs_num = 1730
self.age_num = 80
self.lou_num = 15
self.average_price = 6.612481698138123
self.house_num = 7573
self.public_num = 843426 + 1
if __name__ == '__main__':
svd = sys.argv[1]
model = sys.argv[2]
if model == 'lr':
clf = linear_model.LinearRegression()
elif model == 'gb':
clf = GradientBoostingRegressor()
#clf = RandomForestRegressor()
#clf = DecisionTreeRegressor()
else:
clf = MLPRegressor(hidden_layer_sizes=(20, ))
x_train = []
y_train = []
x_test = []
y_test = []
with open(svd, 'r') as f:
for line in f:
ll = line.strip('\r\n').split()
if ll[0] == 'train':
y_train.append(float(ll[1]))
x_train.append(map(float, ll[2:]))
else:
y_test.append(float(ll[1]))
x_test.append(map(float, ll[2:]))
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
mae = mean_absolute_error(y_test, y_pred)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
r2 = r2_score(y_test, y_pred)
print("%s\t%s\t%s\t%s" % (model, mae, rmse, r2))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
File: feature_preprocess.py
"""
import sys
import random
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import coo_matrix
from datasets.house_price.baseline_sklearn import CityInfo
def parse_line(line, labels, data, row, col, radius, city_info, num_row,
max_house_num, max_public_num):
"""
parse line
"""
ll = line.strip('\r\n').split('\t')
labels.append(ll[0].split()[0])
business = int(ll[1].split()[0])
dis_info = ll[3].split()
if business >= 0 and business < city_info.business_num:
data.append(1)
row.append(num_row)
col.append(business)
idx = 0
h_num = 0
p_num = 0
for i in ll[2].split():
if float(dis_info[idx]) > radius:
idx += 1
continue
if ':' in i:
if h_num > max_house_num:
continue
h_num += 1
data.append(1)
row.append(num_row)
col.append(city_info.business_num + int(i.split(':')[0]) - 1)
else:
if p_num > max_public_num:
break
p_num += 1
data.append(1)
row.append(num_row)
col.append(city_info.business_num + city_info.house_num + int(i) - 1)
idx += 1
if __name__ == '__main__':
test = sys.argv[1]
radius = float(sys.argv[2])
max_house_num = float(sys.argv[3])
max_public_num = float(sys.argv[4])
city_info = CityInfo(sys.argv[5])
train_data = []
train_row = []
train_col = []
train_labels = []
num_row = 0
for line in sys.stdin:
parse_line(line, train_labels, train_data, train_row, train_col, radius, city_info,
num_row, max_house_num, max_public_num)
num_row += 1
coo = coo_matrix((train_data, (train_row, train_col)),
shape=(num_row, city_info.business_num + city_info.house_num + city_info.public_num))
svd = TruncatedSVD(n_components=200, n_iter=10, random_state=0)
svd.fit(coo.tocsr())
x_train = svd.transform(coo.tocsr())
for i in range(len(x_train)):
print("train %s %s" % (train_labels[i], " ".join(map(str, x_train[i]))))
test_data = []
test_row = []
test_col = []
test_labels = []
with open(test, 'r') as f:
num_row = 0
for line in f:
parse_line(line, test_labels, test_data, test_row, test_col, radius, city_info,
num_row, max_house_num, max_public_num)
num_row += 1
coo = coo_matrix((test_data, (test_row, test_col)),
shape=(num_row, city_info.business_num + city_info.house_num + city_info.public_num))
x_test = svd.transform(coo.tocsr())
for i in range(len(x_test)):
print("test %s %s" % (test_labels[i], " ".join(map(str, x_test[i]))))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
File: house_price.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
import random
import paddle.fluid as fluid
from datasets.base_dataset import BaseDataset
from datasets.house_price.baseline_sklearn import CityInfo
class HousePrice(BaseDataset):
"""
shop location dataset
"""
def __init__(self, flags):
super(HousePrice, self).__init__(flags)
self.city_info = CityInfo(flags.city_name)
def parse_context(self, inputs):
"""
provide input context
"""
"""
set inputs_kv: please set key as the same as layer.data.name
notice:
(1)
If user defined "inputs key" is different from layer.data.name,
the frame will rewrite "inputs key" with layer.data.name
(2)
The param "inputs" will be passed to user defined nets class through
the nets class interface function : net(self, FLAGS, inputs),
"""
inputs['label'] = fluid.layers.data(name="label", shape=[1], dtype="float32", lod_level=0)
#if self._flags.dataset_split_name != 'train':
# inputs['qid'] = fluid.layers.data(name='qid', shape=[1], dtype="int32", lod_level=0)
#house self feature
inputs['house_business'] = fluid.layers.data(name="house_business", shape=[self.city_info.business_num],
dtype="float32", lod_level=0)
inputs['house_wuye'] = fluid.layers.data(name="house_wuye", shape=[self.city_info.wuye_num],
dtype="float32", lod_level=0)
inputs['house_kfs'] = fluid.layers.data(name="house_kfs", shape=[self.city_info.kfs_num],
dtype="float32", lod_level=0)
inputs['house_age'] = fluid.layers.data(name="house_age", shape=[self.city_info.age_num],
dtype="float32", lod_level=0)
inputs['house_lou'] = fluid.layers.data(name="house_lou", shape=[self.city_info.lou_num],
dtype="float32", lod_level=0)
#nearby house and public poi
inputs['house_price'] = fluid.layers.data(name="house_price", shape=[self._flags.max_house_num],
dtype="float32", lod_level=0)
inputs['public_bid'] = fluid.layers.data(name="public_bid", shape=[1],
dtype="int64", lod_level=1)
inputs['house_dis'] = fluid.layers.data(name="house_dis", shape=[self._flags.max_house_num * 2],
dtype="float32", lod_level=0)
inputs['public_dis'] = fluid.layers.data(name="public_dis", shape=[self._flags.max_public_num * 2],
dtype="float32", lod_level=0)
inputs['house_num'] = fluid.layers.data(name="house_num", shape=[1], dtype="float32", lod_level=0)
inputs['public_num'] = fluid.layers.data(name="public_num", shape=[1], dtype="float32", lod_level=0)
context = {"inputs": inputs}
#set debug list, print info during training
#debug_list = [key for key in inputs]
#context["debug_list"] = ["label", "house_num"]
return context
def _normalize_distance_factor(self, dis_vec):
sum = 0.0
for d in dis_vec:
sum += 1.0 / d
ret = []
for d in dis_vec:
ret.append(1.0 / (d * sum))
return ret
def parse_oneline(self, line):
"""
parse sample
"""
cols = line.strip('\r\n').split('\t')
max_house_num = self._flags.max_house_num
max_public_num = self._flags.max_public_num
pred = False if self._flags.dataset_split_name == 'train' else True
radius = self._flags.dis_radius
p_info = cols[0].split()
label = float(p_info[0])
samples = [('label', [float(label)])]
#house self info
h_num = int(p_info[1])
p_num = int(p_info[2])
onehot_ids = cols[1].split()
def _get_onehot(idx, num):
onehot = [0.0] * num
if idx >= 0 and idx < num:
onehot[idx] = 1.0
return onehot
onehot_business = _get_onehot(int(onehot_ids[0]), self.city_info.business_num)
onehot_wuye = _get_onehot(int(onehot_ids[1]), self.city_info.wuye_num)
onehot_kfs = _get_onehot(int(onehot_ids[2]), self.city_info.kfs_num)
onehot_age = _get_onehot(int(onehot_ids[3]), self.city_info.age_num)
onehot_lou = _get_onehot(int(onehot_ids[4]), self.city_info.lou_num)
#nearby house and public info
h_p_info = cols[2].split()
h_p_dis = cols[3].split()
h_p_car = []
if self._flags.with_car_dis:
h_p_car = cols[4].split()
assert(len(h_p_car) == len(h_p_dis))
#if h_num < 1 or p_num < 1:
# print("%s, invalid h_num or p_num." % line, file=sys.stderr)
# return
assert(len(h_p_info) == (h_num + p_num) and len(h_p_info) == len(h_p_dis))
p_id = []
p_dis = []
h_price = []
h_dis = []
for i in range(h_num + p_num):
if float(h_p_dis[i]) > radius or (len(h_p_car) > 0 and float(h_p_car[i]) < 0):
continue
if i < h_num:
if len(h_price) >= max_house_num:
continue
pinfo = h_p_info[i].split(':')
#h_price += float(pinfo[1]) * float(h_p_dis[i])
h_price.append(float(pinfo[1]))
if len(h_p_car) > 0:
h_dis.extend([float(h_p_dis[i]), float(h_p_car[i])])
else:
h_dis.append(float(h_p_dis[i]))
else:
if len(p_id) >= max_public_num:
break
p_id.append(int(h_p_info[i]))
if len(h_p_car) > 0:
p_dis.extend([float(h_p_dis[i]), float(h_p_car[i])])
else:
p_dis.append(float(h_p_dis[i]))
qid = 0
if self._flags.avg_eval:
if len(h_price) > 0:
avg_h = np.average(h_price)
h_dis = self._normalize_distance_factor(h_dis)
weight_h = np.sum(np.array(h_price) * h_dis / np.sum(h_dis))
else:
avg_h = self.city_info.average_price
weight_h = self.city_info.average_price
print("%s\t%s\t%s\t%s\t%s" % (qid, label, avg_h, weight_h, self.city_info.average_price))
return
if len(h_price) < 1 and len(p_id) < 1:
#sys.stderr.write("invalid line.\n")
return
h_num = len(h_price)
p_num = len(p_id)
#if pred:
# samples.append(('qid', [qid]))
samples.append(('house_business', onehot_business))
samples.append(('house_wuye', onehot_wuye))
samples.append(('house_kfs', onehot_kfs))
samples.append(('house_age', onehot_age))
samples.append(('house_lou', onehot_lou))
while len(h_price) < max_house_num:
h_price.append(self.city_info.average_price)
if len(h_p_car) > 0:
h_dis.extend([radius, 2 * radius])
else:
h_dis.append(radius)
while len(p_id) < max_public_num:
p_id.append(0)
if len(h_p_car) > 0:
p_dis.extend([radius, 2 * radius])
else:
p_dis.append(radius)
samples.append(('house_price', h_price))
samples.append(('public_bid', p_id))
samples.append(('house_dis', h_dis))
samples.append(('public_dis', p_dis))
samples.append(('house_num', [h_num]))
samples.append(('public_num', [p_num]))
yield samples
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
File: nets/house_price/house_price.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import math
import numpy as np
import paddle.fluid as fluid
from nets.base_net import BaseNet
from datasets.house_price.baseline_sklearn import CityInfo
class HousePrice(BaseNet):
"""
net class: construct net
"""
def __init__(self, FLAGS):
super(HousePrice, self).__init__(FLAGS)
self.city_info = CityInfo(FLAGS.city_name)
def emb_lookup_fn(self, input, dict_dim, emb_dim, layer_name, FLAGS,
padding_idx=None, init_val=0.0):
"""
get embedding out with params
"""
output = fluid.layers.embedding(
input=input,
size=[dict_dim, emb_dim],
padding_idx=padding_idx,
param_attr=fluid.ParamAttr(
name=layer_name,
initializer=fluid.initializer.ConstantInitializer(init_val)),
is_sparse=True)
return output
def fc_fn(self, input, output_size, act, layer_name, FLAGS, num_flatten_dims=1):
"""
pack fc op
"""
dev = 1.0 / math.sqrt(output_size)
_fc = fluid.layers.fc(
input=input,
size=output_size,
num_flatten_dims=num_flatten_dims,
param_attr=fluid.ParamAttr(
name=layer_name + "_fc_w",
initializer=fluid.initializer.Xavier(uniform=False)),
#initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=dev)),
bias_attr=fluid.ParamAttr(
name=layer_name + "_fc_bias",
initializer=fluid.initializer.Constant(value=0.0)),
act=act)
return _fc
def pred_format(self, result):
"""
format pred output
"""
if result is None or result in ['_PRE_']:
return
def _softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=0)
if result == '_POST_':
h_attr_w = fluid.global_scope().find_var("house_self_fc_w").get_tensor()
h_attr_b = fluid.global_scope().find_var("house_self_fc_bias").get_tensor()
dis_w = fluid.global_scope().find_var("dis_w").get_tensor()
bids = fluid.global_scope().find_var("bids").get_tensor()
print("h_attr_w: %s" % (" ".join(map(str, _softmax(np.array(h_attr_w).flatten())))))
print("h_attr_b: %s" % (" ".join(map(str, np.array(h_attr_b)))))
print("dis_w: %s" % (" ".join(map(str, _softmax(np.array(np.mean(dis_w, 0)))))))
print("bids: %s" % (" ".join(map(str, np.array(bids).flatten()))))
return
label = np.array(result[0]).T.flatten().tolist()
pred = np.array(result[1]).T.flatten().tolist()
for i in range(len(pred)):
print("qid\t%s\t%s" % (label[i], pred[i]))
def net(self, inputs):
"""
user-defined interface
"""
"""
feature: dict. {"label": xxx, "ct_onehot": xxxx,,...}
"""
FLAGS = self._flags
label = inputs['label']
public_bids = inputs['public_bid']
max_house_num = FLAGS.max_house_num
max_public_num = FLAGS.max_public_num
#step1. get house self feature
if FLAGS.with_house_attr:
def _get_house_attr(name, attr_vec_size):
h_onehot = fluid.layers.reshape(inputs[name], [-1, attr_vec_size])
h_attr = self.fc_fn(h_onehot, 1, act=None, layer_name=name, FLAGS=FLAGS)
return h_attr
house_business = _get_house_attr("house_business", self.city_info.business_num)
house_wuye = _get_house_attr("house_wuye", self.city_info.wuye_num)
house_kfs = _get_house_attr("house_kfs", self.city_info.kfs_num)
house_age = _get_house_attr("house_age", self.city_info.age_num)
house_lou = _get_house_attr("house_lou", self.city_info.lou_num)
house_vec = fluid.layers.concat([house_business, house_wuye, house_kfs, house_age, house_lou], 1)
else:
#no house attr
house_vec = fluid.layers.reshape(inputs["house_business"], [-1, self.city_info.business_num])
house_self = self.fc_fn(house_vec, 1, act='sigmoid', layer_name='house_self', FLAGS=FLAGS)
house_self = fluid.layers.reshape(house_self, [-1, 1])
#step2. get nearby house and public poi feature
#public poi embeddings matrix
bid_embed = self.emb_lookup_fn(public_bids, self.city_info.public_num, 1, 'bids', FLAGS, None,
self.city_info.average_price)
dis_dim = 1 #only line dis
if FLAGS.with_car_dis:
dis_dim = 2 #add car drive dis
#nearby house and public poi distance weight matrix
dis_w = fluid.layers.create_parameter(shape=[max_house_num + max_public_num, dis_dim],
dtype='float32', name='dis_w')
house_price = inputs['house_price']
public_price = fluid.layers.reshape(bid_embed, [-1, max_public_num])
#nearby price
price_vec = fluid.layers.concat([house_price, public_price], 1)
#nearby price weight
house_dis = fluid.layers.reshape(inputs['house_dis'], [-1, max_house_num, dis_dim])
public_dis = fluid.layers.reshape(inputs['public_dis'], [-1, max_public_num, dis_dim])
dis_vec = fluid.layers.concat([house_dis, public_dis], 1)
dis_w = fluid.layers.reshape(dis_w, [max_house_num + max_public_num, dis_dim])
dis_vec = fluid.layers.reduce_sum(dis_vec * dis_w, 2)
house_mask = fluid.layers.sequence_mask(fluid.layers.reshape(inputs['house_num'], [-1]),
max_house_num) #remove padded
public_mask = fluid.layers.sequence_mask(fluid.layers.reshape(inputs['public_num'], [-1]),
max_public_num) #remove padded
combine_mask = fluid.layers.cast(x=fluid.layers.concat([house_mask, public_mask], 1),
dtype="float32")
adder = (1.0 - combine_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
dis_vec += adder
price_weight = fluid.layers.softmax(dis_vec)
combine_price = price_vec * price_weight
#step3. merge house_self and nearby house and public price: [-1, 1] * [-1, 1]
pred = house_self * fluid.layers.unsqueeze(fluid.layers.reduce_sum(combine_price, 1), [1])
#fluid.layers.Print(pred, message=None, summarize=-1)
#fluid.layers.Print(label, message=None, summarize=-1)
loss = fluid.layers.square_error_cost(input=pred, label=label)
avg_cost = fluid.layers.mean(loss)
# debug output info during training
debug_output = {}
model_output = {}
net_output = {"debug_output": debug_output,
"model_output": model_output}
model_output['feeded_var_names'] = inputs.keys()
model_output['target_vars'] = [label, pred]
model_output['loss'] = avg_cost
#debug_output['pred'] = pred
debug_output['loss'] = avg_cost
#debug_output['label'] = label
#debug_output['public_bids'] = public_bids
return net_output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册