Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ae0ea541
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ae0ea541
编写于
1月 24, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix unit test
上级
ca017719
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
215 addition
and
57 deletion
+215
-57
paddle/operators/layer_norm_op.cc
paddle/operators/layer_norm_op.cc
+7
-4
python/paddle/v2/fluid/tests/test_layer_norm_op.py
python/paddle/v2/fluid/tests/test_layer_norm_op.py
+208
-53
未找到文件。
paddle/operators/layer_norm_op.cc
浏览文件 @
ae0ea541
...
...
@@ -233,13 +233,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if
(
d_x
)
{
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
d_x_map
=
EigenMatrixMapRowMajor
<
T
>
(
d_x
->
data
<
T
>
(),
left
,
right
);
auto
triple_product
=
[](
T
ele
)
{
return
ele
*
ele
*
ele
;
};
auto
neg_inv_std
=
[](
T
ele
)
{
return
T
(
-
1.0
)
*
std
::
sqrt
(
1
/
ele
);
};
auto
triple_product
=
[](
T
ele
)
{
return
ele
*
ele
;
};
auto
neg_inv_std
=
[](
T
ele
)
{
return
-
std
::
sqrt
(
1
/
ele
);
};
auto
inv_std_scale_func
=
[
scale_data
](
T
ele
)
{
return
std
::
sqrt
(
1
/
ele
)
*
scale_data
;
};
auto
neg_inv_std_scale_func
=
[
scale_data
](
T
ele
)
{
return
T
(
-
1.0
)
*
std
::
sqrt
(
1
/
ele
)
*
scale_data
;
return
-
std
::
sqrt
(
1
/
ele
)
*
scale_data
;
};
// dy_dx
auto
dx_end
=
var_map
.
unaryExpr
(
inv_std_scale_func
)
...
...
@@ -260,10 +260,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto
dvar_end
=
var_map
.
unaryExpr
(
neg_inv_std
)
.
unaryExpr
(
triple_product
)
.
cwiseProduct
(
dvar_end_0
);
auto
dx_var
=
(
1.0
f
/
right
)
*
auto
dx_var
=
(
T
(
1.0
)
/
right
)
*
(
x_map
-
mean_map
.
replicate
(
1
,
right
))
.
cwiseProduct
(
dvar_end
.
replicate
(
1
,
right
));
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
d_x_map
=
dx_end
+
dx_mean
+
dx_var
;
}
}
...
...
python/paddle/v2/fluid/tests/test_layer_norm_op.py
浏览文件 @
ae0ea541
...
...
@@ -15,66 +15,221 @@
import
unittest
import
numpy
as
np
from
operator
import
mul
from
op_test
import
OpTest
import
paddle.v2.fluid.core
as
core
from
paddle.v2.fluid.op
import
Operator
from
paddle.v2.fluid.framework
import
grad_var_name
def
layer_norm_naive
(
x
,
scale
,
beta
,
epsilon
):
n
,
c
,
h
,
w
=
x
.
shape
mean
=
np
.
mean
(
x
,
axis
=
(
1
,
2
,
3
))
var
=
np
.
var
(
x
,
axis
=
(
1
,
2
,
3
))
+
epsilon
output
=
scale
*
np
.
divide
((
x
-
mean
.
reshape
([
n
,
1
,
1
,
1
])),
(
np
.
sqrt
(
var
)).
reshape
([
n
,
1
,
1
,
1
]))
+
beta
def
get_backward_op
(
scope
,
op
,
no_grad_set
):
backward_op
=
core
.
Operator
.
backward
(
op
,
no_grad_set
)
for
input
in
backward_op
.
input_vars
():
var
=
scope
.
var
(
input
)
var
.
get_tensor
()
for
output
in
backward_op
.
output_vars
():
var
=
scope
.
var
(
output
)
var
.
get_tensor
()
return
backward_op
def
_reference_layer_norm_naive
(
x
,
scale
,
beta
,
epsilon
):
old_shape
=
x
.
shape
N
=
x
.
shape
[
0
]
D
=
reduce
(
mul
,
old_shape
,
1
)
/
N
x
.
shape
=
[
N
,
D
]
mean
=
np
.
mean
(
x
,
axis
=
1
)
var
=
np
.
var
(
x
,
axis
=
1
)
+
epsilon
output
=
scale
*
np
.
divide
((
x
-
mean
.
reshape
([
N
,
1
])),
(
np
.
sqrt
(
var
)).
reshape
([
N
,
1
]))
+
beta
output
.
shape
=
old_shape
return
output
,
mean
,
var
def
_reference_layer_norm_grad
(
x
,
grad_y
,
scale
,
mean
,
var
,
epsilon
):
x_shape
=
x
.
shape
N
=
x_shape
[
0
]
D
=
reduce
(
mul
,
x_shape
,
1
)
/
N
grad_y
.
shape
=
[
N
,
D
]
x
.
shape
=
[
N
,
D
]
grad_offset
=
np
.
sum
(
grad_y
)
mean
.
shape
=
[
N
,
1
]
var
.
shape
=
[
N
,
1
]
grad_scale
=
np
.
sum
(((
x
-
mean
)
*
np
.
sqrt
(
1
/
var
))
*
grad_y
)
dx_end
=
np
.
sqrt
(
1.0
/
var
)
*
grad_y
d_mean_0
=
np
.
sum
(
-
np
.
sqrt
(
1.0
/
var
)
*
grad_y
,
axis
=
1
).
reshape
([
N
,
1
])
d_mean_1
=
np
.
sum
(
-
1.0
/
var
*
(
x
-
mean
)
*
grad_y
,
axis
=
1
).
reshape
(
[
N
,
1
])
*
(
-
1.0
/
D
*
np
.
sqrt
(
1.0
/
var
)
*
np
.
sum
(
x
-
mean
,
axis
=
1
).
reshape
([
N
,
1
])).
reshape
([
N
,
1
])
d_mean
=
1.0
/
D
*
(
d_mean_0
+
d_mean_1
)
d_std
=
np
.
sum
(
-
1.0
/
var
*
(
x
-
mean
)
*
grad_y
,
axis
=
1
).
reshape
([
N
,
1
])
*
(
1.0
/
D
*
np
.
sqrt
(
1.0
/
var
).
reshape
([
N
,
1
])
*
(
x
-
mean
))
grad_x
=
scale
*
(
dx_end
+
d_mean
+
d_std
)
grad_y
.
shape
=
x_shape
x
.
shape
=
x_shape
return
grad_x
,
grad_scale
,
grad_offset
def
create_or_get_tensor
(
scope
,
var_name
,
var
,
place
):
tensor
=
scope
.
var
(
var_name
).
get_tensor
()
if
var
is
not
None
:
assert
isinstance
(
var
,
np
.
ndarray
)
tensor
.
set_lod
([[]])
tensor
.
set_dims
(
var
.
shape
)
tensor
.
set
(
var
,
place
)
return
tensor
def
set_output_grad
(
scope
,
outputs
,
place
,
feed_dict
=
None
):
def
__set_tensor__
(
name
,
data
=
None
):
out_tensor
=
scope
.
find_var
(
name
).
get_tensor
()
grad_tensor
=
scope
.
var
(
grad_var_name
(
name
)).
get_tensor
()
out_dtype
=
out_tensor
.
dtype
()
if
data
is
None
:
if
out_dtype
==
core
.
DataType
.
FP64
:
data
=
np
.
ones
(
out_tensor
.
shape
(),
dtype
=
np
.
float64
)
elif
out_dtype
==
core
.
DataType
.
FP32
:
data
=
np
.
ones
(
out_tensor
.
shape
(),
dtype
=
np
.
float32
)
else
:
raise
ValueError
(
"Not supported data type "
+
str
(
out_dtype
))
grad_tensor
.
set
(
data
,
place
)
for
output
in
outputs
:
data
=
None
if
output
in
feed_dict
:
data
=
feed_dict
[
output
]
__set_tensor__
(
output
,
data
)
class
TestLayerNormdOp
(
OpTest
):
def
setUp
(
self
):
self
.
init_test_case
()
input
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
input
,
'Scale'
:
np
.
array
([
self
.
scale
]).
astype
(
"float32"
),
'Bias'
:
np
.
array
([
self
.
bias
]).
astype
(
"float32"
)
}
output
,
mean
,
var
=
layer_norm_naive
(
input
,
self
.
scale
,
self
.
bias
,
self
.
epsilon
)
self
.
outputs
=
{
'Y'
:
output
,
'Mean'
:
mean
,
'Variance'
:
var
}
def
test_check_output
(
self
):
self
.
check_output
()
# def test_check_grad(self):
# self.check_grad(
# ['Scale', 'Bias', 'X'], ['Y', 'Mean', 'Variance'],
# max_relative_error=0.02)
def
test_check_grad_no_x
(
self
):
self
.
check_grad
(
[
'Scale'
,
'Bias'
],
[
'Y'
,
'Mean'
,
'Variance'
],
max_relative_error
=
0.02
,
no_grad_set
=
set
([
'X'
]))
# def test_check_grad_no_scale(self):
# self.check_grad(
# ['Bias','X'],
# 'Y',
# max_relative_error=0.02,
# no_grad_set=set(['Scale']))
#
# def test_check_grad_no_bias(self):
# self.check_grad(
# ['Scale','X'],
# 'Y',
# max_relative_error=0.02,
# no_grad_set=set(['Bias']))
def
init_test_case
(
self
):
self
.
op_type
=
"layer_norm"
self
.
input_size
=
[
2
,
3
,
4
,
5
]
self
.
scale
=
0.21
self
.
bias
=
0.1
self
.
epsilon
=
0.00001
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
).
reshape
(
np_array
.
shape
),
np_array
,
atol
=
atol
),
msg
)
def
__assert_grad_close
(
self
,
tensor
,
np_array
,
name
,
place
,
max_relative_error
=
0.02
):
a
=
np
.
array
(
tensor
).
reshape
(
np_array
.
shape
)
b
=
np_array
abs_a
=
np
.
abs
(
a
)
abs_a
[
abs_a
<
1e-5
]
=
1
diff_mat
=
np
.
abs
(
a
-
b
)
/
abs_a
max_diff
=
np
.
max
(
diff_mat
)
def
err_msg
():
offset
=
np
.
argmax
(
diff_mat
>
max_relative_error
)
return
(
"%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d, %f, %f"
)
%
(
"Gradient Check On %s"
%
str
(
place
),
name
,
max_diff
,
max_relative_error
,
offset
,
a
.
flatten
()[
offset
],
b
.
flatten
()[
offset
])
self
.
assertLessEqual
(
max_diff
,
max_relative_error
,
err_msg
())
def
test_forward_backward
(
self
):
def
test_with_place
(
place
,
shape
):
# attr
epsilon
=
0.00001
x_shape
=
shape
scale_shape
=
[
1
]
x_val
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
np
.
float32
)
scale_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
bias_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
# run forward
y_out
,
saved_mean
,
var_ref
=
_reference_layer_norm_naive
(
x_val
,
scale_val
,
bias_val
,
epsilon
)
# for gradient test
# y_grad = np.ones(x_shape).astype(np.float32) * 0.00277778
y_grad
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
np
.
float32
)
x_grad_ref
,
scale_grad_ref
,
bias_grad_ref
=
_reference_layer_norm_grad
(
x_val
,
y_grad
,
scale_val
,
saved_mean
,
var_ref
,
epsilon
)
scope
=
core
.
Scope
()
# create input
x_tensor
=
create_or_get_tensor
(
scope
,
"X"
,
x_val
,
place
)
scale_tensor
=
create_or_get_tensor
(
scope
,
"Scale"
,
scale_val
,
place
)
bias_tensor
=
create_or_get_tensor
(
scope
,
"Bias"
,
bias_val
,
place
)
# create output
y_tensor
=
create_or_get_tensor
(
scope
,
"Y"
,
None
,
place
)
mean_tensor
=
create_or_get_tensor
(
scope
,
"Mean"
,
None
,
place
)
variance_tensor
=
create_or_get_tensor
(
scope
,
"Variance"
,
None
,
place
)
layer_norm_op
=
Operator
(
"layer_norm"
,
# inputs
X
=
"X"
,
Scale
=
"Scale"
,
Bias
=
"Bias"
,
# outputs
Y
=
"Y"
,
Mean
=
"Mean"
,
Variance
=
"Variance"
,
# attrs
epsilon
=
epsilon
)
layer_norm_op
.
run
(
scope
,
place
)
# check forward result
if
isinstance
(
place
,
core
.
CUDAPlace
):
atol
=
5e-2
else
:
atol
=
1e-4
self
.
__assert_close
(
y_tensor
,
y_out
,
"Y"
,
atol
)
self
.
__assert_close
(
mean_tensor
,
saved_mean
,
"Mean"
,
atol
)
self
.
__assert_close
(
variance_tensor
,
var_ref
,
"Variance"
,
atol
)
# run backward
layer_norm_op_grad
=
get_backward_op
(
scope
,
layer_norm_op
,
set
())
set_output_grad
(
scope
,
[
"Y"
,
"Mean"
,
"Variance"
],
place
,
feed_dict
=
{
"Y"
:
y_grad
})
layer_norm_op_grad
.
run
(
scope
,
place
)
x_grad_tensor
=
create_or_get_tensor
(
scope
,
grad_var_name
(
"X"
),
None
,
place
)
scale_grad_tensor
=
create_or_get_tensor
(
scope
,
grad_var_name
(
"Scale"
),
None
,
place
)
bias_grad_tensor
=
create_or_get_tensor
(
scope
,
grad_var_name
(
"Bias"
),
None
,
place
)
# check gradient output
self
.
__assert_grad_close
(
x_grad_tensor
,
x_grad_ref
,
"x_grad"
,
place
)
self
.
__assert_grad_close
(
scale_grad_tensor
,
scale_grad_ref
,
"scale_grad"
,
place
)
self
.
__assert_grad_close
(
bias_grad_tensor
,
bias_grad_ref
,
"bias_grad"
,
place
)
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compile_gpu
()
and
core
.
op_support_gpu
(
"layer_norm"
):
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
test_with_place
(
place
,
[
2
,
3
,
4
,
5
])
test_with_place
(
place
,
[
2
,
3
])
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录