Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
aaed4745
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aaed4745
编写于
11月 03, 2017
作者:
W
wangmeng28
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into fix-4388
上级
e65960ed
8b30e2ab
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
352 addition
and
16 deletion
+352
-16
benchmark/IntelOptimizedPaddle.md
benchmark/IntelOptimizedPaddle.md
+1
-1
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+3
-1
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+2
-1
paddle/scripts/travis/build_doc.sh
paddle/scripts/travis/build_doc.sh
+2
-2
python/paddle/v2/framework/initializer.py
python/paddle/v2/framework/initializer.py
+130
-1
python/paddle/v2/framework/layers.py
python/paddle/v2/framework/layers.py
+2
-6
python/paddle/v2/framework/nets.py
python/paddle/v2/framework/nets.py
+2
-1
python/paddle/v2/framework/tests/test_evaluator.py
python/paddle/v2/framework/tests/test_evaluator.py
+1
-0
python/paddle/v2/framework/tests/test_initializer.py
python/paddle/v2/framework/tests/test_initializer.py
+107
-0
python/paddle/v2/framework/tests/test_recommender_system.py
python/paddle/v2/framework/tests/test_recommender_system.py
+3
-3
python/paddle/v2/framework/tests/test_understand_sentiment_conv.py
...ddle/v2/framework/tests/test_understand_sentiment_conv.py
+99
-0
未找到文件。
benchmark/IntelOptimizedPaddle.md
浏览文件 @
aaed4745
...
...
@@ -23,7 +23,7 @@ On each machine, we will test and compare the performance of training on single
## Benchmark Model
### Server
Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148
M
CPU @ 2.40GHz
Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
Input image size - 3
* 224 *
224, Time: images/second
...
...
paddle/operators/lookup_table_op.h
浏览文件 @
aaed4745
...
...
@@ -90,11 +90,13 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto
*
d_output_data
=
d_output
->
data
<
T
>
();
auto
*
d_table_data
=
d_table
->
mutable_data
<
T
>
(
context
.
GetPlace
());
memset
(
d_table_data
,
0
,
d_table
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids
->
numel
();
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
);
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
d_table_data
[
ids_data
[
i
]
*
D
+
j
]
=
d_output_data
[
i
*
D
+
j
];
d_table_data
[
ids_data
[
i
]
*
D
+
j
]
+
=
d_output_data
[
i
*
D
+
j
];
}
}
}
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
aaed4745
...
...
@@ -42,7 +42,8 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
string
>
(
"pooltype"
,
"(int, default AVERAGE) the pooling pooltype of SequencePoolOp."
)
.
SetDefault
(
"AVERAGE"
);
.
SetDefault
(
"AVERAGE"
)
.
InEnum
({
"AVERAGE"
,
"SUM"
,
"SQRT"
,
"LAST"
,
"FIRST"
,
"MAX"
});
AddComment
(
R"DOC(
SequencePoolOp pools features of all time-steps of each instance.
...
...
paddle/scripts/travis/build_doc.sh
浏览文件 @
aaed4745
...
...
@@ -53,8 +53,8 @@ function deploy_docs() {
set
+e
rm
-rf
${
DIR
}
/doc
${
DIR
}
/doc_cn
set
-e
mv
../doc/cn/html
${
DIR
}
/doc_cn
mv
../doc/en/html
${
DIR
}
/doc
cp
-r
../doc/cn/html
${
DIR
}
/doc_cn
cp
-r
../doc/en/html
${
DIR
}
/doc
git add
.
}
...
...
python/paddle/v2/framework/initializer.py
浏览文件 @
aaed4745
import
paddle.v2.framework.framework
as
framework
import
numpy
as
np
__all__
=
[
'ConstantInitializer'
,
'UniformInitializer'
]
__all__
=
[
'ConstantInitializer'
,
'UniformInitializer'
,
'NormalInitializer'
,
'XavierInitializer'
]
class
Initializer
(
object
):
...
...
@@ -20,6 +24,41 @@ class Initializer(object):
"""
raise
NotImplementedError
()
def
_compute_fans
(
self
,
var
):
"""Compute the fan_in and the fan_out for layers
This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.
Args:
var: variable for which fan_in and fan_out have to be computed
Returns:
tuple of two integers (fan_in, fan_out)
"""
shape
=
var
.
shape
if
not
shape
or
len
(
shape
)
==
0
:
fan_in
=
fan_out
=
1
elif
len
(
shape
)
==
1
:
fan_in
=
fan_out
=
shape
[
0
]
elif
len
(
shape
)
==
2
:
# This is the case for simple matrix multiply
fan_in
=
shape
[
0
]
fan_out
=
shape
[
1
]
else
:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size
=
np
.
prod
(
shape
[
2
:])
fan_in
=
shape
[
1
]
*
receptive_field_size
fan_out
=
shape
[
0
]
*
receptive_field_size
return
(
fan_in
,
fan_out
)
class
ConstantInitializer
(
Initializer
):
"""Implements the constant initializer
...
...
@@ -156,3 +195,93 @@ class NormalInitializer(Initializer):
})
var
.
op
=
op
return
op
class
XavierInitializer
(
Initializer
):
"""Implements the Xavier initializer
This class implements the Xavier weight initializer from the paper
Understanding the difficulty of training deep feedforward neural
networks[1] by Xavier Glorot and Yoshua Bengio.
This initializer is designed to keep the scale of the gradients
approximately same in all the layers. In case of Uniform distribution,
the range is [-x, x], where x = sqrt(6 / (fan_in + fan_out)).
In case of Normal distribution, the mean is 0 and the standard deviation
is sqrt(2/ (fan_in + fan_out)).
References:
[1] Understanding the difficulty of training deep feedforward neural
networks. International conference on artificial intelligence and
statistics.
(http://proceedings.mlr.press/v9/glorot10a.html)
"""
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
fan_out
=
None
,
seed
=
0
):
"""Constructor for XavierInitializer
Args:
uniform: whether to use uniform or normal distribution
fan_in: fan_in for Xavier initialization. If None, it is
inferred from the variable.
fan_out: fan_out for Xavier initialization. If None, it is
inferred from the variable.
seed: random seed
Note: It is recommended to set fan_in and fan_out to None for
most cases.
"""
assert
uniform
is
not
None
assert
seed
is
not
None
super
(
XavierInitializer
,
self
).
__init__
()
self
.
_uniform
=
uniform
self
.
_fan_in
=
fan_in
self
.
_fan_out
=
fan_out
self
.
_seed
=
seed
def
__call__
(
self
,
var
,
block
):
"""Add xavier initialization ops for a variable
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
Returns:
the initialization op
"""
assert
isinstance
(
var
,
framework
.
Variable
)
assert
isinstance
(
block
,
framework
.
Block
)
f_in
,
f_out
=
self
.
_compute_fans
(
var
)
# If fan_in and fan_out are passed, use them
fan_in
=
f_in
if
self
.
_fan_in
is
None
else
self
.
_fan_in
fan_out
=
f_out
if
self
.
_fan_out
is
None
else
self
.
_fan_out
if
self
.
_uniform
:
limit
=
np
.
sqrt
(
6.0
/
float
(
fan_in
+
fan_out
))
op
=
block
.
prepend_op
(
type
=
"uniform_random"
,
outputs
=
{
"Out"
:
var
},
attrs
=
{
"shape"
:
var
.
shape
,
"data_type"
:
int
(
var
.
data_type
),
"min"
:
-
limit
,
"max"
:
limit
,
"seed"
:
self
.
_seed
})
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
op
=
block
.
prepend_op
(
type
=
"gaussian_random"
,
outputs
=
{
"Out"
:
var
},
attrs
=
{
"shape"
:
var
.
shape
,
"data_type"
:
int
(
var
.
data_type
),
"mean"
:
0.0
,
"std"
:
std
,
"seed"
:
self
.
_seed
})
var
.
op
=
op
return
op
python/paddle/v2/framework/layers.py
浏览文件 @
aaed4745
...
...
@@ -278,6 +278,7 @@ def sequence_conv(input,
num_filters
,
filter_size
=
3
,
filter_stride
=
1
,
act
=
None
,
padding
=
None
,
bias_attr
=
None
,
param_attr
=
None
,
...
...
@@ -304,7 +305,7 @@ def sequence_conv(input,
outputs
=
{
"Out"
:
pre_bias
},
attrs
=
{
'contextStride'
:
filter_stride
,
'contextStart'
:
0
,
'contextStart'
:
-
int
(
filter_size
/
2
)
,
'contextLength'
:
filter_size
})
pre_act
=
helper
.
append_bias_op
(
pre_bias
)
...
...
@@ -364,11 +365,6 @@ def conv2d(input,
def
sequence_pool
(
input
,
pool_type
,
**
kwargs
):
ENUM_POOL_TYPE
=
set
([
"MAX"
,
"AVG"
,
"SQRT"
,
"LAST"
,
"FIRST"
])
if
pool_type
.
upper
()
not
in
ENUM_POOL_TYPE
:
raise
ValueError
(
"Unknown pool_type: '%s'. It can only be %s."
,
str
(
pool_type
),
" "
.
join
(
ENUM_POOL_TYPE
))
helper
=
LayerHelper
(
'sequence_pool'
,
input
=
input
,
**
kwargs
)
dtype
=
helper
.
input_dtype
()
pool_out
=
helper
.
create_tmp_variable
(
dtype
)
...
...
python/paddle/v2/framework/nets.py
浏览文件 @
aaed4745
...
...
@@ -109,6 +109,7 @@ def sequence_conv_pool(input,
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
act
=
act
,
program
=
program
,
init_program
=
init_program
)
...
...
python/paddle/v2/framework/tests/test_evaluator.py
浏览文件 @
aaed4745
...
...
@@ -60,4 +60,5 @@ class TestEvaluator(unittest.TestCase):
if
__name__
==
'__main__'
:
exit
(
0
)
unittest
.
main
()
python/paddle/v2/framework/tests/test_initializer.py
浏览文件 @
aaed4745
import
numpy
as
np
import
unittest
import
paddle.v2.framework.framework
as
framework
...
...
@@ -116,5 +117,111 @@ class TestNormalInitializer(unittest.TestCase):
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
123
)
class
TestXavierInitializer
(
unittest
.
TestCase
):
def
test_uniform_xavier_initializer
(
self
):
"""Test Xavier initializer with uniform distribution on
for matrix multiply.
"""
program
=
framework
.
Program
()
block
=
program
.
global_block
()
param
=
block
.
create_parameter
(
dtype
=
"float32"
,
shape
=
[
5
,
10
],
lod_level
=
0
,
name
=
"param"
,
initializer
=
initializer
.
XavierInitializer
())
self
.
assertEqual
(
len
(
block
.
ops
),
1
)
init_op
=
block
.
ops
[
0
]
self
.
assertEqual
(
init_op
.
type
,
'uniform_random'
)
limit
=
np
.
sqrt
(
6.0
/
(
param
.
shape
[
0
]
+
param
.
shape
[
1
]))
self
.
assertAlmostEqual
(
init_op
.
attr
(
'min'
),
-
limit
,
delta
=
DELTA
)
self
.
assertAlmostEqual
(
init_op
.
attr
(
'max'
),
limit
,
delta
=
DELTA
)
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
0
)
def
test_uniform_xavier_initializer_conv
(
self
):
"""Test Xavier initializer with uniform distribution on
for convolutions.
"""
program
=
framework
.
Program
()
block
=
program
.
global_block
()
param
=
block
.
create_parameter
(
dtype
=
"float32"
,
shape
=
[
5
,
10
,
15
,
20
],
lod_level
=
0
,
name
=
"param"
,
initializer
=
initializer
.
XavierInitializer
())
self
.
assertEqual
(
len
(
block
.
ops
),
1
)
init_op
=
block
.
ops
[
0
]
self
.
assertEqual
(
init_op
.
type
,
'uniform_random'
)
receptive_field_size
=
float
(
15
*
20
)
limit
=
np
.
sqrt
(
6.0
/
(
(
param
.
shape
[
0
]
+
param
.
shape
[
1
])
*
receptive_field_size
))
self
.
assertAlmostEqual
(
init_op
.
attr
(
'min'
),
-
limit
,
delta
=
DELTA
)
self
.
assertAlmostEqual
(
init_op
.
attr
(
'max'
),
limit
,
delta
=
DELTA
)
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
0
)
def
test_normal_xavier_initializer
(
self
):
"""Test Xavier initializer with normal distribution on
for matrix multiply.
"""
program
=
framework
.
Program
()
block
=
program
.
global_block
()
param
=
block
.
create_parameter
(
dtype
=
"float32"
,
shape
=
[
5
,
10
],
lod_level
=
0
,
name
=
"param"
,
initializer
=
initializer
.
XavierInitializer
(
uniform
=
False
))
self
.
assertEqual
(
len
(
block
.
ops
),
1
)
init_op
=
block
.
ops
[
0
]
self
.
assertEqual
(
init_op
.
type
,
'gaussian_random'
)
std
=
np
.
sqrt
(
2.0
/
(
param
.
shape
[
0
]
+
param
.
shape
[
1
]))
self
.
assertAlmostEqual
(
init_op
.
attr
(
'mean'
),
0.0
,
delta
=
DELTA
)
self
.
assertAlmostEqual
(
init_op
.
attr
(
'std'
),
std
,
delta
=
DELTA
)
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
0
)
def
test_normal_xavier_initializer_conv
(
self
):
"""Test Xavier initializer with normal distribution on
for convolutions.
"""
program
=
framework
.
Program
()
block
=
program
.
global_block
()
param
=
block
.
create_parameter
(
dtype
=
"float32"
,
shape
=
[
5
,
10
,
15
,
20
],
lod_level
=
0
,
name
=
"param"
,
initializer
=
initializer
.
XavierInitializer
(
uniform
=
False
))
self
.
assertEqual
(
len
(
block
.
ops
),
1
)
init_op
=
block
.
ops
[
0
]
self
.
assertEqual
(
init_op
.
type
,
'gaussian_random'
)
receptive_field_size
=
float
(
15
*
20
)
std
=
np
.
sqrt
(
2.0
/
(
(
param
.
shape
[
0
]
+
param
.
shape
[
1
])
*
receptive_field_size
))
self
.
assertAlmostEqual
(
init_op
.
attr
(
'mean'
),
0.0
,
delta
=
DELTA
)
self
.
assertAlmostEqual
(
init_op
.
attr
(
'std'
),
std
,
delta
=
DELTA
)
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
0
)
def
test_xavier_initializer_supplied_arguments
(
self
):
"""Test the Xavier initializer with supplied arguments
"""
program
=
framework
.
Program
()
block
=
program
.
global_block
()
block
.
create_parameter
(
dtype
=
"float32"
,
shape
=
[
5
,
10
],
lod_level
=
0
,
name
=
"param"
,
initializer
=
initializer
.
XavierInitializer
(
fan_in
=
12
,
fan_out
=
23
,
seed
=
134
))
self
.
assertEqual
(
len
(
block
.
ops
),
1
)
init_op
=
block
.
ops
[
0
]
self
.
assertEqual
(
init_op
.
type
,
'uniform_random'
)
limit
=
np
.
sqrt
(
6.0
/
(
12
+
23
))
self
.
assertAlmostEqual
(
init_op
.
attr
(
'min'
),
-
limit
,
delta
=
DELTA
)
self
.
assertAlmostEqual
(
init_op
.
attr
(
'max'
),
limit
,
delta
=
DELTA
)
self
.
assertEqual
(
init_op
.
attr
(
'seed'
),
134
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_recommender_system.py
浏览文件 @
aaed4745
...
...
@@ -243,7 +243,7 @@ def model():
def
main
():
cost
=
model
()
sgd_optimizer
=
optimizer
.
SGDOptimizer
(
learning_rate
=
0.2
)
opts
=
sgd_optimizer
.
minimize
(
cost
)
opts
=
sgd_optimizer
.
minimize
(
cost
,
init_program
=
init_program
)
block
=
program
.
block
(
0
)
if
use_gpu
:
...
...
@@ -305,8 +305,8 @@ def main():
feed
=
func_feed
(
feeding
,
data
),
fetch_list
=
[
cost
])
out
=
np
.
array
(
outs
[
0
])
if
out
[
0
]
<
5
.0
:
# if avg cost less than
10
.0, we think our code is good.
if
out
[
0
]
<
6
.0
:
# if avg cost less than
6
.0, we think our code is good.
exit
(
0
)
...
...
python/paddle/v2/framework/tests/test_understand_sentiment_conv.py
0 → 100644
浏览文件 @
aaed4745
import
paddle.v2
as
paddle
import
paddle.v2.framework.layers
as
layers
import
paddle.v2.framework.nets
as
nets
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.optimizer
as
optimizer
from
paddle.v2.framework.framework
import
Program
,
g_program
,
g_init_program
from
paddle.v2.framework.executor
import
Executor
import
numpy
as
np
def
convolution_net
(
input_dim
,
class_dim
=
2
,
emb_dim
=
32
,
hid_dim
=
32
):
data
=
layers
.
data
(
name
=
"words"
,
shape
=
[
1
],
data_type
=
"int64"
)
label
=
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
data_type
=
"int64"
)
emb
=
layers
.
embedding
(
input
=
data
,
size
=
[
input_dim
,
emb_dim
])
conv_3
=
nets
.
sequence_conv_pool
(
input
=
emb
,
num_filters
=
hid_dim
,
filter_size
=
3
,
act
=
"tanh"
,
pool_type
=
"sqrt"
)
conv_4
=
nets
.
sequence_conv_pool
(
input
=
emb
,
num_filters
=
hid_dim
,
filter_size
=
4
,
act
=
"tanh"
,
pool_type
=
"sqrt"
)
prediction
=
layers
.
fc
(
input
=
[
conv_3
,
conv_4
],
size
=
class_dim
,
act
=
"softmax"
)
cost
=
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
layers
.
mean
(
x
=
cost
)
adam_optimizer
=
optimizer
.
AdamOptimizer
(
learning_rate
=
0.002
)
opts
=
adam_optimizer
.
minimize
(
avg_cost
)
acc
=
layers
.
accuracy
(
input
=
prediction
,
label
=
label
)
return
avg_cost
,
acc
def
to_lodtensor
(
data
,
place
):
seq_lens
=
[
len
(
seq
)
for
seq
in
data
]
cur_len
=
0
lod
=
[
cur_len
]
for
l
in
seq_lens
:
cur_len
+=
l
lod
.
append
(
cur_len
)
flattened_data
=
np
.
concatenate
(
data
,
axis
=
0
).
astype
(
"int64"
)
flattened_data
=
flattened_data
.
reshape
([
len
(
flattened_data
),
1
])
res
=
core
.
LoDTensor
()
res
.
set
(
flattened_data
,
place
)
res
.
set_lod
([
lod
])
return
res
def
main
():
BATCH_SIZE
=
100
PASS_NUM
=
5
word_dict
=
paddle
.
dataset
.
imdb
.
word_dict
()
dict_dim
=
len
(
word_dict
)
class_dim
=
2
cost
,
acc
=
convolution_net
(
input_dim
=
dict_dim
,
class_dim
=
class_dim
)
train_data
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
imdb
.
train
(
word_dict
),
buf_size
=
1000
),
batch_size
=
BATCH_SIZE
)
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
g_init_program
)
for
pass_id
in
xrange
(
PASS_NUM
):
for
data
in
train_data
():
tensor_words
=
to_lodtensor
(
map
(
lambda
x
:
x
[
0
],
data
),
place
)
label
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
label
=
label
.
reshape
([
BATCH_SIZE
,
1
])
tensor_label
=
core
.
LoDTensor
()
tensor_label
.
set
(
label
,
place
)
outs
=
exe
.
run
(
g_program
,
feed
=
{
"words"
:
tensor_words
,
"label"
:
tensor_label
},
fetch_list
=
[
cost
,
acc
])
cost_val
=
np
.
array
(
outs
[
0
])
acc_val
=
np
.
array
(
outs
[
1
])
print
(
"cost="
+
str
(
cost_val
)
+
" acc="
+
str
(
acc_val
))
if
cost_val
<
1.0
and
acc_val
>
0.7
:
exit
(
0
)
exit
(
1
)
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录