From da1c712de3c7fc6e5a525ec8d512365effdd80cc Mon Sep 17 00:00:00 2001 From: malin10 Date: Thu, 9 Jul 2020 23:20:40 +0800 Subject: [PATCH 1/3] add linear regression --- models/rank/linear_regression/__init__.py | 13 ++ models/rank/linear_regression/config.yaml | 72 +++++++++ .../data/download_preprocess.py | 37 +++++ .../rank/linear_regression/data/preprocess.py | 146 ++++++++++++++++++ models/rank/linear_regression/data/split.py | 56 +++++++ .../linear_regression/data/test_data/data | 0 .../linear_regression/data/train_data/data | 0 models/rank/linear_regression/data_prepare.sh | 15 ++ models/rank/linear_regression/model.py | 75 +++++++++ models/rank/linear_regression/parse_param.py | 64 ++++++++ 10 files changed, 478 insertions(+) create mode 100644 models/rank/linear_regression/__init__.py create mode 100644 models/rank/linear_regression/config.yaml create mode 100644 models/rank/linear_regression/data/download_preprocess.py create mode 100644 models/rank/linear_regression/data/preprocess.py create mode 100644 models/rank/linear_regression/data/split.py create mode 100644 models/rank/linear_regression/data/test_data/data create mode 100644 models/rank/linear_regression/data/train_data/data create mode 100644 models/rank/linear_regression/data_prepare.sh create mode 100644 models/rank/linear_regression/model.py create mode 100644 models/rank/linear_regression/parse_param.py diff --git a/models/rank/linear_regression/__init__.py b/models/rank/linear_regression/__init__.py new file mode 100644 index 00000000..abf198b9 --- /dev/null +++ b/models/rank/linear_regression/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/rank/linear_regression/config.yaml b/models/rank/linear_regression/config.yaml new file mode 100644 index 00000000..29e81409 --- /dev/null +++ b/models/rank/linear_regression/config.yaml @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# global settings +debug: false +workspace: "/home/aistudio/PaddleRec-master/models/rank/linear_regression" + + +dataset: + - name: dataset_train + type: QueueDataset + batch_size: 1 + data_path: "{workspace}/data/train_data/" + sparse_slots: "userid gender age occupation movieid title genres" + dense_slots: "label:1" + - name: dataset_infer + type: QueueDataset + batch_size: 1 + data_path: "{workspace}/data/test_data/" + sparse_slots: "userid gender age occupation movieid title genres" + dense_slots: "label:1" + +hyper_parameters: + optimizer: + class: SGD + learning_rate: 0.0001 + sparse_feature_number: 1000000 + sparse_feature_dim: 1 + reg: 0.001 + + +mode: train_runner +# if infer, change mode to "infer_runner" and change phase to "infer_phase" + +runner: + - name: train_runner + class: train + epochs: 1 + device: cpu + init_model_path: "" + save_checkpoint_interval: 1 + save_inference_interval: 1 + save_checkpoint_path: "increment" + save_inference_path: "inference" + print_interval: 100 + - name: infer_runner + class: infer + device: cpu + init_model_path: "increment/0" + print_interval: 1 + + +phase: +- name: phase1 + model: "{workspace}/model.py" + dataset_name: dataset_train + thread_num: 12 +#- name: infer_phase +# model: "{workspace}/model.py" +# dataset_name: infer_sample +# thread_num: 1 diff --git a/models/rank/linear_regression/data/download_preprocess.py b/models/rank/linear_regression/data/download_preprocess.py new file mode 100644 index 00000000..ab2f7cc6 --- /dev/null +++ b/models/rank/linear_regression/data/download_preprocess.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys + +LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) +TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools") +sys.path.append(TOOLS_PATH) + +from paddlerec.tools.tools import download_file_and_uncompress, download_file + +if __name__ == '__main__': + url = "http://files.grouplens.org/datasets/movielens/ml-1m.zip" + + print("download and extract starting...") + download_file_and_uncompress(url) + print("download and extract finished") + + # print("preprocessing...") + # os.system("python preprocess.py") + # print("preprocess done") + + # shutil.rmtree("raw_data") + print("done") diff --git a/models/rank/linear_regression/data/preprocess.py b/models/rank/linear_regression/data/preprocess.py new file mode 100644 index 00000000..7392deca --- /dev/null +++ b/models/rank/linear_regression/data/preprocess.py @@ -0,0 +1,146 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#coding=utf8 +import os +import sys +reload(sys) +sys.setdefaultencoding('utf-8') +import random +import json + +user_fea = ["userid", "gender", "age", "occupation"] +movie_fea = ["movieid", "title", "genres"] +rating_fea = ["userid", "movieid", "rating", "time"] +dict_size = 1000000 +hash_dict = dict() + +data_path = "ml-1m" +test_user_path = "online_user" + + +def process(path, output_path): + user_dict = parse_data(data_path + "/users.dat", user_fea) + movie_dict = parse_movie_data(data_path + "/movies.dat", movie_fea) + + res = [] + for line in open(path): + line = line.strip() + arr = line.split("::") + userid = arr[0] + movieid = arr[1] + out_str = "time:%s\t%s\t%s\tlabel:%s" % (arr[3], user_dict[userid], + movie_dict[movieid], arr[2]) + log_id = hash(out_str) % 1000000000 + res.append("%s\t%s" % (log_id, out_str)) + with open(output_path, 'w') as fout: + for line in res: + fout.write(line) + fout.write("\n") + + +def parse_data(file_name, feas): + dict = {} + for line in open(file_name): + line = line.strip() + arr = line.split("::") + out_str = "" + for i in range(0, len(feas)): + out_str += "%s:%s\t" % (feas[i], arr[i]) + + dict[arr[0]] = out_str.strip() + return dict + + +def parse_movie_data(file_name, feas): + dict = {} + for line in open(file_name): + line = line.strip() + arr = line.split("::") + title_str = "" + genres_str = "" + + for term in arr[1].split(" "): + term = term.strip() + if term != "": + title_str += "%s " % (term) + for term in arr[2].split("|"): + term = term.strip() + if term != "": + genres_str += "%s " % (term) + out_str = "movieid:%s\ttitle:%s\tgenres:%s" % ( + arr[0], title_str.strip(), genres_str.strip()) + dict[arr[0]] = out_str.strip() + return dict + + +def to_hash(in_str): + feas = in_str.split(":")[0] + arr = in_str.split(":")[1] + out_str = "%s:%s" % (feas, (arr + arr[::-1] + arr[::-2] + arr[::-3])) + hash_id = hash(out_str) % dict_size + # if hash_id in hash_dict and hash_dict[hash_id] != out_str: + # print(hash_id, out_str, hash(out_str)) + # print("conflict") + # exit(-1) + + return "%s:%s" % (feas, hash_id) + + +def to_hash_list(in_str): + arr = in_str.split(":") + tmp_arr = arr[1].split(" ") + out_str = "" + for item in tmp_arr: + item = item.strip() + if item != "": + key = "%s:%s" % (arr[0], item) + out_str += "%s " % (to_hash(key)) + return out_str.strip() + + +def get_hash(path): + #0-34831 1-time:974673057 2-userid:2021 3-gender:M 4-age:25 5-occupation:0 6-movieid:1345 7-title:Carrie (1976) 8-genres:Horror 9-label:2 + for line in open(path): + arr = line.strip().split("\t") + out_str = "logid:%s %s %s %s %s %s %s %s %s %s" % \ + (arr[0], arr[1], to_hash(arr[2]), to_hash(arr[3]), to_hash(arr[4]), to_hash(arr[5]), \ + to_hash(arr[6]), to_hash_list(arr[7]), to_hash_list(arr[8]), arr[9]) + print out_str + + +def split(path, output_dir, num=24): + contents = [] + with open(path) as f: + contents = f.readlines() + lines_per_file = len(contents) / num + print("contents: ", str(len(contents))) + print("lines_per_file: ", str(lines_per_file)) + + for i in range(1, num + 1): + with open(os.path.join(output_dir, "part_" + str(i)), 'w') as fout: + data = contents[(i - 1) * lines_per_file:min(i * lines_per_file, + len(contents))] + for line in data: + fout.write(line) + + +if __name__ == "__main__": + random.seed(1111111) + if sys.argv[1] == "process_raw": + process(sys.argv[2], sys.argv[3]) + elif sys.argv[1] == "hash": + get_hash(sys.argv[2]) + elif sys.argv[1] == "split": + split(sys.argv[2], sys.argv[3]) diff --git a/models/rank/linear_regression/data/split.py b/models/rank/linear_regression/data/split.py new file mode 100644 index 00000000..c763faf3 --- /dev/null +++ b/models/rank/linear_regression/data/split.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +train = dict() +test = dict() +data_path = "ml-1m" + +for line in open(data_path + "/ratings.dat"): + fea = line.rstrip().split("::") + if fea[0] not in train: + train[fea[0]] = [line] + elif fea[0] not in test: + test[fea[0]] = dict() + test[fea[0]]['time'] = int(fea[3]) + test[fea[0]]['content'] = line + else: + time = int(fea[3]) + if time <= test[fea[0]]['time']: + train[fea[0]].append(line) + else: + train[fea[0]].append(test[fea[0]]['content']) + test[fea[0]]['time'] = time + test[fea[0]]['content'] = line + +train_data = [] +for key in train: + for line in train[key]: + train_data.append(line) + +random.shuffle(train_data) +train_num = 10000 +idx = 0 + +with open(data_path + "/train.dat", 'w') as f: + for line in train_data: + idx += 1 + if idx > train_num: + break + f.write(line) + +with open(data_path + "/test.dat", 'w') as f: + for key in test: + f.write(test[key]['content']) diff --git a/models/rank/linear_regression/data/test_data/data b/models/rank/linear_regression/data/test_data/data new file mode 100644 index 00000000..e69de29b diff --git a/models/rank/linear_regression/data/train_data/data b/models/rank/linear_regression/data/train_data/data new file mode 100644 index 00000000..e69de29b diff --git a/models/rank/linear_regression/data_prepare.sh b/models/rank/linear_regression/data_prepare.sh new file mode 100644 index 00000000..6e9f9877 --- /dev/null +++ b/models/rank/linear_regression/data_prepare.sh @@ -0,0 +1,15 @@ +cd data +# 1. download data +python download_preprocess.py + +# 2. split data +python split.py + +# 3. 数据拼接 +python preprocess.py process_raw ml-1m/train.dat raw_train +python preprocess.py process_raw ml-1m/test.dat raw_test + +# 4. hash +python preprocess.py hash raw_train > train_data/data +python preprocess.py hash raw_test > test_data/data +cd .. diff --git a/models/rank/linear_regression/model.py b/models/rank/linear_regression/model.py new file mode 100644 index 00000000..1680092a --- /dev/null +++ b/models/rank/linear_regression/model.py @@ -0,0 +1,75 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle.fluid as fluid + +from paddlerec.core.utils import envs +from paddlerec.core.model import ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.sparse_feature_number = envs.get_global_env( + "hyper_parameters.sparse_feature_number", None) + self.reg = envs.get_global_env("hyper_parameters.reg", 1e-4) + + def net(self, inputs, is_infer=False): + init_value_ = 0.1 + is_distributed = True if envs.get_trainer() == "CtrTrainer" else False + + # ------------------------- network input -------------------------- + + sparse_var = self._sparse_data_var + self.label = self._dense_data_var[0] + + def embedding_layer(input): + emb = fluid.embedding( + input=input, + is_sparse=True, + is_distributed=is_distributed, + size=[self.sparse_feature_number + 1, 1], + padding_idx=0, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_), + regularizer=fluid.regularizer.L1DecayRegularizer( + self.reg))) + reshape_emb = fluid.layers.reshape(emb, shape=[-1, 1]) + return reshape_emb + + sparse_embed_seq = list(map(embedding_layer, sparse_var)) + weight = fluid.layers.concat(sparse_embed_seq, axis=0) + weight_sum = fluid.layers.reduce_sum(weight) + b_linear = fluid.layers.create_parameter( + shape=[1], + dtype='float32', + default_initializer=fluid.initializer.ConstantInitializer(value=0)) + + self.predict = fluid.layers.relu(weight_sum + b_linear) + cost = fluid.layers.square_error_cost( + input=self.predict, label=self.label) + avg_cost = fluid.layers.reduce_sum(cost) + + self._cost = avg_cost + + self._metrics["COST"] = self._cost + self._metrics["Predict"] = self.predict + if is_infer: + self._infer_results["Predict"] = self.predict + self._infer_results["COST"] = self._cost diff --git a/models/rank/linear_regression/parse_param.py b/models/rank/linear_regression/parse_param.py new file mode 100644 index 00000000..dba6706b --- /dev/null +++ b/models/rank/linear_regression/parse_param.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import sys + +params = [] +with open(sys.argv[1]) as f: + for line in f: + line = line.strip().strip('data: ').strip(',').split(',') + line = map(float, line) + params.append(line) + +feas = [] +with open(sys.argv[2]) as f: + for line in f: + line = line.strip().split('\t') + feas.append(line) + +score = [] +with open(sys.argv[3]) as f: + for line in f: + line = float(line.strip().strip('data: ').strip()[1:-1]) + score.append(line) + +assert (len(params) == len(feas)) +length = len(params) + +bias = None +for i in range(length): + label = feas[i][-1] + tmp = feas[i][2:-3] + tmp_fea = feas[i][-3].split(":") + _ = tmp_fea[1].split(" ") + for j in range(len(_)): + if _[j] != "": + tmp.append(tmp_fea[0] + ":" + _[j]) + tmp_fea = feas[i][-2].split(":") + _ = tmp_fea[1].split(" ") + for j in range(len(_)): + if _[j] != "": + tmp.append(tmp_fea[0] + ":" + _[j]) + sort_p = np.argsort(np.array(params[i]))[::-1] + + res = [] + for j in range(len(sort_p)): + res.append(tmp[sort_p[j]] + "_" + str(params[i][sort_p[j]])) + + res.append(label) + res.append(str(score[i])) + bias = score[i] - sum(params[i]) + print("; ".join(res)) + assert (len(params[i]) == len(tmp)) -- GitLab From 354472a7b85b3477079e10250a1b167cdbd538f0 Mon Sep 17 00:00:00 2001 From: malin10 Date: Thu, 9 Jul 2020 23:36:27 +0800 Subject: [PATCH 2/3] bug fix --- models/rank/linear_regression/config.yaml | 6 +++--- models/rank/linear_regression/data/test_data/data | 0 models/rank/linear_regression/data/test_data/data.txt | 10 ++++++++++ models/rank/linear_regression/data/train_data/data | 0 models/rank/linear_regression/data/train_data/data.txt | 10 ++++++++++ models/rank/linear_regression/model.py | 2 ++ 6 files changed, 25 insertions(+), 3 deletions(-) delete mode 100644 models/rank/linear_regression/data/test_data/data create mode 100644 models/rank/linear_regression/data/test_data/data.txt delete mode 100644 models/rank/linear_regression/data/train_data/data create mode 100644 models/rank/linear_regression/data/train_data/data.txt diff --git a/models/rank/linear_regression/config.yaml b/models/rank/linear_regression/config.yaml index 29e81409..db29a8ed 100644 --- a/models/rank/linear_regression/config.yaml +++ b/models/rank/linear_regression/config.yaml @@ -14,7 +14,7 @@ # global settings debug: false -workspace: "/home/aistudio/PaddleRec-master/models/rank/linear_regression" +workspace: "paddlerec.models.rank.linear_regression" dataset: @@ -53,7 +53,7 @@ runner: save_inference_interval: 1 save_checkpoint_path: "increment" save_inference_path: "inference" - print_interval: 100 + print_interval: 1 - name: infer_runner class: infer device: cpu @@ -68,5 +68,5 @@ phase: thread_num: 12 #- name: infer_phase # model: "{workspace}/model.py" -# dataset_name: infer_sample +# dataset_name: dataset_infer # thread_num: 1 diff --git a/models/rank/linear_regression/data/test_data/data b/models/rank/linear_regression/data/test_data/data deleted file mode 100644 index e69de29b..00000000 diff --git a/models/rank/linear_regression/data/test_data/data.txt b/models/rank/linear_regression/data/test_data/data.txt new file mode 100644 index 00000000..2c1f9878 --- /dev/null +++ b/models/rank/linear_regression/data/test_data/data.txt @@ -0,0 +1,10 @@ +logid:406201811 time:959974725 userid:782959 gender:713968 age:367871 occupation:887870 movieid:593474 title:389382 title:420426 title:697263 title:759044 genres:220363 label:2 +logid:731667808 time:957756905 userid:317013 gender:715500 age:959669 occupation:956770 movieid:436098 title:195371 genres:43405 genres:22907 label:5 +logid:914109775 time:956936565 userid:466839 gender:713968 age:655161 occupation:113614 movieid:495392 title:304668 title:924099 title:819199 genres:893545 genres:750128 label:4 +logid:130407756 time:956926305 userid:845353 gender:713968 age:367871 occupation:113678 movieid:278462 title:631279 title:75815 title:683892 genres:697846 genres:893545 genres:803235 label:4 +logid:162344002 time:956945570 userid:467391 gender:713968 age:930457 occupation:113614 movieid:757292 title:437909 title:67566 title:663931 title:622433 title:973752 genres:455809 label:2 +logid:172473319 time:956934181 userid:263429 gender:713968 age:655161 occupation:113862 movieid:889490 title:186300 title:677155 title:937049 title:622433 title:855502 genres:220363 label:2 +logid:814895539 time:956890758 userid:734503 gender:715500 age:930457 occupation:113614 movieid:312477 title:594818 title:491338 title:234546 genres:220363 genres:532872 label:3 +logid:721088277 time:973289822 userid:278745 gender:713968 age:367871 occupation:300174 movieid:673898 title:774924 title:61321 title:755193 genres:891961 genres:455809 label:5 +logid:487566723 time:956907971 userid:192959 gender:713968 age:367871 occupation:113614 movieid:722671 title:677184 title:108182 title:974616 title:797604 genres:697846 genres:893545 genres:220363 genres:803235 label:5 +logid:960590831 time:978252470 userid:297629 gender:715500 age:292885 occupation:113738 movieid:301175 title:508793 title:829566 title:199430 title:158723 title:467516 genres:43405 label:3 diff --git a/models/rank/linear_regression/data/train_data/data b/models/rank/linear_regression/data/train_data/data deleted file mode 100644 index e69de29b..00000000 diff --git a/models/rank/linear_regression/data/train_data/data.txt b/models/rank/linear_regression/data/train_data/data.txt new file mode 100644 index 00000000..cee804d8 --- /dev/null +++ b/models/rank/linear_regression/data/train_data/data.txt @@ -0,0 +1,10 @@ +logid:966771740 time:962590666 userid:170728 gender:713968 age:367871 occupation:645306 movieid:790755 title:824881 title:127016 genres:697846 genres:893545 genres:803235 label:4 +logid:353175809 time:968974066 userid:414067 gender:713968 age:292885 occupation:113738 movieid:729061 title:372631 title:838478 title:144774 title:593993 genres:220363 genres:853044 label:5 +logid:376358153 time:973455150 userid:867983 gender:713968 age:655161 occupation:113862 movieid:546406 title:697513 title:9603 title:37653 genres:220363 label:5 +logid:532644334 time:959798907 userid:854849 gender:713968 age:367871 occupation:113738 movieid:310450 title:748457 title:214082 genres:220363 label:3 +logid:499989270 time:967918013 userid:852221 gender:713968 age:655161 occupation:113614 movieid:427238 title:665231 title:622433 title:225707 genres:43405 label:4 +logid:580343064 time:966375464 userid:886104 gender:713968 age:367871 occupation:113738 movieid:715507 title:668677 title:234546 genres:697846 genres:220363 genres:803235 genres:455809 label:3 +logid:866282024 time:974636987 userid:606130 gender:713968 age:263859 occupation:991290 movieid:818748 title:28827 title:590550 title:622433 title:973752 genres:43405 genres:22907 label:4 +logid:56511742 time:975233847 userid:535824 gender:713968 age:367871 occupation:113614 movieid:468606 title:491533 title:887940 title:622433 title:549750 genres:43405 genres:532872 genres:22907 label:5 +logid:909168601 time:965277736 userid:157 gender:713968 age:688501 occupation:113862 movieid:753872 title:26739 title:622433 title:797604 genres:697846 genres:532872 genres:455809 label:3 +logid:641667514 time:960083820 userid:875780 gender:715500 age:367871 occupation:113882 movieid:956826 title:191999 title:108182 title:496485 title:549750 genres:220363 genres:532872 label:4 diff --git a/models/rank/linear_regression/model.py b/models/rank/linear_regression/model.py index 1680092a..8ee32e17 100644 --- a/models/rank/linear_regression/model.py +++ b/models/rank/linear_regression/model.py @@ -55,6 +55,8 @@ class Model(ModelBase): sparse_embed_seq = list(map(embedding_layer, sparse_var)) weight = fluid.layers.concat(sparse_embed_seq, axis=0) + if is_infer: + fluid.layers.Print(weight) weight_sum = fluid.layers.reduce_sum(weight) b_linear = fluid.layers.create_parameter( shape=[1], -- GitLab From 90fd23ad691ed39d5b3179809eeff60e4969c397 Mon Sep 17 00:00:00 2001 From: malin10 Date: Thu, 9 Jul 2020 23:42:19 +0800 Subject: [PATCH 3/3] aistudio link --- README.md | 1 + README_CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 86990286..cc0e8e29 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ python -m paddlerec.run -m paddlerec.models.rank.dnn ### Introductory Project * [Get start of PaddleRec in ten minutes](https://aistudio.baidu.com/aistudio/projectdetail/559336) +* [Tutorial of Linear Regression and Analysis of Feature Importance](https://aistudio.baidu.com/aistudio/projectdetail/618918) ### Introductory tutorial * [Data](doc/slot_reader.md) diff --git a/README_CN.md b/README_CN.md index 3f3f8f0e..41a0d5f3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -135,6 +135,7 @@ python -m paddlerec.run -m paddlerec.models.rank.dnn ### 快速开始 * [十分钟上手PaddleRec](https://aistudio.baidu.com/aistudio/projectdetail/559336) +* [乘风破浪的调参侠!玩转特征重要性~从此精通LR](https://aistudio.baidu.com/aistudio/projectdetail/618918) ### 入门教程 * [数据准备](doc/slot_reader.md) -- GitLab