Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
245415f5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
245415f5
编写于
8月 17, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modelzoo wide_and_deep_multitable
上级
68ba6532
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
1585 addition
and
6 deletion
+1585
-6
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
...ckend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
+2
-2
model_zoo/official/recommend/wide_and_deep/src/datasets.py
model_zoo/official/recommend/wide_and_deep/src/datasets.py
+1
-1
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+2
-3
model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt
...icial/recommend/wide_and_deep_multitable/requirements.txt
+3
-0
model_zoo/official/recommend/wide_and_deep_multitable/script/run_multinpu_train.sh
...end/wide_and_deep_multitable/script/run_multinpu_train.sh
+34
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/__init__.py
...ficial/recommend/wide_and_deep_multitable/src/__init__.py
+0
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/callbacks.py
...icial/recommend/wide_and_deep_multitable/src/callbacks.py
+96
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/config.py
...official/recommend/wide_and_deep_multitable/src/config.py
+95
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/datasets.py
...ficial/recommend/wide_and_deep_multitable/src/datasets.py
+341
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/metrics.py
...fficial/recommend/wide_and_deep_multitable/src/metrics.py
+153
-0
model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py
...l/recommend/wide_and_deep_multitable/src/wide_and_deep.py
+638
-0
model_zoo/official/recommend/wide_and_deep_multitable/train_and_eval.py
...cial/recommend/wide_and_deep_multitable/train_and_eval.py
+107
-0
model_zoo/official/recommend/wide_and_deep_multitable/train_and_eval_distribute.py
...end/wide_and_deep_multitable/train_and_eval_distribute.py
+113
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
浏览文件 @
245415f5
...
...
@@ -69,8 +69,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
auto
input_addr
=
reinterpret_cast
<
float
*>
(
inputs
[
0
]
->
addr
);
auto
indices_addr
=
reinterpret_cast
<
int
*>
(
inputs
[
1
]
->
addr
);
auto
output_addr
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
const
size_t
thread_num
=
8
;
std
::
thread
threads
[
8
];
const
size_t
thread_num
=
16
;
std
::
thread
threads
[
16
];
size_t
task_proc_lens
=
(
indices_lens_
+
thread_num
-
1
)
/
thread_num
;
size_t
i
;
size_t
task_offset
=
0
;
...
...
model_zoo/official/recommend/wide_and_deep/src/datasets.py
浏览文件 @
245415f5
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_
imagen
et."""
"""train_
datas
et."""
import
os
...
...
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
245415f5
...
...
@@ -164,9 +164,6 @@ class WideDeepModel(nn.Cell):
init_acts
=
[(
'Wide_b'
,
[
1
],
self
.
emb_init
)]
var_map
=
init_var_dict
(
self
.
init_args
,
init_acts
)
self
.
wide_b
=
var_map
[
"Wide_b"
]
if
parameter_server
:
self
.
wide_w
.
set_param_ps
()
self
.
embedding_table
.
set_param_ps
()
self
.
dense_layer_1
=
DenseLayer
(
self
.
all_dim_list
[
0
],
self
.
all_dim_list
[
1
],
self
.
weight_bias_init
,
...
...
@@ -217,6 +214,8 @@ class WideDeepModel(nn.Cell):
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
self
.
wide_w
.
set_param_ps
()
self
.
embedding_table
.
set_param_ps
()
else
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
target
=
'DEVICE'
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
target
=
'DEVICE'
)
...
...
model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt
0 → 100644
浏览文件 @
245415f5
numpy
pandas
pickle
model_zoo/official/recommend/wide_and_deep_multitable/script/run_multinpu_train.sh
0 → 100644
浏览文件 @
245415f5
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# bash run_multinpu_train.sh
execute_path
=
$(
pwd
)
script_self
=
$(
readlink
-f
"
$0
"
)
self_path
=
$(
dirname
"
${
script_self
}
"
)
export
RANK_SIZE
=
$1
export
EPOCH_SIZE
=
$2
export
DATASET
=
$3
export
RANK_TABLE_FILE
=
$4
for
((
i
=
0
;
i<
$RANK_SIZE
;
i++
))
;
do
rm
-rf
${
execute_path
}
/device_
$i
/
mkdir
${
execute_path
}
/device_
$i
/
cd
${
execute_path
}
/device_
$i
/
||
exit
export
RANK_ID
=
$i
export
DEVICE_ID
=
$i
python
-s
${
self_path
}
/../train_and_eval_distribute.py
--data_path
=
$DATASET
--epochs
=
$EPOCH_SIZE
>
train_deep
$i
.log 2>&1 &
done
model_zoo/official/recommend/wide_and_deep_multitable/src/__init__.py
0 → 100644
浏览文件 @
245415f5
model_zoo/official/recommend/wide_and_deep_multitable/src/callbacks.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
callbacks
"""
import
time
from
mindspore.train.callback
import
Callback
def
add_write
(
file_path
,
out_str
):
with
open
(
file_path
,
'a+'
,
encoding
=
"utf-8"
)
as
file_out
:
file_out
.
write
(
out_str
+
"
\n
"
)
class
LossCallBack
(
Callback
):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def
__init__
(
self
,
config
,
per_print_times
=
1
):
super
(
LossCallBack
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0."
)
self
.
_per_print_times
=
per_print_times
self
.
config
=
config
def
step_end
(
self
,
run_context
):
"""Monitor the loss in training."""
cb_params
=
run_context
.
original_args
()
wide_loss
,
deep_loss
=
cb_params
.
net_outputs
[
0
].
asnumpy
(),
\
cb_params
.
net_outputs
[
1
].
asnumpy
()
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
+
1
cur_num
=
cb_params
.
cur_step_num
print
(
"===loss==="
,
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
wide_loss
,
deep_loss
,
flush
=
True
)
if
self
.
_per_print_times
!=
0
and
cur_num
%
self
.
_per_print_times
==
0
:
loss_file
=
open
(
self
.
config
.
loss_file_name
,
"a+"
)
loss_file
.
write
(
"epoch: %s step: %s, wide_loss is %s, deep_loss is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
wide_loss
,
deep_loss
))
loss_file
.
write
(
"
\n
"
)
loss_file
.
close
()
print
(
"epoch: %s step: %s, wide_loss is %s, deep_loss is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
wide_loss
,
deep_loss
))
class
EvalCallBack
(
Callback
):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def
__init__
(
self
,
model
,
eval_dataset
,
auc_metric
,
config
,
print_per_step
=
1
):
super
(
EvalCallBack
,
self
).
__init__
()
if
not
isinstance
(
print_per_step
,
int
)
or
print_per_step
<
0
:
raise
ValueError
(
"print_step must be int and >= 0."
)
self
.
print_per_step
=
print_per_step
self
.
model
=
model
self
.
eval_dataset
=
eval_dataset
self
.
aucMetric
=
auc_metric
self
.
aucMetric
.
clear
()
self
.
eval_file_name
=
config
.
eval_file_name
def
epoch_end
(
self
,
run_context
):
"""Monitor the auc in training."""
self
.
aucMetric
.
clear
()
start_time
=
time
.
time
()
out
=
self
.
model
.
eval
(
self
.
eval_dataset
)
end_time
=
time
.
time
()
eval_time
=
int
(
end_time
-
start_time
)
time_str
=
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time
.
localtime
())
out_str
=
"{}=====EvalCallBack model.eval(): {} ; eval_time:{}s"
.
format
(
time_str
,
out
.
values
(),
eval_time
)
print
(
out_str
)
add_write
(
self
.
eval_file_name
,
out_str
)
model_zoo/official/recommend/wide_and_deep_multitable/src/config.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" config. """
import
argparse
def
argparse_init
():
"""
argparse_init
"""
parser
=
argparse
.
ArgumentParser
(
description
=
'WideDeep'
)
parser
.
add_argument
(
"--data_path"
,
type
=
str
,
default
=
"./test_raw_data/"
)
# The location of the input data.
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
200
)
# The number of epochs used to train.
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
131072
)
# Batch size for training and evaluation
parser
.
add_argument
(
"--eval_batch_size"
,
type
=
int
,
default
=
131072
)
# The batch size used for evaluation.
parser
.
add_argument
(
"--deep_layers_dim"
,
type
=
int
,
nargs
=
'+'
,
default
=
[
1024
,
512
,
256
,
128
])
# The sizes of hidden layers for MLP
parser
.
add_argument
(
"--deep_layers_act"
,
type
=
str
,
default
=
'relu'
)
# The act of hidden layers for MLP
parser
.
add_argument
(
"--keep_prob"
,
type
=
float
,
default
=
1.0
)
# The Embedding size of MF model.
parser
.
add_argument
(
"--adam_lr"
,
type
=
float
,
default
=
0.003
)
# The Adam lr
parser
.
add_argument
(
"--ftrl_lr"
,
type
=
float
,
default
=
0.1
)
# The ftrl lr.
parser
.
add_argument
(
"--l2_coef"
,
type
=
float
,
default
=
0.0
)
# The l2 coefficient.
parser
.
add_argument
(
"--is_tf_dataset"
,
type
=
bool
,
default
=
True
)
# The l2 coefficient.
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./output/"
)
# The location of the output file.
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
"./checkpoints/"
)
# The location of the checkpoints file.
parser
.
add_argument
(
"--eval_file_name"
,
type
=
str
,
default
=
"eval.log"
)
# Eval output file.
parser
.
add_argument
(
"--loss_file_name"
,
type
=
str
,
default
=
"loss.log"
)
# Loss output file.
return
parser
class
WideDeepConfig
():
"""
WideDeepConfig
"""
def
__init__
(
self
):
self
.
data_path
=
''
self
.
epochs
=
200
self
.
batch_size
=
131072
self
.
eval_batch_size
=
131072
self
.
deep_layers_act
=
'relu'
self
.
weight_bias_init
=
[
'normal'
,
'normal'
]
self
.
emb_init
=
'normal'
self
.
init_args
=
[
-
0.01
,
0.01
]
self
.
dropout_flag
=
False
self
.
keep_prob
=
1.0
self
.
l2_coef
=
0.0
self
.
adam_lr
=
0.003
self
.
ftrl_lr
=
0.1
self
.
is_tf_dataset
=
True
self
.
input_emb_dim
=
0
self
.
output_path
=
"./output/"
self
.
eval_file_name
=
"eval.log"
self
.
loss_file_name
=
"loss.log"
self
.
ckpt_path
=
"./checkpoints/"
def
argparse_init
(
self
):
"""
argparse_init
"""
parser
=
argparse_init
()
args
,
_
=
parser
.
parse_known_args
()
self
.
data_path
=
args
.
data_path
self
.
epochs
=
args
.
epochs
self
.
batch_size
=
args
.
batch_size
self
.
eval_batch_size
=
args
.
eval_batch_size
self
.
deep_layers_act
=
args
.
deep_layers_act
self
.
keep_prob
=
args
.
keep_prob
self
.
weight_bias_init
=
[
'normal'
,
'normal'
]
self
.
emb_init
=
'normal'
self
.
init_args
=
[
-
0.01
,
0.01
]
self
.
dropout_flag
=
False
self
.
l2_coef
=
args
.
l2_coef
self
.
ftrl_lr
=
args
.
ftrl_lr
self
.
adam_lr
=
args
.
adam_lr
self
.
is_tf_dataset
=
args
.
is_tf_dataset
self
.
output_path
=
args
.
output_path
self
.
eval_file_name
=
args
.
eval_file_name
self
.
loss_file_name
=
args
.
loss_file_name
self
.
ckpt_path
=
args
.
ckpt_path
model_zoo/official/recommend/wide_and_deep_multitable/src/datasets.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_dataset."""
import
os
import
math
import
pickle
import
numpy
as
np
import
pandas
as
pd
import
mindspore.dataset.engine
as
de
import
mindspore.common.dtype
as
mstype
class
H5Dataset
():
"""
H5Dataset
"""
input_length
=
39
def
__init__
(
self
,
data_path
,
train_mode
=
True
,
train_num_of_parts
=
21
,
test_num_of_parts
=
3
):
self
.
_hdf_data_dir
=
data_path
self
.
_is_training
=
train_mode
if
self
.
_is_training
:
self
.
_file_prefix
=
'train'
self
.
_num_of_parts
=
train_num_of_parts
else
:
self
.
_file_prefix
=
'test'
self
.
_num_of_parts
=
test_num_of_parts
self
.
data_size
=
self
.
_bin_count
(
self
.
_hdf_data_dir
,
self
.
_file_prefix
,
self
.
_num_of_parts
)
print
(
"data_size: {}"
.
format
(
self
.
data_size
))
def
_bin_count
(
self
,
hdf_data_dir
,
file_prefix
,
num_of_parts
):
size
=
0
for
part
in
range
(
num_of_parts
):
_y
=
pd
.
read_hdf
(
os
.
path
.
join
(
hdf_data_dir
,
file_prefix
+
'_output_part_'
+
str
(
part
)
+
'.h5'
))
size
+=
_y
.
shape
[
0
]
return
size
def
_iterate_hdf_files_
(
self
,
num_of_parts
=
None
,
shuffle_block
=
False
):
"""
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
from the beginning, thus the data stream will never stop
:param train_mode: True or false,false is eval_mode,
this file iterator will go through the train set
:param num_of_parts: number of files
:param shuffle_block: shuffle block files at every round
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
"""
parts
=
np
.
arange
(
num_of_parts
)
while
True
:
if
shuffle_block
:
for
_
in
range
(
int
(
shuffle_block
)):
np
.
random
.
shuffle
(
parts
)
for
i
,
p
in
enumerate
(
parts
):
yield
os
.
path
.
join
(
self
.
_hdf_data_dir
,
self
.
_file_prefix
+
'_input_part_'
+
str
(
p
)
+
'.h5'
),
\
os
.
path
.
join
(
self
.
_hdf_data_dir
,
self
.
_file_prefix
+
'_output_part_'
+
str
(
p
)
+
'.h5'
),
i
+
1
==
len
(
parts
)
def
_generator
(
self
,
X
,
y
,
batch_size
,
shuffle
=
True
):
"""
should be accessed only in private
:param X:
:param y:
:param batch_size:
:param shuffle:
:return:
"""
number_of_batches
=
np
.
ceil
(
1.
*
X
.
shape
[
0
]
/
batch_size
)
counter
=
0
finished
=
False
sample_index
=
np
.
arange
(
X
.
shape
[
0
])
if
shuffle
:
for
_
in
range
(
int
(
shuffle
)):
np
.
random
.
shuffle
(
sample_index
)
assert
X
.
shape
[
0
]
>
0
while
True
:
batch_index
=
sample_index
[
batch_size
*
counter
:
batch_size
*
(
counter
+
1
)]
X_batch
=
X
[
batch_index
]
y_batch
=
y
[
batch_index
]
counter
+=
1
yield
X_batch
,
y_batch
,
finished
if
counter
==
number_of_batches
:
counter
=
0
finished
=
True
def
batch_generator
(
self
,
batch_size
=
1000
,
random_sample
=
False
,
shuffle_block
=
False
):
"""
:param train_mode: True or false,false is eval_mode,
:param batch_size
:param num_of_parts: number of files
:param random_sample: if True, will shuffle
:param shuffle_block: shuffle file blocks at every round
:return:
"""
for
hdf_in
,
hdf_out
,
_
in
self
.
_iterate_hdf_files_
(
self
.
_num_of_parts
,
shuffle_block
):
start
=
stop
=
None
X_all
=
pd
.
read_hdf
(
hdf_in
,
start
=
start
,
stop
=
stop
).
values
y_all
=
pd
.
read_hdf
(
hdf_out
,
start
=
start
,
stop
=
stop
).
values
data_gen
=
self
.
_generator
(
X_all
,
y_all
,
batch_size
,
shuffle
=
random_sample
)
finished
=
False
while
not
finished
:
X
,
y
,
finished
=
data_gen
.
__next__
()
X_id
=
X
[:,
0
:
self
.
input_length
]
X_va
=
X
[:,
self
.
input_length
:]
yield
np
.
array
(
X_id
.
astype
(
dtype
=
np
.
int32
)),
np
.
array
(
X_va
.
astype
(
dtype
=
np
.
float32
)),
np
.
array
(
y
.
astype
(
dtype
=
np
.
float32
))
def
_get_h5_dataset
(
data_dir
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
1000
):
"""
_get_h5_dataset
"""
data_para
=
{
'batch_size'
:
batch_size
,
}
if
train_mode
:
data_para
[
'random_sample'
]
=
True
data_para
[
'shuffle_block'
]
=
True
h5_dataset
=
H5Dataset
(
data_path
=
data_dir
,
train_mode
=
train_mode
)
numbers_of_batch
=
math
.
ceil
(
h5_dataset
.
data_size
/
batch_size
)
def
_iter_h5_data
():
train_eval_gen
=
h5_dataset
.
batch_generator
(
**
data_para
)
for
_
in
range
(
0
,
numbers_of_batch
,
1
):
yield
train_eval_gen
.
__next__
()
ds
=
de
.
GeneratorDataset
(
_iter_h5_data
(),
[
"ids"
,
"weights"
,
"labels"
])
ds
.
set_dataset_size
(
numbers_of_batch
)
ds
=
ds
.
repeat
(
epochs
)
return
ds
def
_get_tf_dataset
(
data_dir
,
schema_dict
,
input_shape_dict
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
4096
,
line_per_sample
=
4096
,
rank_size
=
None
,
rank_id
=
None
):
"""
_get_tf_dataset
"""
dataset_files
=
[]
file_prefix_name
=
'train'
if
train_mode
else
'eval'
shuffle
=
bool
(
train_mode
)
for
(
dirpath
,
_
,
filenames
)
in
os
.
walk
(
data_dir
):
for
filename
in
filenames
:
if
file_prefix_name
in
filename
and
"tfrecord"
in
filename
:
dataset_files
.
append
(
os
.
path
.
join
(
dirpath
,
filename
))
schema
=
de
.
Schema
()
float_key_list
=
[
"label"
,
"continue_val"
]
columns_list
=
[]
for
key
,
attr_dict
in
schema_dict
.
items
():
print
(
"key: {}; shape: {}"
.
format
(
key
,
attr_dict
[
"tf_shape"
]))
columns_list
.
append
(
key
)
if
key
in
set
(
float_key_list
):
ms_dtype
=
mstype
.
float32
else
:
ms_dtype
=
mstype
.
int32
schema
.
add_column
(
key
,
de_type
=
ms_dtype
)
if
rank_size
is
not
None
and
rank_id
is
not
None
:
ds
=
de
.
TFRecordDataset
(
dataset_files
=
dataset_files
,
shuffle
=
shuffle
,
schema
=
schema
,
num_parallel_workers
=
8
,
num_shards
=
rank_size
,
shard_id
=
rank_id
,
shard_equal_rows
=
True
)
else
:
ds
=
de
.
TFRecordDataset
(
dataset_files
=
dataset_files
,
shuffle
=
shuffle
,
schema
=
schema
,
num_parallel_workers
=
8
)
ds
=
ds
.
batch
(
int
(
batch_size
/
line_per_sample
),
drop_remainder
=
True
)
operations_list
=
[]
for
key
in
columns_list
:
operations_list
.
append
(
lambda
x
:
np
.
array
(
x
).
flatten
().
reshape
(
input_shape_dict
[
key
]))
print
(
"ssssssssssssssssssssss---------------------"
*
10
)
print
(
input_shape_dict
)
print
(
"---------------------"
*
10
)
print
(
schema_dict
)
def
mixup
(
a
,
b
,
c
,
d
,
e
,
f
,
g
,
h
,
i
,
j
,
k
,
l
,
m
,
n
,
o
,
p
,
q
,
r
,
s
,
t
,
u
):
a
=
np
.
asarray
(
a
.
reshape
(
batch_size
,))
b
=
np
.
array
(
b
).
flatten
().
reshape
(
batch_size
,
-
1
)
c
=
np
.
array
(
c
).
flatten
().
reshape
(
batch_size
,
-
1
)
d
=
np
.
array
(
d
).
flatten
().
reshape
(
batch_size
,
-
1
)
e
=
np
.
array
(
e
).
flatten
().
reshape
(
batch_size
,
-
1
)
f
=
np
.
array
(
f
).
flatten
().
reshape
(
batch_size
,
-
1
)
g
=
np
.
array
(
g
).
flatten
().
reshape
(
batch_size
,
-
1
)
h
=
np
.
array
(
h
).
flatten
().
reshape
(
batch_size
,
-
1
)
i
=
np
.
array
(
i
).
flatten
().
reshape
(
batch_size
,
-
1
)
j
=
np
.
array
(
j
).
flatten
().
reshape
(
batch_size
,
-
1
)
k
=
np
.
array
(
k
).
flatten
().
reshape
(
batch_size
,
-
1
)
l
=
np
.
array
(
l
).
flatten
().
reshape
(
batch_size
,
-
1
)
m
=
np
.
array
(
m
).
flatten
().
reshape
(
batch_size
,
-
1
)
n
=
np
.
array
(
n
).
flatten
().
reshape
(
batch_size
,
-
1
)
o
=
np
.
array
(
o
).
flatten
().
reshape
(
batch_size
,
-
1
)
p
=
np
.
array
(
p
).
flatten
().
reshape
(
batch_size
,
-
1
)
q
=
np
.
array
(
q
).
flatten
().
reshape
(
batch_size
,
-
1
)
r
=
np
.
array
(
r
).
flatten
().
reshape
(
batch_size
,
-
1
)
s
=
np
.
array
(
s
).
flatten
().
reshape
(
batch_size
,
-
1
)
t
=
np
.
array
(
t
).
flatten
().
reshape
(
batch_size
,
-
1
)
u
=
np
.
array
(
u
).
flatten
().
reshape
(
batch_size
,
-
1
)
return
a
,
b
,
c
,
d
,
e
,
f
,
g
,
h
,
i
,
j
,
k
,
l
,
m
,
n
,
o
,
p
,
q
,
r
,
s
,
t
,
u
ds
=
ds
.
map
(
operations
=
mixup
,
input_columns
=
[
'label'
,
'continue_val'
,
'indicator_id'
,
'emb_128_id'
,
'emb_64_single_id'
,
'multi_doc_ad_category_id'
,
'multi_doc_ad_category_id_mask'
,
'multi_doc_event_entity_id'
,
'multi_doc_event_entity_id_mask'
,
'multi_doc_ad_entity_id'
,
'multi_doc_ad_entity_id_mask'
,
'multi_doc_event_topic_id'
,
'multi_doc_event_topic_id_mask'
,
'multi_doc_event_category_id'
,
'multi_doc_event_category_id_mask'
,
'multi_doc_ad_topic_id'
,
'multi_doc_ad_topic_id_mask'
,
'ad_id'
,
'display_ad_and_is_leak'
,
'display_id'
,
'is_leak'
],
columns_order
=
[
'label'
,
'continue_val'
,
'indicator_id'
,
'emb_128_id'
,
'emb_64_single_id'
,
'multi_doc_ad_category_id'
,
'multi_doc_ad_category_id_mask'
,
'multi_doc_event_entity_id'
,
'multi_doc_event_entity_id_mask'
,
'multi_doc_ad_entity_id'
,
'multi_doc_ad_entity_id_mask'
,
'multi_doc_event_topic_id'
,
'multi_doc_event_topic_id_mask'
,
'multi_doc_event_category_id'
,
'multi_doc_event_category_id_mask'
,
'multi_doc_ad_topic_id'
,
'multi_doc_ad_topic_id_mask'
,
'display_id'
,
'ad_id'
,
'display_ad_and_is_leak'
,
'is_leak'
],
num_parallel_workers
=
8
)
ds
=
ds
.
repeat
(
epochs
)
return
ds
def
compute_emb_dim
(
config
):
"""
compute_emb_dim
"""
with
open
(
os
.
path
.
join
(
config
.
data_path
+
'dataformat/'
,
"input_shape_dict.pkl"
),
"rb"
)
as
file_in
:
input_shape_dict
=
pickle
.
load
(
file_in
)
input_field_size
=
{}
for
key
,
shape
in
input_shape_dict
.
items
():
if
len
(
shape
)
<
2
:
input_field_size
[
key
]
=
1
else
:
input_field_size
[
key
]
=
shape
[
1
]
multi_key_list
=
[
"multi_doc_event_topic_id"
,
"multi_doc_event_entity_id"
,
"multi_doc_ad_category_id"
,
"multi_doc_event_category_id"
,
"multi_doc_ad_entity_id"
,
"multi_doc_ad_topic_id"
]
config
.
input_emb_dim
=
input_field_size
[
"continue_val"
]
+
\
input_field_size
[
"indicator_id"
]
*
64
+
\
input_field_size
[
"emb_128_id"
]
*
128
+
\
input_field_size
[
"emb_64_single_id"
]
*
64
+
\
len
(
multi_key_list
)
*
64
def
create_dataset
(
data_dir
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
4096
,
is_tf_dataset
=
True
,
line_per_sample
=
4096
,
rank_size
=
None
,
rank_id
=
None
):
"""
create_dataset
"""
if
is_tf_dataset
:
with
open
(
os
.
path
.
join
(
data_dir
+
'dataformat/'
,
"schema_dict.pkl"
),
"rb"
)
as
file_in
:
print
(
os
.
path
.
join
(
data_dir
+
'dataformat/'
,
"schema_dict.pkl"
))
schema_dict
=
pickle
.
load
(
file_in
)
with
open
(
os
.
path
.
join
(
data_dir
+
'dataformat/'
,
"input_shape_dict.pkl"
),
"rb"
)
as
file_in
:
input_shape_dict
=
pickle
.
load
(
file_in
)
return
_get_tf_dataset
(
data_dir
,
schema_dict
,
input_shape_dict
,
train_mode
,
epochs
,
batch_size
,
line_per_sample
,
rank_size
=
rank_size
,
rank_id
=
rank_id
)
if
rank_size
is
not
None
and
rank_size
>
1
:
raise
RuntimeError
(
"please use tfrecord dataset."
)
return
_get_h5_dataset
(
data_dir
,
train_mode
,
epochs
,
batch_size
)
model_zoo/official/recommend/wide_and_deep_multitable/src/metrics.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Area under cure metric
"""
import
time
import
numpy
as
np
import
pandas
as
pd
from
sklearn.metrics
import
roc_auc_score
,
average_precision_score
from
mindspore.nn.metrics
import
Metric
def
groupby_df_v1
(
test_df
,
gb_key
):
"""
groupby_df_v1
"""
data_groups
=
test_df
.
groupby
(
gb_key
)
return
data_groups
def
_compute_metric_v1
(
batch_groups
,
topk
):
"""
_compute_metric_v1
"""
results
=
[]
for
df
in
batch_groups
:
df
=
df
.
sort_values
(
by
=
"preds"
,
ascending
=
False
)
if
df
.
shape
[
0
]
>
topk
:
df
=
df
.
head
(
topk
)
preds
=
df
[
"preds"
].
values
labels
=
df
[
"labels"
].
values
if
np
.
sum
(
labels
)
>
0
:
results
.
append
(
average_precision_score
(
labels
,
preds
))
else
:
results
.
append
(
0.0
)
return
results
def
mean_AP_topk
(
batch_labels
,
batch_preds
,
topk
=
12
):
"""
mean_AP_topk
"""
def
ap_score
(
label
,
y_preds
,
topk
):
ind_list
=
np
.
argsort
(
y_preds
)[::
-
1
]
ind_list
=
ind_list
[:
topk
]
if
label
not
in
set
(
ind_list
):
return
0.0
rank
=
list
(
ind_list
).
index
(
label
)
return
1.0
/
(
rank
+
1
)
mAP_list
=
[]
for
label
,
preds
in
zip
(
batch_labels
,
batch_preds
):
mAP
=
ap_score
(
label
,
preds
,
topk
)
mAP_list
.
append
(
mAP
)
return
mAP_list
def
new_compute_mAP
(
test_df
,
gb_key
=
"display_ids"
,
top_k
=
12
):
"""
new_compute_mAP
"""
total_start
=
time
.
time
()
display_ids
=
test_df
[
"display_ids"
]
labels
=
test_df
[
"labels"
]
predictions
=
test_df
[
"preds"
]
test_df
.
sort_values
(
by
=
[
gb_key
],
inplace
=
True
,
ascending
=
True
)
display_ids
=
test_df
[
"display_ids"
]
labels
=
test_df
[
"labels"
]
predictions
=
test_df
[
"preds"
]
_
,
display_ids_idx
=
np
.
unique
(
display_ids
,
return_index
=
True
)
preds
=
np
.
split
(
predictions
.
tolist
(),
display_ids_idx
.
tolist
()[
1
:])
labels
=
np
.
split
(
labels
.
tolist
(),
display_ids_idx
.
tolist
()[
1
:])
def
pad_fn
(
ele_l
):
res_list
=
ele_l
+
[
0.0
for
i
in
range
(
30
-
len
(
ele_l
))]
return
res_list
preds
=
list
(
map
(
lambda
x
:
pad_fn
(
x
.
tolist
()),
preds
))
labels
=
[
np
.
argmax
(
l
)
for
l
in
labels
]
result_list
=
[]
batch_size
=
100000
for
idx
in
range
(
0
,
len
(
labels
),
batch_size
):
batch_labels
=
labels
[
idx
:
idx
+
batch_size
]
batch_preds
=
preds
[
idx
:
idx
+
batch_size
]
meanAP
=
mean_AP_topk
(
batch_labels
,
batch_preds
,
topk
=
top_k
)
result_list
.
extend
(
meanAP
)
mean_AP
=
np
.
mean
(
result_list
)
print
(
"compute time: {}"
.
format
(
time
.
time
()
-
total_start
))
print
(
"mean_AP: {}"
.
format
(
mean_AP
))
return
mean_AP
class
AUCMetric
(
Metric
):
"""
AUCMetric
"""
def
__init__
(
self
):
super
(
AUCMetric
,
self
).
__init__
()
self
.
index
=
1
def
clear
(
self
):
"""Clear the internal evaluation result."""
self
.
true_labels
=
[]
self
.
pred_probs
=
[]
self
.
display_id
=
[]
def
update
(
self
,
*
inputs
):
"""
update
"""
all_predict
=
inputs
[
1
].
asnumpy
()
# predict
all_label
=
inputs
[
2
].
asnumpy
()
# label
all_display_id
=
inputs
[
3
].
asnumpy
()
# label
self
.
true_labels
.
extend
(
all_label
.
flatten
().
tolist
())
self
.
pred_probs
.
extend
(
all_predict
.
flatten
().
tolist
())
self
.
display_id
.
extend
(
all_display_id
.
flatten
().
tolist
())
def
eval
(
self
):
"""
eval
"""
if
len
(
self
.
true_labels
)
!=
len
(
self
.
pred_probs
):
raise
RuntimeError
(
'true_labels.size() is not equal to pred_probs.size()'
)
result_df
=
pd
.
DataFrame
({
"display_ids"
:
self
.
display_id
,
"preds"
:
self
.
pred_probs
,
"labels"
:
self
.
true_labels
,
})
auc
=
roc_auc_score
(
self
.
true_labels
,
self
.
pred_probs
)
MAP
=
new_compute_mAP
(
result_df
,
gb_key
=
"display_ids"
,
top_k
=
12
)
print
(
"====="
*
20
+
" auc_metric end "
)
print
(
"====="
*
20
+
" auc: {}, map: {}"
.
format
(
auc
,
MAP
))
return
auc
model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py
0 → 100644
浏览文件 @
245415f5
此差异已折叠。
点击以展开。
model_zoo/official/recommend/wide_and_deep_multitable/train_and_eval.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" training_and_evaluating """
import
os
import
sys
from
mindspore
import
Model
,
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
TimeMonitor
from
src.wide_and_deep
import
PredictWithSigmoid
,
TrainStepWrap
,
NetWithLossClass
,
WideDeepModel
from
src.callbacks
import
LossCallBack
,
EvalCallBack
from
src.datasets
import
create_dataset
,
compute_emb_dim
from
src.metrics
import
AUCMetric
from
src.config
import
WideDeepConfig
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
def
get_WideDeep_net
(
config
):
"""
Get network of wide&deep model.
"""
WideDeep_net
=
WideDeepModel
(
config
)
loss_net
=
NetWithLossClass
(
WideDeep_net
,
config
)
train_net
=
TrainStepWrap
(
loss_net
,
config
)
eval_net
=
PredictWithSigmoid
(
WideDeep_net
)
return
train_net
,
eval_net
class
ModelBuilder
():
"""
ModelBuilder.
"""
def
__init__
(
self
):
pass
def
get_hook
(
self
):
pass
def
get_train_hook
(
self
):
hooks
=
[]
callback
=
LossCallBack
()
hooks
.
append
(
callback
)
if
int
(
os
.
getenv
(
'DEVICE_ID'
))
==
0
:
pass
return
hooks
def
get_net
(
self
,
config
):
return
get_WideDeep_net
(
config
)
def
train_and_eval
(
config
):
"""
train_and_eval.
"""
data_path
=
config
.
data_path
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
config
.
batch_size
,
is_tf_dataset
=
config
.
is_tf_dataset
)
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
config
.
batch_size
,
is_tf_dataset
=
config
.
is_tf_dataset
)
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
net_builder
=
ModelBuilder
()
train_net
,
eval_net
=
net_builder
.
get_net
(
config
)
train_net
.
set_train
()
auc_metric
=
AUCMetric
()
model
=
Model
(
train_net
,
eval_network
=
eval_net
,
metrics
=
{
"auc"
:
auc_metric
})
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
config
)
callback
=
LossCallBack
(
config
)
ckptconfig
=
CheckpointConfig
(
save_checkpoint_steps
=
ds_train
.
get_dataset_size
(),
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'widedeep_train'
,
directory
=
config
.
ckpt_path
,
config
=
ckptconfig
)
model
.
train
(
epochs
,
ds_train
,
callbacks
=
[
TimeMonitor
(
ds_train
.
get_dataset_size
()),
eval_callback
,
callback
,
ckpoint_cb
])
if
__name__
==
"__main__"
:
wide_and_deep_config
=
WideDeepConfig
()
wide_and_deep_config
.
argparse_init
()
compute_emb_dim
(
wide_and_deep_config
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Davinci"
,
save_graphs
=
True
)
train_and_eval
(
wide_and_deep_config
)
model_zoo/official/recommend/wide_and_deep_multitable/train_and_eval_distribute.py
0 → 100644
浏览文件 @
245415f5
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" training_multinpu"""
import
os
import
sys
from
mindspore
import
Model
,
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
TimeMonitor
from
mindspore.train
import
ParallelMode
from
mindspore.communication.management
import
get_rank
,
get_group_size
,
init
from
src.wide_and_deep
import
PredictWithSigmoid
,
TrainStepWrap
,
NetWithLossClass
,
WideDeepModel
from
src.callbacks
import
LossCallBack
,
EvalCallBack
from
src.datasets
import
create_dataset
,
compute_emb_dim
from
src.metrics
import
AUCMetric
from
src.config
import
WideDeepConfig
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
def
get_WideDeep_net
(
config
):
"""
get_WideDeep_net
"""
WideDeep_net
=
WideDeepModel
(
config
)
loss_net
=
NetWithLossClass
(
WideDeep_net
,
config
)
train_net
=
TrainStepWrap
(
loss_net
,
config
)
eval_net
=
PredictWithSigmoid
(
WideDeep_net
)
return
train_net
,
eval_net
class
ModelBuilder
():
"""
ModelBuilder
"""
def
__init__
(
self
):
pass
def
get_hook
(
self
):
pass
def
get_train_hook
(
self
):
hooks
=
[]
callback
=
LossCallBack
()
hooks
.
append
(
callback
)
if
int
(
os
.
getenv
(
'DEVICE_ID'
))
==
0
:
pass
return
hooks
def
get_net
(
self
,
config
):
return
get_WideDeep_net
(
config
)
def
train_and_eval
(
config
):
"""
train_and_eval
"""
data_path
=
config
.
data_path
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
config
.
batch_size
,
is_tf_dataset
=
config
.
is_tf_dataset
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
config
.
batch_size
,
is_tf_dataset
=
config
.
is_tf_dataset
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
net_builder
=
ModelBuilder
()
train_net
,
eval_net
=
net_builder
.
get_net
(
config
)
train_net
.
set_train
()
auc_metric
=
AUCMetric
()
model
=
Model
(
train_net
,
eval_network
=
eval_net
,
metrics
=
{
"auc"
:
auc_metric
})
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
config
)
callback
=
LossCallBack
(
config
)
ckptconfig
=
CheckpointConfig
(
save_checkpoint_steps
=
ds_train
.
get_dataset_size
(),
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
'widedeep_train'
,
directory
=
config
.
ckpt_path
,
config
=
ckptconfig
)
model
.
train
(
epochs
,
ds_train
,
callbacks
=
[
TimeMonitor
(
ds_train
.
get_dataset_size
()),
eval_callback
,
callback
,
ckpoint_cb
])
if
__name__
==
"__main__"
:
wide_and_deep_config
=
WideDeepConfig
()
wide_and_deep_config
.
argparse_init
()
compute_emb_dim
(
wide_and_deep_config
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Davinci"
,
save_graphs
=
True
)
init
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
get_group_size
())
train_and_eval
(
wide_and_deep_config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录