Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
652ab6c3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
652ab6c3
编写于
4年前
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add test case for aware quantizaiton
上级
60958d6b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
456 addition
and
44 deletion
+456
-44
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
...spore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
+7
-7
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h
.../ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h
+7
-7
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h
+7
-7
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h
...e/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h
+8
-6
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
+7
-7
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
...e/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
+9
-7
mindspore/nn/layer/activation.py
mindspore/nn/layer/activation.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+1
-1
tests/st/ops/gpu/test_batchnorm_fold2_op.py
tests/st/ops/gpu/test_batchnorm_fold2_op.py
+89
-0
tests/st/ops/gpu/test_batchnorm_fold_grad_op.py
tests/st/ops/gpu/test_batchnorm_fold_grad_op.py
+96
-0
tests/st/ops/gpu/test_batchnorm_fold_op.py
tests/st/ops/gpu/test_batchnorm_fold_op.py
+116
-0
tests/st/ops/gpu/test_conv2d_op.py
tests/st/ops/gpu/test_conv2d_op.py
+1
-1
tests/st/ops/gpu/test_correction_mul_grad_op.py
tests/st/ops/gpu/test_correction_mul_grad_op.py
+55
-0
tests/st/ops/gpu/test_correction_mul_op.py
tests/st/ops/gpu/test_correction_mul_op.py
+52
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -38,14 +38,14 @@ class BatchNormFold2GpuKernel : public GpuKernel {
~
BatchNormFold2GpuKernel
()
override
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
if
(
is_null_input_
)
{
return
true
;
}
...
...
@@ -66,7 +66,7 @@ class BatchNormFold2GpuKernel : public GpuKernel {
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
...
...
@@ -98,9 +98,9 @@ class BatchNormFold2GpuKernel : public GpuKernel {
}
protected:
void
InitResource
()
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitResource
()
override
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
size_t
input_size
=
batch_size_
*
channel_
*
height_
*
width_
*
sizeof
(
T
);
size_t
weight_size
=
channel_
*
sizeof
(
T
);
input_size_list_
.
push_back
(
input_size
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -38,14 +38,14 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
~
BatchNormFold2GradGpuKernel
()
override
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
if
(
is_null_input_
)
{
return
true
;
}
...
...
@@ -88,7 +88,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
...
...
@@ -120,9 +120,9 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
}
protected:
void
InitResource
()
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitResource
()
override
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
size_t
input_size
=
batch_size_
*
channel_
*
height_
*
width_
*
sizeof
(
T
);
size_t
weight_size
=
channel_
*
sizeof
(
T
);
size_t
workspace_size
=
batch_size_
*
channel_
*
sizeof
(
T
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -46,14 +46,14 @@ class BatchNormFoldGpuKernel : public GpuKernel {
~
BatchNormFoldGpuKernel
()
override
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
(
void
)
workspace
;
auto
x
=
reinterpret_cast
<
T
*>
(
inputs
[
0
]
->
addr
);
auto
mean
=
reinterpret_cast
<
T
*>
(
inputs
[
1
]
->
addr
);
...
...
@@ -104,7 +104,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
4
)
{
...
...
@@ -152,7 +152,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
}
protected:
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
// x, mean, variance, current_step
input_size_list_
.
push_back
(
input_size_
);
input_size_list_
.
push_back
(
output_size_
);
...
...
@@ -169,7 +169,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
workspace_size_list_
.
push_back
(
input_size_
);
}
void
InitResource
()
{
void
InitResource
()
override
{
handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnCreateTensorDescriptor
(
&
x_desc_
),
"Create x desc failed"
);
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnCreateTensorDescriptor
(
&
scale_bias_mean_var_desc_
),
"Create para desc failed"
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -42,11 +42,12 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
width_
(
0
)
{}
~
BatchNormFoldGradGpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
(
void
)
workspace
;
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
T
*
d_batch_mean
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
...
...
@@ -92,7 +93,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
6
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but BatchNormFoldGrad GpuKernel OP needs 6 input."
;
...
...
@@ -128,7 +130,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
}
protected:
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
input_size_list_
.
push_back
(
channel_size_
);
input_size_list_
.
push_back
(
channel_size_
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -30,11 +30,11 @@ class CorrectionMulGpuKernel : public GpuKernel {
CorrectionMulGpuKernel
()
:
batch_size_
(
0
),
channel_
(
0
),
height_
(
0
),
width_
(
0
)
{}
~
CorrectionMulGpuKernel
()
override
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
auto
*
weight
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
*
gamma
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
auto
*
running_std
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
...
...
@@ -44,7 +44,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
...
...
@@ -69,7 +69,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
}
protected:
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
size_t
input_size
=
batch_size_
*
channel_
*
height_
*
width_
*
sizeof
(
T
);
size_t
weight_size
=
batch_size_
*
sizeof
(
T
);
input_size_list_
.
push_back
(
input_size
);
// weight
...
...
@@ -79,7 +79,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
output_size_list_
.
push_back
(
input_size
);
workspace_size_list_
.
push_back
(
workspace_size
);
}
void
InitResource
()
{}
void
InitResource
()
override
{}
private:
void
DestroyResource
()
noexcept
{}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
浏览文件 @
652ab6c3
...
...
@@ -30,11 +30,12 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
CorrectionMulGradGpuKernel
()
:
batch_size_
(
0
),
channel_
(
0
),
height_
(
0
),
width_
(
0
)
{}
~
CorrectionMulGradGpuKernel
()
override
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
auto
*
d_out
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
*
weight
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
auto
*
gamma
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
...
...
@@ -49,7 +50,8 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
...
...
@@ -74,7 +76,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
}
protected:
void
InitSizeLists
()
{
void
InitSizeLists
()
override
{
size_t
input_size
=
batch_size_
*
channel_
*
height_
*
width_
*
sizeof
(
T
);
size_t
weight_size
=
batch_size_
*
sizeof
(
T
);
input_size_list_
.
push_back
(
input_size
);
// d_out
...
...
@@ -85,7 +87,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
output_size_list_
.
push_back
(
weight_size
);
// d_gamma
workspace_size_list_
.
push_back
(
input_size
);
// tmp d_out * weight
}
void
InitResource
()
{}
void
InitResource
()
override
{}
private:
void
DestroyResource
()
noexcept
{}
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/layer/activation.py
浏览文件 @
652ab6c3
...
...
@@ -369,7 +369,7 @@ class HSigmoid(Cell):
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \f
t
ac{2 * x_{i} + 5}{10})),
\text{hsigmoid}(x_{i}) = max(0, min(1, \f
r
ac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
...
...
This diff is collapsed.
Click to expand it.
mindspore/ops/operations/nn_ops.py
浏览文件 @
652ab6c3
...
...
@@ -319,7 +319,7 @@ class HSigmoid(PrimitiveWithInfer):
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \f
t
ac{2 * x_{i} + 5}{10})),
\text{hsigmoid}(x_{i}) = max(0, min(1, \f
r
ac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
...
...
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_batchnorm_fold2_op.py
0 → 100644
浏览文件 @
652ab6c3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
mindspore.context
as
context
context
.
set_context
(
device_target
=
'GPU'
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
P
.
BatchNormFold2
(
100000
)
@
ms_function
def
construct
(
self
,
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
current_step
):
return
self
.
op
(
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
current_step
)
class
Net_gnd
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net_gnd
,
self
).
__init__
()
self
.
conv_mul
=
P
.
ConvMul
(
freeze_bn
=
100000
)
self
.
correct_add
=
P
.
CorrectionAdd
(
freeze_bn
=
100000
)
self
.
add_fold
=
P
.
AddFold
()
@
ms_function
def
construct
(
self
,
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
current_step
):
out
=
self
.
conv_mul
(
x
,
batch_std
,
running_std
,
current_step
)
out
=
self
.
correct_add
(
out
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
current_step
)
out
=
self
.
add_fold
(
out
,
beta
,
gamma
,
batch_std
,
batch_mean
)
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnrom_fold2
():
net
=
Net
()
c
=
64
freeze_bn
=
100000
x
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
3
,
c
,
32
,
32
]).
astype
(
'float32'
)
beta
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
gamma
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
batch_std
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
batch_mean
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
running_std
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
running_mean
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
0
]).
astype
(
'int32'
)
output
=
net
(
Tensor
(
x
),
Tensor
(
beta
),
Tensor
(
gamma
),
Tensor
(
batch_std
),
Tensor
(
batch_mean
),
Tensor
(
running_std
),
Tensor
(
running_mean
),
Tensor
(
current_step
))
expect
=
(
x
+
beta
.
reshape
(
-
1
,
1
,
1
)
-
(
gamma
*
running_mean
/
running_std
).
reshape
(
-
1
,
1
,
1
)
if
current_step
>=
freeze_bn
else
x
*
(
running_std
/
batch_std
).
reshape
(
-
1
,
1
,
1
)
+
(
beta
-
gamma
*
batch_mean
/
batch_std
).
reshape
(
-
1
,
1
,
1
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-6
diff
=
output
.
asnumpy
()
-
expect
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
diff
>
error
*
-
1
)
current_step
=
np
.
array
([
100000
]).
astype
(
'int32'
)
output
=
net
(
Tensor
(
x
),
Tensor
(
beta
),
Tensor
(
gamma
),
Tensor
(
batch_std
),
Tensor
(
batch_mean
),
Tensor
(
running_std
),
Tensor
(
running_mean
),
Tensor
(
current_step
))
expect
=
(
x
+
beta
.
reshape
(
-
1
,
1
,
1
)
-
(
gamma
*
running_mean
/
running_std
).
reshape
(
-
1
,
1
,
1
)
if
current_step
>=
freeze_bn
else
x
*
(
batch_std
/
running_std
).
reshape
(
-
1
,
1
,
1
)
+
(
beta
-
gamma
*
batch_mean
/
batch_std
).
reshape
(
-
1
,
1
,
1
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-6
diff
=
output
.
asnumpy
()
-
expect
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
diff
>
error
*
-
1
)
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_batchnorm_fold_grad_op.py
0 → 100644
浏览文件 @
652ab6c3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
mindspore.context
as
context
context
.
set_context
(
device_target
=
'GPU'
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
P
.
BatchNormFoldGrad
(
freeze_bn
=
10
)
@
ms_function
def
construct
(
self
,
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
,
current_step
):
dx
=
self
.
op
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
,
current_step
)
return
dx
def
np_result
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
):
n
=
x
.
shape
[
0
]
*
x
.
shape
[
2
]
*
x
.
shape
[
3
]
dx
=
d_batch_mean
.
reshape
(
1
,
-
1
,
1
,
1
)
/
n
+
d_batch_std
.
reshape
(
1
,
-
1
,
1
,
1
)
*
(
x
-
batch_mean
.
reshape
(
1
,
-
1
,
1
,
1
))
/
batch_std
.
reshape
(
1
,
-
1
,
1
,
1
)
/
n
return
dx
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold_grad1
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
3
,
c
,
32
,
32
]).
astype
(
'float32'
)
d_batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
d_batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
0
]).
astype
(
'int32'
)
dx
=
net
(
Tensor
(
d_batch_mean
),
Tensor
(
d_batch_std
),
Tensor
(
x
),
Tensor
(
batch_mean
),
Tensor
(
batch_std
),
Tensor
(
current_step
))
expect
=
np_result
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
)
assert
np
.
allclose
(
dx
.
asnumpy
(),
expect
,
rtol
=
1.e-7
,
atol
=
1.e-7
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold_grad2
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
1
,
c
,
256
,
256
]).
astype
(
'float32'
)
d_batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
d_batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
0
]).
astype
(
'int32'
)
dx
=
net
(
Tensor
(
d_batch_mean
),
Tensor
(
d_batch_std
),
Tensor
(
x
),
Tensor
(
batch_mean
),
Tensor
(
batch_std
),
Tensor
(
current_step
))
expect
=
np_result
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
)
assert
np
.
allclose
(
dx
.
asnumpy
(),
expect
,
rtol
=
1.e-7
,
atol
=
1.e-7
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold_grad_freeze
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
3
,
c
,
32
,
32
]).
astype
(
'float32'
)
d_batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
d_batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
10
]).
astype
(
'int32'
)
dx
=
net
(
Tensor
(
d_batch_mean
),
Tensor
(
d_batch_std
),
Tensor
(
x
),
Tensor
(
batch_mean
),
Tensor
(
batch_std
),
Tensor
(
current_step
))
expect
=
np
.
zeros_like
(
x
)
assert
np
.
allclose
(
dx
.
asnumpy
(),
expect
,
rtol
=
1.e-7
,
atol
=
1.e-7
)
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_batchnorm_fold_op.py
0 → 100644
浏览文件 @
652ab6c3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
mindspore.context
as
context
context
.
set_context
(
device_target
=
'GPU'
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
P
.
BatchNormFold
(
freeze_bn
=
10
)
@
ms_function
def
construct
(
self
,
x
,
mean
,
variance
,
current_step
):
a
,
b
,
c
,
d
=
self
.
op
(
x
,
mean
,
variance
,
current_step
)
return
a
,
b
,
c
,
d
def
np_result
(
x
,
mean
,
var
,
momentum
,
epsilon
):
np_mean
=
x
.
mean
(
axis
=
(
0
,
2
,
3
))
np_var
=
x
.
var
(
axis
=
(
0
,
2
,
3
))
n
=
x
.
shape
[
0
]
*
x
.
shape
[
2
]
*
x
.
shape
[
3
]
mean_update
=
momentum
*
np_mean
+
(
1
-
momentum
)
*
mean
var_update
=
momentum
*
np_var
*
n
/
(
n
-
1
)
+
(
1
-
momentum
)
*
var
np_var
=
np
.
sqrt
(
np_var
+
epsilon
)
delay_mean
=
mean
.
copy
()
delay_std
=
np
.
sqrt
(
var
+
epsilon
)
return
np_mean
,
np_var
,
mean_update
,
var_update
,
delay_mean
,
delay_std
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
3
,
c
,
32
,
32
]).
astype
(
'float32'
)
mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
variance
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
0
]).
astype
(
'int32'
)
ms_mean
=
Tensor
(
mean
)
ms_var
=
Tensor
(
variance
)
batch_mean
,
batch_var
,
delay_mean
,
delay_std
=
net
(
Tensor
(
x
),
ms_mean
,
ms_var
,
Tensor
(
current_step
))
expect1
,
expect2
,
expect3
,
expect4
,
expect5
,
expect6
=
np_result
(
x
,
mean
,
variance
,
0.9
,
1e-12
)
assert
np
.
allclose
(
batch_mean
.
asnumpy
(),
expect1
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
batch_var
.
asnumpy
(),
expect2
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
ms_mean
.
asnumpy
(),
expect3
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
ms_var
.
asnumpy
(),
expect4
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_mean
.
asnumpy
(),
expect5
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_std
.
asnumpy
(),
expect6
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold2
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
3
,
c
,
512
,
512
]).
astype
(
'float32'
)
mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
variance
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
0
]).
astype
(
'int32'
)
ms_mean
=
Tensor
(
mean
)
ms_var
=
Tensor
(
variance
)
batch_mean
,
batch_var
,
delay_mean
,
delay_std
=
net
(
Tensor
(
x
),
ms_mean
,
ms_var
,
Tensor
(
current_step
))
expect1
,
expect2
,
expect3
,
expect4
,
expect5
,
expect6
=
np_result
(
x
,
mean
,
variance
,
0.9
,
1e-12
)
assert
np
.
allclose
(
batch_mean
.
asnumpy
(),
expect1
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
batch_var
.
asnumpy
(),
expect2
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
ms_mean
.
asnumpy
(),
expect3
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_mean
.
asnumpy
(),
expect5
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_std
.
asnumpy
(),
expect6
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_batchnorm_fold_freeze
():
net
=
Net
()
c
=
64
x
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
3
,
c
,
32
,
32
]).
astype
(
'float32'
)
mean
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
variance
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
c
]).
astype
(
'float32'
)
current_step
=
np
.
array
([
10
]).
astype
(
'int32'
)
ms_mean
=
Tensor
(
mean
)
ms_var
=
Tensor
(
variance
)
batch_mean
,
batch_var
,
delay_mean
,
delay_std
=
net
(
Tensor
(
x
),
ms_mean
,
ms_var
,
Tensor
(
current_step
))
expect1
,
expect2
,
expect3
,
expect4
,
expect5
,
expect6
=
np_result
(
x
,
mean
,
variance
,
0.9
,
1e-12
)
assert
np
.
allclose
(
batch_mean
.
asnumpy
(),
np
.
zeros_like
(
mean
),
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
batch_var
.
asnumpy
(),
np
.
ones_like
(
mean
),
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
ms_mean
.
asnumpy
(),
mean
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
ms_var
.
asnumpy
(),
variance
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_mean
.
asnumpy
(),
expect5
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
assert
np
.
allclose
(
delay_std
.
asnumpy
(),
expect6
,
rtol
=
1.e-7
,
atol
=
1.e-5
)
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_conv2d_op.py
浏览文件 @
652ab6c3
...
...
@@ -14,10 +14,10 @@
# ============================================================================
import
pytest
import
numpy
as
np
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
import
numpy
as
np
import
mindspore.context
as
context
...
...
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_correction_mul_grad_op.py
0 → 100644
浏览文件 @
652ab6c3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
os
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
mindspore.context
as
context
context
.
set_context
(
device_target
=
'GPU'
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op_w
=
P
.
CorrectionMulGrad
()
@
ms_function
def
construct
(
self
,
dy
,
x
,
batch_std
,
running_std
):
dx
,
d_batch_std
=
self
.
op_w
(
dy
,
x
,
batch_std
,
running_std
)
return
dx
,
d_batch_std
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_correction_mul_grad
():
net
=
Net
()
co
,
ci
,
h
,
w
=
64
,
1
,
32
,
32
dout
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
size
=
[
co
,
ci
,
h
,
w
]).
astype
(
'float32'
)
x
=
np
.
random
.
uniform
(
1
,
1
,
size
=
[
co
,
ci
,
h
,
w
]).
astype
(
'float32'
)
batch_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
co
]).
astype
(
'float32'
)
running_std
=
np
.
random
.
uniform
(
1
,
10
,
size
=
[
co
]).
astype
(
'float32'
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
batch_std
),
Tensor
(
running_std
))
expect
=
[
0
,
0
]
expect
[
0
]
=
(
dout
*
np
.
reshape
(
batch_std
/
running_std
,
(
co
,
1
,
1
,
1
)))
expect
[
1
]
=
(
np
.
sum
(
dout
*
x
,
(
1
,
2
,
3
))
/
running_std
)
for
i
,
v
in
enumerate
(
output
):
assert
(
np
.
allclose
(
output
[
i
].
asnumpy
(),
expect
[
i
],
rtol
=
1.e-5
,
atol
=
1.e-5
))
This diff is collapsed.
Click to expand it.
tests/st/ops/gpu/test_correction_mul_op.py
0 → 100644
浏览文件 @
652ab6c3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore.common.api
import
ms_function
import
mindspore.context
as
context
context
.
set_context
(
device_target
=
'GPU'
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
P
.
CorrectionMul
()
@
ms_function
def
construct
(
self
,
x
,
batch_var
,
moving_var
):
return
self
.
op
(
x
,
batch_var
,
moving_var
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_correction_mul
():
net
=
Net
()
co
=
64
x
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
co
,
64
,
32
,
32
]).
astype
(
'float32'
)
bv
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
co
]).
astype
(
'float32'
)
mv
=
np
.
random
.
uniform
(
1
,
2
,
size
=
[
co
]).
astype
(
'float32'
)
output
=
net
(
Tensor
(
x
),
Tensor
(
bv
),
Tensor
(
mv
))
expect
=
x
*
np
.
reshape
(
bv
,
(
co
,
1
,
1
,
1
))
/
np
.
reshape
(
mv
,
(
co
,
1
,
1
,
1
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
()
-
expect
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
diff
>
error
*
-
1
)
assert
(
output
.
shape
()
==
expect
.
shape
)
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部