Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ed2bc194
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ed2bc194
编写于
3月 20, 2018
作者:
K
Kexin Zhao
提交者:
GitHub
3月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9176 from kexinzhao/batch_norm_fp16
Add float16 support to batch norm operator
上级
cd07c0f0
6ec0f912
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
244 addition
and
25 deletion
+244
-25
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+23
-0
paddle/fluid/operators/batch_norm_op.cu.cc
paddle/fluid/operators/batch_norm_op.cu.cc
+31
-19
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+1
-0
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+1
-0
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+6
-3
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+182
-3
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
ed2bc194
...
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"SavedVariance"
,
{
C
});
ctx
->
SetOutputDim
(
"SavedVariance"
,
{
C
});
ctx
->
ShareLoD
(
"X"
,
"Y"
);
ctx
->
ShareLoD
(
"X"
,
"Y"
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
auto
bn_param_type
=
framework
::
proto
::
VarType
::
FP32
;
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
()),
"Scale input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
()),
"Bias input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Mean"
)
->
type
()),
"Mean input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
"Variance input should be of float type"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
};
class
BatchNormOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
BatchNormOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/batch_norm_op.cu.cc
浏览文件 @
ed2bc194
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <cfloat>
#include <cfloat>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
...
@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
using
DataLayout
=
framework
::
DataLayout
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
using
BatchNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
void
ExtractNCWHD
(
const
framework
::
DDim
&
dims
,
const
DataLayout
&
data_layout
,
void
ExtractNCWHD
(
const
framework
::
DDim
&
dims
,
const
DataLayout
&
data_layout
,
int
*
N
,
int
*
C
,
int
*
H
,
int
*
W
,
int
*
D
)
{
int
*
N
,
int
*
C
,
int
*
H
,
int
*
W
,
int
*
D
)
{
...
@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
// Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
bn_param_desc_
,
data_desc_
,
is_test
?
CUDNN_BATCHNORM_SPATIAL
:
mode_
));
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
...
@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory
// alloc memory
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean_out
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
variance_out
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
saved_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
saved_mean
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
saved_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
saved_variance
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
BatchNormParamType
<
T
>>
functor
(
dev_ctx
,
saved_mean
,
0
);
functor
;
functor
(
dev_ctx
,
saved_variance
,
0
);
functor
(
dev_ctx
,
saved_mean
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
functor
(
dev_ctx
,
saved_variance
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
...
@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
data_desc_
,
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
T
>(),
bias
->
template
data
<
T
>(),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
T
>(),
est_var
->
template
data
<
T
>(),
epsilon
));
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
}
else
{
}
else
{
// Run training mode.
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// obtain running mean and running inv var, and see if we need to
...
@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
T
>(),
bias
->
template
data
<
T
>(),
this_factor
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
mean_out
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
this_factor
,
variance_out
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
epsilon
,
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
saved_mean
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
())));
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())));
}
}
// clean when exit.
// clean when exit.
...
@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
p
addle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
BatchNormKernel
<
p
lat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad
,
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
BatchNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/math/math_function.cc
浏览文件 @
ed2bc194
...
@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
...
@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
cblas_daxpy
(
n
,
alpha
,
x
,
1
,
y
,
1
);
cblas_daxpy
(
n
,
alpha
,
x
,
1
,
y
,
1
);
}
}
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
CPUDeviceContext
,
int
>;
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
ed2bc194
...
@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
...
@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
&
alpha
,
x
,
1
,
y
,
1
));
&
alpha
,
x
,
1
,
y
,
1
));
}
}
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int
>;
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
ed2bc194
...
@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
...
@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
public:
public:
static
const
cudnnDataType_t
type
=
CUDNN_DATA_HALF
;
static
const
cudnnDataType_t
type
=
CUDNN_DATA_HALF
;
// The scaling param type is float for HALF and FLOAT tensors
// The scaling param type is float for HALF and FLOAT tensors
typedef
const
float
ScalingParamType
;
using
ScalingParamType
=
const
float
;
using
BatchNormParamType
=
float
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
;
static
ScalingParamType
v
=
1.0
;
return
&
v
;
return
&
v
;
...
@@ -101,7 +102,8 @@ template <>
...
@@ -101,7 +102,8 @@ template <>
class
CudnnDataType
<
float
>
{
class
CudnnDataType
<
float
>
{
public:
public:
static
const
cudnnDataType_t
type
=
CUDNN_DATA_FLOAT
;
static
const
cudnnDataType_t
type
=
CUDNN_DATA_FLOAT
;
typedef
const
float
ScalingParamType
;
using
ScalingParamType
=
const
float
;
using
BatchNormParamType
=
float
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
;
static
ScalingParamType
v
=
1.0
;
return
&
v
;
return
&
v
;
...
@@ -116,7 +118,8 @@ template <>
...
@@ -116,7 +118,8 @@ template <>
class
CudnnDataType
<
double
>
{
class
CudnnDataType
<
double
>
{
public:
public:
static
const
cudnnDataType_t
type
=
CUDNN_DATA_DOUBLE
;
static
const
cudnnDataType_t
type
=
CUDNN_DATA_DOUBLE
;
typedef
const
double
ScalingParamType
;
using
ScalingParamType
=
const
double
;
using
BatchNormParamType
=
double
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
;
static
ScalingParamType
v
=
1.0
;
return
&
v
;
return
&
v
;
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
ed2bc194
...
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
...
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
return
backward_op
return
backward_op
def
_reference_testing
(
x
,
scale
,
offset
,
mean
,
var
,
epsilon
,
data_format
):
x_shape
=
x
.
shape
if
len
(
x_shape
)
==
2
:
if
data_format
==
"NCHW"
:
x
=
np
.
reshape
(
x
,
(
x
.
shape
[
0
],
x
.
shape
[
1
],
1
,
1
))
else
:
x
=
np
.
reshape
(
x
,
(
x
.
shape
[
0
],
1
,
1
,
x
.
shape
[
1
]))
if
data_format
==
"NCHW"
:
n
,
c
,
h
,
w
=
x
.
shape
mean_tile
=
np
.
reshape
(
mean
,
(
1
,
c
,
1
,
1
))
mean_tile
=
np
.
tile
(
mean_tile
,
(
n
,
1
,
h
,
w
))
var_tile
=
np
.
reshape
(
var
,
(
1
,
c
,
1
,
1
))
var_tile
=
np
.
tile
(
var_tile
,
(
n
,
1
,
h
,
w
))
normalized
=
(
x
-
mean_tile
)
/
np
.
sqrt
(
var_tile
+
epsilon
)
scale_tile
=
np
.
reshape
(
scale
,
(
1
,
c
,
1
,
1
))
scale_tile
=
np
.
tile
(
scale_tile
,
(
n
,
1
,
h
,
w
))
offset_tile
=
np
.
reshape
(
offset
,
(
1
,
c
,
1
,
1
))
offset_tile
=
np
.
reshape
(
offset_tile
,
(
1
,
c
,
1
,
1
))
y
=
normalized
*
scale_tile
+
offset_tile
elif
data_format
==
"NHWC"
:
normalized
=
(
x
-
mean
)
/
np
.
sqrt
(
var
+
epsilon
)
y
=
normalized
*
scale
+
offset
else
:
raise
ValueError
(
"Unknown data order."
)
if
len
(
x_shape
)
==
2
:
y
=
np
.
reshape
(
y
,
x_shape
)
return
y
def
_reference_training
(
x
,
scale
,
offset
,
epsilon
,
data_format
):
def
_reference_training
(
x
,
scale
,
offset
,
epsilon
,
data_format
):
x_shape
=
x
.
shape
x_shape
=
x
.
shape
if
len
(
x_shape
)
==
2
:
if
len
(
x_shape
)
==
2
:
...
@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
...
@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
__set_tensor__
(
output
,
data
)
__set_tensor__
(
output
,
data
)
class
TestBatchNormOp
(
OpTest
):
class
TestBatchNormOpInference
(
OpTest
):
def
setUp
(
self
):
self
.
dtype
=
np
.
float32
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
def
test_python
(
self
):
def
check_with_place
(
self
,
place
,
data_layout
,
dtype
,
shape
):
epsilon
=
0.00001
if
len
(
shape
)
==
2
:
x_shape
=
shape
c
=
x_shape
[
1
]
else
:
n
,
h
,
w
,
c
=
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]
if
data_layout
==
"NHWC"
:
x_shape
=
[
n
,
h
,
w
,
c
]
elif
data_layout
==
"NCHW"
:
x_shape
=
[
n
,
c
,
h
,
w
]
else
:
raise
ValueError
(
"Unknown data layout."
)
scale_shape
=
[
c
]
x_val
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
dtype
)
scale_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
bias_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
mean
=
np
.
zeros
(
scale_shape
).
astype
(
np
.
float32
)
variance
=
np
.
ones
(
scale_shape
).
astype
(
np
.
float32
)
y_out
=
_reference_testing
(
x_val
,
scale_val
,
bias_val
,
mean
,
variance
,
epsilon
,
data_layout
).
astype
(
dtype
)
scope
=
core
.
Scope
()
# create input
x_tensor
=
create_or_get_tensor
(
scope
,
"x_val"
,
OpTest
.
np_dtype_to_fluid_dtype
(
x_val
),
place
)
scale_tensor
=
create_or_get_tensor
(
scope
,
"scale_val"
,
OpTest
.
np_dtype_to_fluid_dtype
(
scale_val
),
place
)
bias_tensor
=
create_or_get_tensor
(
scope
,
"bias_val"
,
OpTest
.
np_dtype_to_fluid_dtype
(
bias_val
),
place
)
mean_tensor
=
create_or_get_tensor
(
scope
,
"mean"
,
OpTest
.
np_dtype_to_fluid_dtype
(
mean
),
place
)
variance_tensor
=
create_or_get_tensor
(
scope
,
"variance"
,
OpTest
.
np_dtype_to_fluid_dtype
(
variance
),
place
)
# create output
y_tensor
=
create_or_get_tensor
(
scope
,
"y_out"
,
None
,
place
)
saved_mean_tensor
=
create_or_get_tensor
(
scope
,
"saved_mean"
,
None
,
place
)
saved_variance_tensor
=
create_or_get_tensor
(
scope
,
"saved_variance"
,
None
,
place
)
mean_out_tensor
=
mean_tensor
variance_out_tensor
=
variance_tensor
batch_norm_op
=
Operator
(
"batch_norm"
,
# inputs
X
=
"x_val"
,
Scale
=
"scale_val"
,
Bias
=
"bias_val"
,
Mean
=
"mean"
,
Variance
=
"variance"
,
# outputs
Y
=
"y_out"
,
MeanOut
=
"mean"
,
VarianceOut
=
"variance"
,
SavedMean
=
"saved_mean"
,
SavedVariance
=
"saved_variance"
,
# attrs
is_test
=
True
,
data_layout
=
data_layout
,
epsilon
=
epsilon
)
batch_norm_op
.
run
(
scope
,
place
)
# check inference result
self
.
__assert_close
(
y_tensor
,
y_out
,
"inference output are different at "
+
str
(
place
)
+
", "
+
data_layout
+
", "
+
str
(
np
.
dtype
(
dtype
))
+
str
(
np
.
array
(
y_tensor
))
+
str
(
y_out
),
atol
=
1e-3
)
def
test_check_output
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"batch_norm"
):
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
for
data_format
in
[
"NCHW"
,
"NHWC"
]:
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
,
4
,
5
])
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
])
class
TestFP16BatchNormOpInference
(
TestBatchNormOpInference
):
def
setUp
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
places
=
[]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"batch_norm"
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
places
.
append
(
place
)
for
place
in
places
:
for
data_format
in
[
"NCHW"
,
"NHWC"
]:
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
,
4
,
5
])
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
])
class
TestBatchNormOpTraining
(
OpTest
):
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
def
test_python_testing
(
self
):
data_format
=
"NHWC"
epsilon
=
0.00001
n
,
h
,
w
,
c
=
2
,
3
,
4
,
5
x_shape
=
[
n
,
h
,
w
,
c
]
scale_shape
=
[
c
]
x_val
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
np
.
float32
)
scale_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
bias_val
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
mean
=
np
.
zeros
(
scale_shape
).
astype
(
np
.
float32
)
variance
=
np
.
ones
(
scale_shape
).
astype
(
np
.
float32
)
y_out
=
_reference_testing
(
x_val
,
scale_val
,
bias_val
,
mean
,
variance
,
epsilon
,
"NHWC"
)
# running N, C, H, W case
# should produce the same results
x_shape2
=
[
n
,
c
,
h
,
w
]
x_val2
=
np
.
transpose
(
x_val
,
(
0
,
3
,
1
,
2
))
y_out2
=
_reference_testing
(
x_val2
,
scale_val
,
bias_val
,
mean
,
variance
,
epsilon
,
"NCHW"
)
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans
=
np
.
transpose
(
y_out2
,
(
0
,
2
,
3
,
1
))
self
.
__assert_close
(
y_out
,
y_out2_trans
,
"inference output"
)
print
'python: NHWC, NCHW, inference checking passed'
def
test_python_training
(
self
):
data_format
=
"NHWC"
data_format
=
"NHWC"
epsilon
=
0.00001
epsilon
=
0.00001
momentum
=
0.9
momentum
=
0.9
...
@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
...
@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
# transfer (N, C, H, W) back to (N, H, W, C)
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans
=
np
.
transpose
(
y_out2
,
(
0
,
2
,
3
,
1
))
y_out2_trans
=
np
.
transpose
(
y_out2
,
(
0
,
2
,
3
,
1
))
self
.
__assert_close
(
y_out
,
y_out2_trans
,
"batch
variance
"
)
self
.
__assert_close
(
y_out
,
y_out2_trans
,
"batch
output
"
)
print
'python: NHWC, NCHW, forward checking passed'
print
'python: NHWC, NCHW, forward checking passed'
# test backward now
# test backward now
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录