Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
1370b9b6
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1370b9b6
编写于
4月 28, 2020
作者:
O
overlord
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/models
into ncf_04281221
上级
fbf706b6
6896c7cd
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
684 addition
and
61 deletion
+684
-61
PaddleCV/gan/README.md
PaddleCV/gan/README.md
+11
-7
PaddleCV/gan/metric/compute_fid.py
PaddleCV/gan/metric/compute_fid.py
+183
-0
PaddleCV/gan/metric/inception.py
PaddleCV/gan/metric/inception.py
+302
-0
PaddleCV/gan/network/base_network.py
PaddleCV/gan/network/base_network.py
+1
-1
PaddleCV/gan/requirements.txt
PaddleCV/gan/requirements.txt
+1
-0
PaddleNLP/machine_translation/transformer/predict.py
PaddleNLP/machine_translation/transformer/predict.py
+8
-3
PaddleRec/README.md
PaddleRec/README.md
+30
-17
PaddleRec/img/paddlerec.png
PaddleRec/img/paddlerec.png
+0
-0
PaddleRec/text_classification/README.md
PaddleRec/text_classification/README.md
+0
-0
PaddleRec/text_classification/net.py
PaddleRec/text_classification/net.py
+0
-0
PaddleRec/text_classification/train.py
PaddleRec/text_classification/train.py
+0
-0
PaddleRec/word2vec/README.md
PaddleRec/word2vec/README.md
+2
-1
PaddleRec/word2vec/cluster_train.py
PaddleRec/word2vec/cluster_train.py
+29
-15
PaddleRec/word2vec/net.py
PaddleRec/word2vec/net.py
+83
-0
PaddleRec/word2vec/train.py
PaddleRec/word2vec/train.py
+29
-15
PaddleRec/word2vec/utils.py
PaddleRec/word2vec/utils.py
+5
-2
未找到文件。
PaddleCV/gan/README.md
浏览文件 @
1370b9b6
...
...
@@ -18,12 +18,12 @@
本图像生成模型库包含CGAN
\[
[
3
](
#参考文献
)
\]
, DCGAN
\[
[
4
](
#参考文献
)
\]
, Pix2Pix
\[
[
5
](
#参考文献
)
\]
, CycleGAN
\[
[
6
](
#参考文献
)
\]
, StarGAN
\[
[
7
](
#参考文献
)
\]
, AttGAN
\[
[
8
](
#参考文献
)
\]
, STGAN
\[
[
9
](
#参考文献
)
\]
, SPADE
\[
[
13
](
#参考文献
)
\]
。
注意:
1.
StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练
。
2.
GAN模型目前仅仅验证了单机单卡训练和预测结果
。
3.
CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。cityscapes数据集需要从
[
官方
](
https://www.cityscapes-dataset.com
)
下载数据,下载完之后使用
`scripts/prepare_cityscapes_dataset.py`
处理,处理后的文件夹命名为cityscapes并放入data目录下即可
。
4.
PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装
。
5.
中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换
。
6.
infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)。
1.
GAN模型目前仅仅验证了单机单卡训练和预测结果
。
2.
CGAN和DCGAN两个模型训练使用的数据集为MNIST数据集;StarGAN,AttGAN和STGAN的数据集为CelebA数据集。Pix2Pix和CycleGAN支持的数据集可以参考download.py中的cycle_pix_dataset。cityscapes数据集需要从
[
官方
](
https://www.cityscapes-dataset.com
)
下载数据,下载完之后使用
`scripts/prepare_cityscapes_dataset.py`
处理,处理后的文件夹命名为cityscapes并放入data目录下即可
。
3.
PaddlePaddle1.5.1及之前的版本不支持在AttGAN和STGAN模型里的判别器加上的instance norm。如果要在判别器中加上instance norm,请源码编译develop分支并安装
。
4.
中间效果图保存在${output_dir}/test文件夹中。对于Pix2Pix来说,inputA 和inputB 代表输入的两种风格的图片,fakeB表示生成图片;对于CycleGAN来说,inputA表示输入图片,fakeB表示inputA根据生成的图片,cycA表示fakeB经过生成器重构出来的对应于inputA的重构图片;对于StarGAN,AttGAN和STGAN来说,第一行表示原图,之后的每一行都代表一种属性变换
。
5.
infer过程使用的test_list文件和训练过程中使用的train_list具有相同格式,第一行为样本数量,第二行为属性,之后的行中第一个表示图片名称,之后的-1和1表示该图片是否拥有该属性(1为有该属性,-1为没有该属性)
。
6.
metric中的fid评价指标需要先下载inceptionV3模型参数,模型参数下载链接:
[
inceptionV3
](
https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
)
图像生成模型库库的目录结构如下:
```
...
...
@@ -58,6 +58,10 @@
│ ├── celeba
│ ├── ${image_dir} 存放实际图片
│ ├── list 文件
│
├── metric 评价指标
│ ├── compute_fid.py 计算fid_score的文件
│ ├── inception.py 计算fid_score所需要的inceptionV3网络结构
```
...
...
@@ -71,7 +75,7 @@
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.7.1或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据
[
安装文档
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html
)
中的说明来更新PaddlePaddle。
其他依赖包:
1.
`pip install
imageio`
或者
`pip install
-r requirements.txt`
安装imageio包(保存图片代码中所依赖的包)
1.
`pip install -r requirements.txt`
安装imageio包(保存图片代码中所依赖的包)
### 任务简介
...
...
PaddleCV/gan/metric/compute_fid.py
0 → 100644
浏览文件 @
1370b9b6
import
os
import
fnmatch
import
numpy
as
np
import
cv2
from
cv2
import
imread
from
scipy
import
linalg
import
paddle.fluid
as
fluid
from
inception
import
InceptionV3
def
tqdm
(
x
):
return
x
""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py
"""
"""
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
"""
def
_calculate_frechet_distance
(
mu1
,
sigma1
,
mu2
,
sigma2
,
eps
=
1e-6
):
m1
=
np
.
atleast_1d
(
mu1
)
m2
=
np
.
atleast_1d
(
mu2
)
sigma1
=
np
.
atleast_2d
(
sigma1
)
sigma2
=
np
.
atleast_2d
(
sigma2
)
assert
mu1
.
shape
==
mu2
.
shape
,
'Training and test mean vectors have different lengths'
assert
sigma1
.
shape
==
sigma2
.
shape
,
'Training and test covariances have different dimensions'
diff
=
mu1
-
mu2
t
=
sigma1
.
dot
(
sigma2
)
covmean
,
_
=
linalg
.
sqrtm
(
sigma1
.
dot
(
sigma2
),
disp
=
False
)
if
not
np
.
isfinite
(
covmean
).
all
():
msg
=
(
'fid calculation produces singular product; '
'adding %s to diagonal of cov estimates'
)
%
eps
print
(
msg
)
offset
=
np
.
eye
(
sigma1
.
shape
[
0
])
*
eps
covmean
=
linalg
.
sqrtm
((
sigma1
+
offset
).
dot
(
sigma2
+
offset
))
# Numerical error might give slight imaginary component
if
np
.
iscomplexobj
(
covmean
):
if
not
np
.
allclose
(
np
.
diagonal
(
covmean
).
imag
,
0
,
atol
=
1e-3
):
m
=
np
.
max
(
np
.
abs
(
covmean
.
imag
))
raise
ValueError
(
'Imaginary component {}'
.
format
(
m
))
covmean
=
covmean
.
real
tr_covmean
=
np
.
trace
(
covmean
)
return
(
diff
.
dot
(
diff
)
+
np
.
trace
(
sigma1
)
+
np
.
trace
(
sigma2
)
-
2
*
tr_covmean
)
def
_build_program
(
model
):
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
,
startup_program
):
images
=
fluid
.
data
(
name
=
'images'
,
shape
=
[
None
,
3
,
None
,
None
])
output
=
model
.
network
(
images
,
class_dim
=
1008
)
pred
=
fluid
.
layers
.
pool2d
(
output
[
0
],
global_pooling
=
True
)
test_program
=
main_program
.
clone
(
for_test
=
True
)
return
pred
,
test_program
,
startup_program
def
_get_activations_from_ims
(
img
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
):
n_batches
=
(
len
(
img
)
+
batch_size
-
1
)
//
batch_size
n_used_img
=
len
(
img
)
pred_arr
=
np
.
empty
((
n_used_img
,
dims
))
for
i
in
tqdm
(
range
(
n_batches
)):
start
=
i
*
batch_size
end
=
start
+
batch_size
if
end
>
len
(
img
):
end
=
len
(
img
)
images
=
img
[
start
:
end
]
if
images
.
shape
[
1
]
!=
3
:
images
=
images
.
transpose
((
0
,
3
,
1
,
2
))
images
/=
255
output
,
main_program
,
startup_program
=
_build_program
(
model
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
fluid
.
load
(
main_program
,
os
.
path
.
join
(
premodel_path
,
'paddle_inceptionv3'
),
exe
)
pred
=
exe
.
run
(
main_program
,
feed
=
{
'images'
:
images
},
fetch_list
=
[
output
])[
0
]
pred_arr
[
start
:
end
]
=
pred
.
reshape
(
end
-
start
,
-
1
)
return
pred_arr
def
_compute_statistic_of_img
(
img
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
):
act
=
_get_activations_from_ims
(
img
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
mu
=
np
.
mean
(
act
,
axis
=
0
)
sigma
=
np
.
cov
(
act
,
rowvar
=
False
)
return
mu
,
sigma
def
calculate_fid_given_img
(
img_fake
,
img_real
,
batch_size
,
use_gpu
,
dims
,
premodel_path
,
model
=
None
):
assert
os
.
path
.
exists
(
premodel_path
),
'pretrain_model path {} is not exists! Please download it first'
.
format
(
premodel_path
)
if
model
is
None
:
block_idx
=
InceptionV3
.
BLOCK_INDEX_BY_DIM
[
dims
]
model
=
InceptionV3
([
block_idx
])
m1
,
s1
=
_compute_statistic_of_img
(
img_fake
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
m2
,
s2
=
_compute_statistic_of_img
(
img_real
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
return
fid_value
def
_get_activations
(
files
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
):
if
len
(
files
)
%
batch_size
!=
0
:
print
((
'Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.'
))
if
batch_size
>
len
(
files
):
print
((
'Warning: batch size is bigger than the datasets size. '
'Setting batch size to datasets size'
))
batch_size
=
len
(
files
)
n_batches
=
len
(
files
)
//
batch_size
n_used_imgs
=
n_batches
*
batch_size
pred_arr
=
np
.
empty
((
n_used_imgs
,
dims
))
for
i
in
tqdm
(
range
(
n_batches
)):
start
=
i
*
batch_size
end
=
start
+
batch_size
images
=
np
.
array
([
imread
(
str
(
f
)).
astype
(
np
.
float32
)
for
f
in
files
[
start
:
end
]])
if
len
(
images
.
shape
)
!=
4
:
images
=
imread
(
str
(
files
[
start
]))
images
=
cv2
.
cvtColor
(
images
,
cv2
.
COLOR_BGR2GRAY
)
images
=
np
.
array
([
images
.
astype
(
np
.
float32
)])
images
=
images
.
transpose
((
0
,
3
,
1
,
2
))
images
/=
255
output
,
main_program
,
startup_program
=
_build_program
(
model
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
fluid
.
load
(
main_program
,
os
.
path
.
join
(
premodel_path
,
'paddle_inceptionv3'
),
exe
)
pred
=
exe
.
run
(
main_program
,
feed
=
{
'images'
:
images
},
fetch_list
=
[
output
])[
0
]
pred_arr
[
start
:
end
]
=
pred
.
reshape
(
end
-
start
,
-
1
)
return
pred_arr
def
_calculate_activation_statistics
(
files
,
model
,
premodel_path
,
batch_size
=
50
,
dims
=
2048
,
use_gpu
=
False
):
act
=
_get_activations
(
files
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
mu
=
np
.
mean
(
act
,
axis
=
0
)
sigma
=
np
.
cov
(
act
,
rowvar
=
False
)
return
mu
,
sigma
def
_compute_statistics_of_path
(
path
,
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
):
if
path
.
endswith
(
'.npz'
):
f
=
np
.
load
(
path
)
m
,
s
=
f
[
'mu'
][:],
f
[
'sigma'
][:]
f
.
close
()
else
:
files
=
[]
for
root
,
dirnames
,
filenames
in
os
.
walk
(
path
):
for
filename
in
fnmatch
.
filter
(
filenames
,
'*.jpg'
)
or
fnmatch
.
filter
(
filenames
,
'*.png'
):
files
.
append
(
os
.
path
.
join
(
root
,
filename
))
m
,
s
=
_calculate_activation_statistics
(
files
,
model
,
premodel_path
,
batch_size
,
dims
,
use_gpu
)
return
m
,
s
def
calculate_fid_given_paths
(
paths
,
batch_size
,
use_gpu
,
dims
,
premodel_path
,
model
=
None
):
assert
os
.
path
.
exists
(
premodel_path
),
'pretrain_model path {} is not exists! Please download it first'
.
format
(
premodel_path
)
for
p
in
paths
:
if
not
os
.
path
.
exists
(
p
):
raise
RuntimeError
(
'Invalid path: %s'
%
p
)
if
model
is
None
:
block_idx
=
InceptionV3
.
BLOCK_INDEX_BY_DIM
[
dims
]
model
=
InceptionV3
([
block_idx
])
m1
,
s1
=
_compute_statistics_of_path
(
paths
[
0
],
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
m2
,
s2
=
_compute_statistics_of_path
(
paths
[
1
],
model
,
batch_size
,
dims
,
use_gpu
,
premodel_path
)
fid_value
=
_calculate_frechet_distance
(
m1
,
s1
,
m2
,
s2
)
return
fid_value
PaddleCV/gan/metric/inception.py
0 → 100644
浏览文件 @
1370b9b6
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
__all__
=
[
'InceptionV3'
]
class
InceptionV3
:
DEFAULT_BLOCK_INDEX
=
3
BLOCK_INDEX_BY_DIM
=
{
64
:
0
,
# First max pooling features
192
:
1
,
# Second max pooling featurs
768
:
2
,
# Pre-aux classifier features
2048
:
3
# Final average pooling features
}
def
__init__
(
self
,
output_blocks
=
[
DEFAULT_BLOCK_INDEX
],
resize_input
=
True
,
normalize_input
=
True
):
self
.
resize_input
=
resize_input
self
.
normalize_input
=
normalize_input
self
.
output_blocks
=
sorted
(
output_blocks
)
self
.
last_needed_block
=
max
(
output_blocks
)
assert
self
.
last_needed_block
<=
3
,
'Last possible output block index is 3'
def
network
(
self
,
x
,
class_dim
=
1000
,
aux_logits
=
False
):
if
self
.
resize_input
:
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
out_shape
=
[
299
,
299
],
align_corners
=
False
,
align_mode
=
0
)
if
self
.
normalize_input
:
x
=
x
*
2
-
1
out
,
_
,
=
self
.
fid_inceptionv3
(
x
,
class_dim
,
aux_logits
)
return
out
def
fid_inceptionv3
(
self
,
x
,
num_classes
=
1000
,
aux_logits
=
False
):
""" inception v3 model for FID computation
"""
out
=
[]
aux
=
None
### block0
x
=
self
.
conv_bn_layer
(
x
,
32
,
3
,
stride
=
2
,
name
=
'Conv2d_1a_3x3'
)
x
=
self
.
conv_bn_layer
(
x
,
32
,
3
,
name
=
'Conv2d_2a_3x3'
)
x
=
self
.
conv_bn_layer
(
x
,
64
,
3
,
padding
=
1
,
name
=
'Conv2d_2b_3x3'
)
x
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
2
,
pool_type
=
'max'
)
if
0
in
self
.
output_blocks
:
out
.
append
(
x
)
if
self
.
last_needed_block
>=
1
:
### block1
x
=
self
.
conv_bn_layer
(
x
,
80
,
1
,
name
=
'Conv2d_3b_1x1'
)
x
=
self
.
conv_bn_layer
(
x
,
192
,
3
,
name
=
'Conv2d_4a_3x3'
)
x
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
2
,
pool_type
=
'max'
)
if
1
in
self
.
output_blocks
:
out
.
append
(
x
)
if
self
.
last_needed_block
>=
2
:
### block2
### Mixed_5b 5c 5d
x
=
self
.
fid_inceptionA
(
x
,
pool_features
=
32
,
name
=
'Mixed_5b'
)
x
=
self
.
fid_inceptionA
(
x
,
pool_features
=
64
,
name
=
'Mixed_5c'
)
x
=
self
.
fid_inceptionA
(
x
,
pool_features
=
64
,
name
=
'Mixed_5d'
)
### Mixed_6
x
=
self
.
inceptionB
(
x
,
name
=
'Mixed_6a'
)
x
=
self
.
fid_inceptionC
(
x
,
c7
=
128
,
name
=
'Mixed_6b'
)
x
=
self
.
fid_inceptionC
(
x
,
c7
=
160
,
name
=
'Mixed_6c'
)
x
=
self
.
fid_inceptionC
(
x
,
c7
=
160
,
name
=
'Mixed_6d'
)
x
=
self
.
fid_inceptionC
(
x
,
c7
=
192
,
name
=
'Mixed_6e'
)
if
2
in
self
.
output_blocks
:
out
.
append
(
x
)
if
aux_logits
:
aux
=
self
.
inceptionAux
(
x
,
num_classes
,
name
=
'AuxLogits'
)
if
self
.
last_needed_block
>=
3
:
### block3
### Mixed_7
x
=
self
.
inceptionD
(
x
,
name
=
'Mixed_7a'
)
x
=
self
.
fid_inceptionE_1
(
x
,
name
=
'Mixed_7b'
)
x
=
self
.
fid_inceptionE_2
(
x
,
name
=
'Mixed_7c'
)
x
=
fluid
.
layers
.
pool2d
(
x
,
global_pooling
=
True
,
pool_type
=
'avg'
)
out
.
append
(
x
)
#x = fluid.layers.dropout(x, dropout_prob=0.5)
#x = fluid.layers.flatten(x, axis=1)
#x = fluid.layers.fc(x, size=num_classes, param_attr=ParamAttr(name='fc.weight'), bias_attr=ParamAttr(name='fc.bias'))
return
out
,
aux
def
inceptionA
(
self
,
x
,
pool_features
,
name
=
None
):
branch1x1
=
self
.
conv_bn_layer
(
x
,
64
,
1
,
name
=
name
+
'.branch1x1'
)
branch5x5
=
self
.
conv_bn_layer
(
x
,
48
,
1
,
name
=
name
+
'.branch5x5_1'
)
branch5x5
=
self
.
conv_bn_layer
(
branch5x5
,
64
,
5
,
padding
=
2
,
name
=
name
+
'.branch5x5_2'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
64
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_3'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
pool_features
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch5x5
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
inceptionB
(
self
,
x
,
name
=
None
):
branch3x3
=
self
.
conv_bn_layer
(
x
,
384
,
3
,
stride
=
2
,
name
=
name
+
'.branch3x3'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
64
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
stride
=
2
,
name
=
name
+
'.branch3x3dbl_3'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
2
,
pool_type
=
'max'
)
return
fluid
.
layers
.
concat
([
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
inceptionC
(
self
,
x
,
c7
,
name
=
None
):
branch1x1
=
self
.
conv_bn_layer
(
x
,
192
,
1
,
name
=
name
+
'.branch1x1'
)
branch7x7
=
self
.
conv_bn_layer
(
x
,
c7
,
1
,
name
=
name
+
'.branch7x7_1'
)
branch7x7
=
self
.
conv_bn_layer
(
branch7x7
,
c7
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7_2'
)
branch7x7
=
self
.
conv_bn_layer
(
branch7x7
,
192
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7_3'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
x
,
c7
,
1
,
name
=
name
+
'.branch7x7dbl_1'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7dbl_2'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7dbl_3'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7dbl_4'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
192
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7dbl_5'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
192
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch7x7
,
branch7x7dbl
,
branch_pool
],
axis
=
1
)
def
inceptionD
(
self
,
x
,
name
=
None
):
branch3x3
=
self
.
conv_bn_layer
(
x
,
192
,
1
,
name
=
name
+
'.branch3x3_1'
)
branch3x3
=
self
.
conv_bn_layer
(
branch3x3
,
320
,
3
,
stride
=
2
,
name
=
name
+
'.branch3x3_2'
)
branch7x7x3
=
self
.
conv_bn_layer
(
x
,
192
,
1
,
name
=
name
+
'.branch7x7x3_1'
)
branch7x7x3
=
self
.
conv_bn_layer
(
branch7x7x3
,
192
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7x3_2'
)
branch7x7x3
=
self
.
conv_bn_layer
(
branch7x7x3
,
192
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7x3_3'
)
branch7x7x3
=
self
.
conv_bn_layer
(
branch7x7x3
,
192
,
3
,
stride
=
2
,
name
=
name
+
'.branch7x7x3_4'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
2
,
pool_type
=
'max'
)
return
fluid
.
layers
.
concat
([
branch3x3
,
branch7x7x3
,
branch_pool
],
axis
=
1
)
def
inceptionE
(
self
,
x
,
name
=
None
):
branch1x1
=
self
.
conv_bn_layer
(
x
,
320
,
1
,
name
=
name
+
'.branch1x1'
)
branch3x3
=
self
.
conv_bn_layer
(
x
,
384
,
1
,
name
=
name
+
'.branch3x3_1'
)
branch3x3_2a
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3_2a'
)
branch3x3_2b
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3_2b'
)
branch3x3
=
fluid
.
layers
.
concat
([
branch3x3_2a
,
branch3x3_2b
],
axis
=
1
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
448
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl_3a
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3dbl_3a'
)
branch3x3dbl_3b
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3dbl_3b'
)
branch3x3dbl
=
fluid
.
layers
.
concat
([
branch3x3dbl_3a
,
branch3x3dbl_3b
],
axis
=
1
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
192
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
inceptionAux
(
self
,
x
,
num_classes
,
name
=
None
):
x
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
5
,
pool_stride
=
3
,
pool_type
=
'avg'
)
x
=
self
.
conv_bn_layer
(
x
,
128
,
1
,
name
=
name
+
'.conv0'
)
x
=
self
.
conv_bn_layer
(
x
,
768
,
5
,
name
=
name
+
'.conv1'
)
x
=
fluid
.
layers
.
pool2d
(
x
,
global_pooling
=
True
,
pool_type
=
'avg'
)
x
=
fluid
.
layers
.
flatten
(
x
,
axis
=
1
)
x
=
fluid
.
layers
.
fc
(
x
,
size
=
num_classes
)
return
x
def
fid_inceptionA
(
self
,
x
,
pool_features
,
name
=
None
):
""" FID block in inception v3
"""
branch1x1
=
self
.
conv_bn_layer
(
x
,
64
,
1
,
name
=
name
+
'.branch1x1'
)
branch5x5
=
self
.
conv_bn_layer
(
x
,
48
,
1
,
name
=
name
+
'.branch5x5_1'
)
branch5x5
=
self
.
conv_bn_layer
(
branch5x5
,
64
,
5
,
padding
=
2
,
name
=
name
+
'.branch5x5_2'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
64
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
96
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_3'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
exclusive
=
True
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
pool_features
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch5x5
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
fid_inceptionC
(
self
,
x
,
c7
,
name
=
None
):
""" FID block in inception v3
"""
branch1x1
=
self
.
conv_bn_layer
(
x
,
192
,
1
,
name
=
name
+
'.branch1x1'
)
branch7x7
=
self
.
conv_bn_layer
(
x
,
c7
,
1
,
name
=
name
+
'.branch7x7_1'
)
branch7x7
=
self
.
conv_bn_layer
(
branch7x7
,
c7
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7_2'
)
branch7x7
=
self
.
conv_bn_layer
(
branch7x7
,
192
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7_3'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
x
,
c7
,
1
,
name
=
name
+
'.branch7x7dbl_1'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7dbl_2'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7dbl_3'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
c7
,
(
7
,
1
),
padding
=
(
3
,
0
),
name
=
name
+
'.branch7x7dbl_4'
)
branch7x7dbl
=
self
.
conv_bn_layer
(
branch7x7dbl
,
192
,
(
1
,
7
),
padding
=
(
0
,
3
),
name
=
name
+
'.branch7x7dbl_5'
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
exclusive
=
True
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
192
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch7x7
,
branch7x7dbl
,
branch_pool
],
axis
=
1
)
def
fid_inceptionE_1
(
self
,
x
,
name
=
None
):
""" FID block in inception v3
"""
branch1x1
=
self
.
conv_bn_layer
(
x
,
320
,
1
,
name
=
name
+
'.branch1x1'
)
branch3x3
=
self
.
conv_bn_layer
(
x
,
384
,
1
,
name
=
name
+
'.branch3x3_1'
)
branch3x3_2a
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3_2a'
)
branch3x3_2b
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3_2b'
)
branch3x3
=
fluid
.
layers
.
concat
([
branch3x3_2a
,
branch3x3_2b
],
axis
=
1
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
448
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl_3a
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3dbl_3a'
)
branch3x3dbl_3b
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3dbl_3b'
)
branch3x3dbl
=
fluid
.
layers
.
concat
([
branch3x3dbl_3a
,
branch3x3dbl_3b
],
axis
=
1
)
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
exclusive
=
True
,
pool_type
=
'avg'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
192
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
fid_inceptionE_2
(
self
,
x
,
name
=
None
):
""" FID block in inception v3
"""
branch1x1
=
self
.
conv_bn_layer
(
x
,
320
,
1
,
name
=
name
+
'.branch1x1'
)
branch3x3
=
self
.
conv_bn_layer
(
x
,
384
,
1
,
name
=
name
+
'.branch3x3_1'
)
branch3x3_2a
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3_2a'
)
branch3x3_2b
=
self
.
conv_bn_layer
(
branch3x3
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3_2b'
)
branch3x3
=
fluid
.
layers
.
concat
([
branch3x3_2a
,
branch3x3_2b
],
axis
=
1
)
branch3x3dbl
=
self
.
conv_bn_layer
(
x
,
448
,
1
,
name
=
name
+
'.branch3x3dbl_1'
)
branch3x3dbl
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
3
,
padding
=
1
,
name
=
name
+
'.branch3x3dbl_2'
)
branch3x3dbl_3a
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
1
,
3
),
padding
=
(
0
,
1
),
name
=
name
+
'.branch3x3dbl_3a'
)
branch3x3dbl_3b
=
self
.
conv_bn_layer
(
branch3x3dbl
,
384
,
(
3
,
1
),
padding
=
(
1
,
0
),
name
=
name
+
'.branch3x3dbl_3b'
)
branch3x3dbl
=
fluid
.
layers
.
concat
([
branch3x3dbl_3a
,
branch3x3dbl_3b
],
axis
=
1
)
### same with paper
branch_pool
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
3
,
pool_stride
=
1
,
pool_padding
=
1
,
pool_type
=
'max'
)
branch_pool
=
self
.
conv_bn_layer
(
branch_pool
,
192
,
1
,
name
=
name
+
'.branch_pool'
)
return
fluid
.
layers
.
concat
([
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
def
conv_bn_layer
(
self
,
data
,
num_filters
,
filter_size
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
'relu'
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weight"
),
bias_attr
=
False
,
name
=
name
)
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
epsilon
=
0.001
,
name
=
name
+
'.bn'
,
param_attr
=
ParamAttr
(
name
=
name
+
".bn.weight"
),
bias_attr
=
ParamAttr
(
name
=
name
+
".bn.bias"
),
moving_mean_name
=
name
+
'.bn.running_mean'
,
moving_variance_name
=
name
+
'.bn.running_var'
)
PaddleCV/gan/network/base_network.py
浏览文件 @
1370b9b6
...
...
@@ -42,7 +42,7 @@ def norm_layer(input,
if
norm_type
==
'batch_norm'
:
if
affine
==
True
:
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_w'
,
initializer
=
fluid
.
initializer
.
Constant
(
1.0
))
name
=
name
+
'_w'
,
initializer
=
fluid
.
initializer
.
Normal
(
loc
=
1.0
,
scale
=
0.02
))
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_b'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
))
...
...
PaddleCV/gan/requirements.txt
浏览文件 @
1370b9b6
numpy >= 1.15.0
opencv-python
PaddleNLP/machine_translation/transformer/predict.py
浏览文件 @
1370b9b6
...
...
@@ -92,9 +92,14 @@ def do_predict(args):
input_field_names
=
desc
.
encoder_data_input_fields
+
desc
.
fast_decoder_data_input_fields
input_descs
=
desc
.
get_input_descs
(
args
.
args
)
input_slots
=
[{
"name"
:
name
,
"shape"
:
input_descs
[
name
][
0
],
"dtype"
:
input_descs
[
name
][
1
]
"name"
:
name
,
"shape"
:
input_descs
[
name
][
0
],
"dtype"
:
input_descs
[
name
][
1
],
"lod_level"
:
input_descs
[
name
][
2
]
if
len
(
input_descs
[
name
])
>
2
else
0
}
for
name
in
input_field_names
]
input_field
=
InputField
(
input_slots
)
...
...
PaddleRec/README.md
浏览文件 @
1370b9b6
...
...
@@ -4,20 +4,33 @@ PaddleRec
个性化推荐
-------
推荐系统在当前的互联网服务中正在发挥越来越大的作用,目前大部分电子商务系统、社交网络,广告推荐,搜索引擎,都不同程度的使用了各种形式的个性化推荐技术,帮助用户快速找到他们想要的信息。
在工业可用的推荐系统中,推荐策略一般会被划分为多个模块串联执行。以新闻推荐系统为例,存在多个可以使用深度学习技术的环节,例如新闻的自动化标注,个性化新闻召回,个性化匹配与排序等。PaddlePaddle对推荐算法的训练提供了完整的支持,并提供了多种模型配置供用户选择。
| 模型 | 应用场景 | 简介 |
| :------------------------------------------------: | :---------------------------------------------------: | :----------------------------------------------------------: |
|
[
GRU4Rec
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/gru4rec
)
| Session-based 推荐, 图网络推荐 | 首次将RNN(GRU)运用于session-based推荐,核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型 |
|
[
TagSpace
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/tagspace
)
| 标签推荐 | Tagspace模型学习文本及标签的embedding表示,应用于工业级的标签推荐,具体应用场景有feed新闻标签推荐。|
|
[
SequenceSemanticRetrieval
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ssr
)
| 召回 | 解决了 GRU4Rec 模型无法预测训练数据集中不存在的项目,比如新闻推荐的问题。它由两个部分组成:一个是匹配模型部分,另一个是检索部分 |
|
[
Word2Vec
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/word2vec
)
| 词向量 | 训练得到词的向量表示、广泛应用于NLP、推荐等任务场景。 |
|
[
Multiview-Simnet
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/multiview_simnet
)
| 排序 | 多视角Simnet模型是可以融合用户以及推荐项目的多个视角的特征并进行个性化匹配学习的一体化模型。这类模型在很多工业化的场景中都会被使用到,比如百度的Feed产品中 |
|
[
GraphNeuralNetwork
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/gnn
)
| 召回 | SR-GNN,全称为Session-based Recommendations with Graph Neural Network(GNN)。使用GNN进行会话序列建模。 |
|
[
DeepInterestNetwork
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/din
)
| 排序 | DIN,全称为Deep Interest Network。特点为对历史序列建模的过程中结合了预估目标的信息。 |
|
[
DeepFM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/deepfm
)
| 推荐系统 | DeepFM,全称Factorization-Machine based Neural Network。经典的CTR推荐算法,网络由DNN和FM两部分组成。 |
|
[
XDeepFM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/xdeepfm
)
| 推荐系统 | xDeepFM,全称extreme Factorization Machine。对DeepFM和DCN的改进,提出CIN(Compressed Interaction Network),使用vector-wise等级的显示特征交叉。 |
|
[
DCN
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/dcn
)
| 推荐系统 | 全称Deep & Cross Network。提出一种新的交叉网络(cross network),在每个层上明确地应用特征交叉。 |
推荐系统在当前的互联网服务中正在发挥越来越大的作用,目前大部分电子商务系统、社交网络,广告推荐,搜索引擎,信息流,都不同程度的使用了各种形式的个性化推荐技术,帮助用户快速找到他们想要的信息。
在工业可用的推荐系统中,推荐策略一般会被划分为多个模块串联执行。以新闻推荐系统为例,存在多个可以使用深度学习技术的环节,例如新闻的内容理解--标签标注,个性化新闻召回,个性化匹配与排序,融合等。PaddlePaddle对推荐算法的训练提供了完整的支持,并提供了多种模型配置供用户选择。
PaddleRec全景图:
![
paddlerec
](
./img/paddlerec.png
)
| 任务场景 | 模型 | 简介 |
| :------------------------:| :------------------------------------------------: | :----------------------------------------------------------: |
| 内容理解 |
[
TagSpace
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/tagspace
)
|Tagspace模型学习文本及标签的embedding表示,应用于工业级的标签推荐,具体应用场景有feed新闻标签推荐。|
| 内容理解 |
[
TextClassification
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/text_classification
)
|文本分类,具体应用场景有feed新闻标签分类。|
| 召回 |
[
Word2Vec
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/word2vec
)
|训练得到词的向量表示、广泛应用于NLP、推荐等任务场景。 |
| 召回 |
[
GNN
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/gnn
)
|SR-GNN,全称为Session-based Recommendations with Graph Neural Network(GNN)。使用GNN进行会话序列建模。 |
| 召回 |
[
TDM-Variant
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/tdm
)
|全称为Tree-based Deep Model for Recommender Systems。层次化建模及检索 |
| 召回 |
[
GRU4Rec
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/gru4rec
)
|Session-based 推荐, 首次将RNN(GRU)运用于session-based推荐,核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型 |
| 召回 |
[
SSR
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ssr
)
|全称为SequenceSemanticRetrieval, 解决了 GRU4Rec 模型无法预测训练数据集中不存在的项目,比如新闻推荐的问题。它由两个部分组成:一个是匹配模型部分,另一个是检索部分 |
| 匹配 |
[
Multiview-Simnet
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/multiview_simnet
)
|多视角Simnet模型是可以融合用户以及推荐项目的多个视角的特征并进行个性化匹配学习的一体化模型。这类模型在很多工业化的场景中都会被使用到,比如百度的Feed产品中 |
| 匹配 |
[
DSSM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/dssm
)
|全称:Deep Structured Semantic Model深度语义匹配模型 |
| 排序 |
[
DNN
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/dnn
)
|经典的CTR预估算法。|
| 排序 |
[
Wide_Deep
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/wide_deep
)
|经典的CTR预估算法。|
| 排序 |
[
DeepFM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/deepfm
)
|DeepFM,全称Factorization-Machine based Neural Network。经典的CTR预估算法,网络由DNN和FM两部分组成。 |
| 排序 |
[
XDeepFM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/xdeepfm
)
|xDeepFM,全称extreme Factorization Machine。对DeepFM和DCN的改进,提出CIN(Compressed Interaction Network),使用vector-wise等级的显示特征交叉。 |
| 排序 |
[
DCN
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/dcn
)
|全称Deep & Cross Network。提出一种新的交叉网络(cross network),在每个层上明确地应用特征交叉。 |
| 排序 |
[
DIN
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr/din
)
|DIN,全称为Deep Interest Network。特点为对历史序列建模的过程中结合了预估目标的信息。 |
| 融合- 多任务 |
[
ESMM
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/multi_task/esmm
)
|ESMM,全称为Entire Space Multi-task Model。提出一种新的CVR预估模型。 |
| 融合- 多任务 |
[
Share_bottom
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/multi_task/share_bottom
)
|多任务学习的基本框架,其特点是对于不同的任务,底层的参数和网络结构是共享的。 |
| 融合- 多任务 |
[
MMoE
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/multi_task/mmoe
)
|MMOE, 全称为Multi-grate Mixture-of-Experts,可以刻画任务相关性。 |
PaddleRec/img/paddlerec.png
0 → 100644
浏览文件 @
1370b9b6
149.2 KB
PaddleRec/text_classifi
ler
/README.md
→
PaddleRec/text_classifi
cation
/README.md
浏览文件 @
1370b9b6
文件已移动
PaddleRec/text_classifi
ler
/net.py
→
PaddleRec/text_classifi
cation
/net.py
浏览文件 @
1370b9b6
文件已移动
PaddleRec/text_classifi
ler
/train.py
→
PaddleRec/text_classifi
cation
/train.py
浏览文件 @
1370b9b6
文件已移动
PaddleRec/word2vec/README.md
浏览文件 @
1370b9b6
...
...
@@ -20,7 +20,7 @@
## 介绍
本例实现了skip-gram模式的word2vector模型。
**目前模型库下模型均要求使用PaddlePaddle 1.6及以上版本或适当的develop版本。**
**目前模型库下模型均要求使用PaddlePaddle 1.6及以上版本或适当的develop版本。
若要使用shuffle_batch功能,则需使用PaddlePaddle 1.7及以上版本。
**
同时推荐用户参考
[
IPython Notebook demo
](
https://aistudio.baidu.com/aistudio/projectDetail/124377
)
...
...
@@ -102,6 +102,7 @@ OPENBLAS_NUM_THREADS=1 CPU_NUM=5 python train.py --train_data_dir data/convert_t
```
bash
sh cluster_train.sh
```
若需要开启shuffle_batch功能,需在命令中加入
`--with_shuffle_batch`
。单机模拟分布式多机训练,需更改
`cluster_train.sh`
文件,在各个节点的启动命令中加入
`--with_shuffle_batch`
。
## 预测
测试集下载命令如下
...
...
PaddleRec/word2vec/cluster_train.py
浏览文件 @
1370b9b6
...
...
@@ -10,7 +10,7 @@ import paddle
import
paddle.fluid
as
fluid
import
six
import
reader
from
net
import
skip_gram_word2vec
from
net
import
skip_gram_word2vec
,
skip_gram_word2vec_shuffle_batch
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"fluid"
)
...
...
@@ -100,6 +100,12 @@ def parse_args():
type
=
int
,
default
=
1
,
help
=
'The num of trianers, (default: 1)'
)
parser
.
add_argument
(
'--with_shuffle_batch'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'negative samples come from shuffle_batch op or not , (default: False)'
)
return
parser
.
parse_args
()
...
...
@@ -134,11 +140,7 @@ def convert_python_to_tensor(weight, batch_size, sample_reader):
return
__reader__
def
train_loop
(
args
,
train_program
,
reader
,
data_loader
,
loss
,
trainer_id
,
weight
):
data_loader
.
set_batch_generator
(
convert_python_to_tensor
(
weight
,
args
.
batch_size
,
reader
.
train
()))
def
train_loop
(
args
,
train_program
,
data_loader
,
loss
,
trainer_id
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -207,14 +209,26 @@ def train(args):
filelist
,
0
,
1
)
logger
.
info
(
"dict_size: {}"
.
format
(
word2vec_reader
.
dict_size
))
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
if
args
.
with_shuffle_batch
:
loss
,
data_loader
=
skip_gram_word2vec_shuffle_batch
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
data_loader
.
set_sample_generator
(
word2vec_reader
.
train
(),
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
else
:
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
loss
,
data_loader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
loss
,
data_loader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
data_loader
.
set_batch_generator
(
convert_python_to_tensor
(
id_frequencys_pow
,
args
.
batch_size
,
word2vec_reader
.
train
()))
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
...
...
@@ -241,8 +255,8 @@ def train(args):
elif
args
.
role
==
"trainer"
:
print
(
"run trainer"
)
train_loop
(
args
,
t
.
get_trainer_program
(),
word2vec_reader
,
data_loader
,
loss
,
args
.
trainer_id
,
id_frequencys_pow
)
t
.
get_trainer_program
(),
data_loader
,
loss
,
args
.
trainer_id
)
if
__name__
==
'__main__'
:
...
...
PaddleRec/word2vec/net.py
浏览文件 @
1370b9b6
...
...
@@ -20,6 +20,89 @@ import numpy as np
import
paddle.fluid
as
fluid
def
skip_gram_word2vec_shuffle_batch
(
dict_size
,
embedding_size
,
is_sparse
=
False
,
neg_num
=
5
):
words
=
[]
input_word
=
fluid
.
data
(
name
=
"input_word"
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
true_word
=
fluid
.
data
(
name
=
'true_label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
words
.
append
(
input_word
)
words
.
append
(
true_word
)
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
64
,
feed_list
=
words
,
iterable
=
False
)
init_width
=
0.5
/
embedding_size
input_emb
=
fluid
.
embedding
(
input
=
words
[
0
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
embedding_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb'
,
initializer
=
fluid
.
initializer
.
Uniform
(
-
init_width
,
init_width
)))
true_emb_w
=
fluid
.
embedding
(
input
=
words
[
1
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
embedding_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_w'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
)))
true_emb_b
=
fluid
.
embedding
(
input
=
words
[
1
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
1
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_b'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
)))
input_emb
=
fluid
.
layers
.
squeeze
(
input
=
input_emb
,
axes
=
[
1
])
true_emb_w
=
fluid
.
layers
.
squeeze
(
input
=
true_emb_w
,
axes
=
[
1
])
true_emb_b
=
fluid
.
layers
.
squeeze
(
input
=
true_emb_b
,
axes
=
[
1
])
# add shuffle_batch after embedding.
neg_emb_w_list
=
[]
for
i
in
range
(
neg_num
):
neg_emb_w_list
.
append
(
fluid
.
contrib
.
layers
.
shuffle_batch
(
true_emb_w
))
# shuffle true_word
neg_emb_w
=
fluid
.
layers
.
concat
(
neg_emb_w_list
,
axis
=
0
)
neg_emb_w_re
=
fluid
.
layers
.
reshape
(
neg_emb_w
,
shape
=
[
-
1
,
neg_num
,
embedding_size
])
neg_emb_b_list
=
[]
for
i
in
range
(
neg_num
):
neg_emb_b_list
.
append
(
fluid
.
contrib
.
layers
.
shuffle_batch
(
true_emb_b
))
# shuffle true_word
neg_emb_b
=
fluid
.
layers
.
concat
(
neg_emb_b_list
,
axis
=
0
)
neg_emb_b_vec
=
fluid
.
layers
.
reshape
(
neg_emb_b
,
shape
=
[
-
1
,
neg_num
])
true_logits
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
reduce_sum
(
fluid
.
layers
.
elementwise_mul
(
input_emb
,
true_emb_w
),
dim
=
1
,
keep_dim
=
True
),
true_emb_b
)
input_emb_re
=
fluid
.
layers
.
reshape
(
input_emb
,
shape
=
[
-
1
,
1
,
embedding_size
])
neg_matmul
=
fluid
.
layers
.
matmul
(
input_emb_re
,
neg_emb_w_re
,
transpose_y
=
True
)
neg_matmul_re
=
fluid
.
layers
.
reshape
(
neg_matmul
,
shape
=
[
-
1
,
neg_num
])
neg_logits
=
fluid
.
layers
.
elementwise_add
(
neg_matmul_re
,
neg_emb_b_vec
)
#nce loss
label_ones
=
fluid
.
layers
.
fill_constant_batch_size_like
(
true_logits
,
shape
=
[
-
1
,
1
],
value
=
1.0
,
dtype
=
'float32'
)
label_zeros
=
fluid
.
layers
.
fill_constant_batch_size_like
(
true_logits
,
shape
=
[
-
1
,
neg_num
],
value
=
0.0
,
dtype
=
'float32'
)
true_xent
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
true_logits
,
label_ones
)
neg_xent
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
neg_logits
,
label_zeros
)
cost
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
reduce_sum
(
true_xent
,
dim
=
1
),
fluid
.
layers
.
reduce_sum
(
neg_xent
,
dim
=
1
))
avg_cost
=
fluid
.
layers
.
reduce_mean
(
cost
)
return
avg_cost
,
data_loader
def
skip_gram_word2vec
(
dict_size
,
embedding_size
,
is_sparse
=
False
,
neg_num
=
5
):
words
=
[]
...
...
PaddleRec/word2vec/train.py
浏览文件 @
1370b9b6
...
...
@@ -10,7 +10,7 @@ import paddle
import
paddle.fluid
as
fluid
import
six
import
reader
from
net
import
skip_gram_word2vec
from
net
import
skip_gram_word2vec
,
skip_gram_word2vec_shuffle_batch
import
utils
import
sys
...
...
@@ -84,6 +84,12 @@ def parse_args():
required
=
False
,
default
=
False
,
help
=
'print speed or not , (default: False)'
)
parser
.
add_argument
(
'--with_shuffle_batch'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'negative samples come from shuffle_batch op or not , (default: False)'
)
parser
.
add_argument
(
'--enable_ce'
,
action
=
'store_true'
,
help
=
'If set, run the task with continuous evaluation logs.'
)
...
...
@@ -121,10 +127,7 @@ def convert_python_to_tensor(weight, batch_size, sample_reader):
return
__reader__
def
train_loop
(
args
,
train_program
,
reader
,
data_loader
,
loss
,
trainer_id
,
weight
):
data_loader
.
set_batch_generator
(
convert_python_to_tensor
(
weight
,
args
.
batch_size
,
reader
.
train
()))
def
train_loop
(
args
,
train_program
,
data_loader
,
loss
,
trainer_id
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -211,14 +214,26 @@ def train(args):
filelist
,
0
,
1
)
logger
.
info
(
"dict_size: {}"
.
format
(
word2vec_reader
.
dict_size
))
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
loss
,
data_loader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
if
args
.
with_shuffle_batch
:
loss
,
data_loader
=
skip_gram_word2vec_shuffle_batch
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
data_loader
.
set_sample_generator
(
word2vec_reader
.
train
(),
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
else
:
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
loss
,
data_loader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
data_loader
.
set_batch_generator
(
convert_python_to_tensor
(
id_frequencys_pow
,
args
.
batch_size
,
word2vec_reader
.
train
()))
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
...
...
@@ -232,11 +247,10 @@ def train(args):
# do local training
logger
.
info
(
"run local training"
)
main_program
=
fluid
.
default_main_program
()
train_loop
(
args
,
main_program
,
word2vec_reader
,
data_loader
,
loss
,
0
,
id_frequencys_pow
)
train_loop
(
args
,
main_program
,
data_loader
,
loss
,
0
)
if
__name__
==
'__main__'
:
utils
.
check_version
()
args
=
parse_args
()
utils
.
check_version
(
args
.
with_shuffle_batch
)
train
(
args
)
PaddleRec/word2vec/utils.py
浏览文件 @
1370b9b6
...
...
@@ -27,7 +27,7 @@ def prepare_data(file_dir, dict_path, batch_size):
return
vocab_size
,
reader
,
i2w
def
check_version
():
def
check_version
(
with_shuffle_batch
=
False
):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
...
...
@@ -37,7 +37,10 @@ def check_version():
"Please make sure the version is good with your code."
\
try
:
fluid
.
require_version
(
'1.6.0'
)
if
with_shuffle_batch
:
fluid
.
require_version
(
'1.7.0'
)
else
:
fluid
.
require_version
(
'1.6.0'
)
except
Exception
as
e
:
logger
.
error
(
err
)
sys
.
exit
(
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录