提交 c34d6b53 编写于 作者: M malin10

reset gnn

上级 4e84e280
...@@ -49,31 +49,31 @@ runner: ...@@ -49,31 +49,31 @@ runner:
- name: train_runner - name: train_runner
class: train class: train
# num of epochs # num of epochs
epochs: 5 epochs: 2
# device to run training or infer # device to run training or infer
device: cpu device: cpu
save_checkpoint_interval: 1 # save model interval of epochs save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 1 # save inference save_inference_interval: 1 # save inference
save_checkpoint_path: "increment_gnn" # save checkpoint path save_checkpoint_path: "increment" # save checkpoint path
save_inference_path: "inference_gnn" # save inference path save_inference_path: "inference" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path init_model_path: "" # load model path
print_interval: 10 print_interval: 1
- name: infer_runner - name: infer_runner
class: infer class: infer
# device to run training or infer # device to run training or infer
device: cpu device: cpu
print_interval: 1 print_interval: 1
init_model_path: "increment_gnn" # load model path init_model_path: "increment/0" # load model path
# runner will run all the phase in each epoch # runner will run all the phase in each epoch
phase: phase:
- name: phase_train - name: phase1
model: "{workspace}/model.py" # user-defined model model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_train # select dataset by name dataset_name: dataset_train # select dataset by name
thread_num: 1 thread_num: 1
# - name: phase_infer # - name: phase2
# model: "{workspace}/model.py" # user-defined model # model: "{workspace}/model.py" # user-defined model
# dataset_name: dataset_infer # select dataset by name # dataset_name: dataset_infer # select dataset by name
# thread_num: 1 # thread_num: 1
...@@ -57,10 +57,5 @@ def _download_file(url, savepath, print_progress): ...@@ -57,10 +57,5 @@ def _download_file(url, savepath, print_progress):
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if sys.argv[1] == "diginetica": _download_file("https://sr-gnn.bj.bcebos.com/train-item-views.csv",
_download_file("https://sr-gnn.bj.bcebos.com/train-item-views.csv", "./train-item-views.csv", True)
"./train-item-views.csv", True)
elif sys.argv[1] == "yoochoose":
_download_file(
"https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat",
"./yoochoose-clicks.dat", True)
...@@ -41,29 +41,39 @@ with open(dataset, "r") as f: ...@@ -41,29 +41,39 @@ with open(dataset, "r") as f:
curdate = None curdate = None
for data in reader: for data in reader:
sessid = data['session_id'] sessid = data['session_id']
date = '' if curdate and not curid == sessid:
date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(
time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
sess_date[curid] = date
curid = sessid
if opt.dataset == 'yoochoose': if opt.dataset == 'yoochoose':
item = data['item_id'] item = data['item_id']
date = time.mktime(
time.strptime(data['timestamp'][:19], '%Y-%m-%dT%H:%M:%S'))
else: else:
item = data['item_id'], int(data['timeframe']) item = data['item_id'], int(data['timeframe'])
date = time.mktime(time.strptime(data['eventdate'], '%Y-%m-%d')) curdate = ''
if opt.dataset == 'yoochoose':
if sessid not in sess_date: curdate = data['timestamp']
sess_date[sessid] = date else:
elif date > sess_date[sessid]: curdate = data['eventdate']
sess_date[sessid] = date
if sessid in sess_clicks: if sessid in sess_clicks:
sess_clicks[sessid] += [item] sess_clicks[sessid] += [item]
else: else:
sess_clicks[sessid] = [item] sess_clicks[sessid] = [item]
ctr += 1 ctr += 1
if opt.dataset != 'yoochoose': date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
for i in list(sess_clicks): for i in list(sess_clicks):
sorted_clicks = sorted(sess_clicks[i], key=operator.itemgetter(1)) sorted_clicks = sorted(sess_clicks[i], key=operator.itemgetter(1))
sess_clicks[i] = [c[0] for c in sorted_clicks] sess_clicks[i] = [c[0] for c in sorted_clicks]
sess_date[curid] = date
print("-- Reading data @ %ss" % datetime.datetime.now()) print("-- Reading data @ %ss" % datetime.datetime.now())
# Filter out length 1 sessions # Filter out length 1 sessions
...@@ -150,7 +160,7 @@ def obtian_tra(): ...@@ -150,7 +160,7 @@ def obtian_tra():
train_dates += [date] train_dates += [date]
train_seqs += [outseq] train_seqs += [outseq]
print(item_ctr) # 43098, 37484 print(item_ctr) # 43098, 37484
with open("./config.txt", "w") as fout: with open("./diginetica/config.txt", "w") as fout:
fout.write(str(item_ctr) + "\n") fout.write(str(item_ctr) + "\n")
return train_ids, train_dates, train_seqs return train_ids, train_dates, train_seqs
......
...@@ -15,31 +15,21 @@ ...@@ -15,31 +15,21 @@
# limitations under the License. # limitations under the License.
set -e set -e
dataset=$1
src=$1
if [[ $src == "yoochoose1_4" || $src == "yoochoose1_64" ]];then
src="yoochoose"
elif [[ $src == "diginetica" ]];then
src="diginetica"
else
echo "Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]"
exit 1
fi
echo "begin to download data" echo "begin to download data"
cd data && python download.py $src
mkdir $dataset cd data && python download.py
python preprocess.py --dataset $src mkdir diginetica
python preprocess.py --dataset diginetica
echo "begin to convert data (binary -> txt)" echo "begin to convert data (binary -> txt)"
python convert_data.py --data_dir $dataset python convert_data.py --data_dir diginetica
cat ${dataset}/train.txt | wc -l >> config.txt cat diginetica/train.txt | wc -l >> diginetica/config.txt
rm -rf train && mkdir train rm -rf train && mkdir train
mv ${dataset}/train.txt train mv diginetica/train.txt train
rm -rf test && mkdir test rm -rf test && mkdir test
mv ${dataset}/test.txt test mv diginetica/test.txt test
mv diginetica/config.txt ./config.txt
...@@ -20,7 +20,6 @@ import paddle.fluid.layers as layers ...@@ -20,7 +20,6 @@ import paddle.fluid.layers as layers
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase from paddlerec.core.model import ModelBase
from paddlerec.core.metrics import Precision
class Model(ModelBase): class Model(ModelBase):
...@@ -236,16 +235,16 @@ class Model(ModelBase): ...@@ -236,16 +235,16 @@ class Model(ModelBase):
softmax = layers.softmax_with_cross_entropy( softmax = layers.softmax_with_cross_entropy(
logits=logits, label=inputs[6]) # [batch_size, 1] logits=logits, label=inputs[6]) # [batch_size, 1]
self.loss = layers.reduce_mean(softmax) # [1] self.loss = layers.reduce_mean(softmax) # [1]
acc = Precision(input=logits, label=inputs[6], k=20) self.acc = layers.accuracy(input=logits, label=inputs[6], k=20)
self._cost = self.loss
self._cost = self.loss
if is_infer: if is_infer:
self._infer_results['P@20'] = acc self._infer_results['acc'] = self.acc
self._infer_results['LOSS'] = self.loss self._infer_results['loss'] = self.loss
return return
self._metrics["LOSS"] = self.loss self._metrics["LOSS"] = self.loss
self._metrics["Train_P@20"] = acc self._metrics["train_acc"] = self.acc
def optimizer(self): def optimizer(self):
step_per_epoch = self.corpus_size // self.train_batch_size step_per_epoch = self.corpus_size // self.train_batch_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册