Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
02584fe2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02584fe2
编写于
6月 20, 2020
作者:
W
wangdongxu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix perchannel num_channels not set bug and adjust quant.py params order
上级
bd0c5384
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
336 addition
and
473 deletion
+336
-473
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
+9
-26
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
+2
-3
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
...c/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
+2
-25
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
...rc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
+0
-5
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
...src/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
+2
-24
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
...csrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
+0
-5
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+138
-144
mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
...spore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
+36
-45
mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py
mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py
+25
-42
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+122
-154
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
浏览文件 @
02584fe2
...
...
@@ -23,35 +23,24 @@
#include "device/gpu/cuda_common.h"
__global__
void
UpdateInputMinMaxPerLayerWithEMA
(
const
float
*
input_min
,
const
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
float
min
,
const
float
max
,
const
float
decay
,
const
float
symmetric
)
{
float
*
output_max
,
const
float
min
,
const
float
max
,
const
float
decay
)
{
output_min
[
0
]
=
decay
*
(
min
)
+
(
1
-
decay
)
*
(
input_min
[
0
]);
output_min
[
0
]
=
input_min
[
0
]
>
0
?
0
:
input_min
[
0
];
output_max
[
0
]
=
decay
*
(
max
)
+
(
1
-
decay
)
*
(
input_max
[
0
]);
output_max
[
0
]
=
input_max
[
0
]
<
0
?
0
:
input_max
[
0
];
if
(
symmetric
)
{
output_max
[
0
]
=
abs
(
output_min
[
0
])
<
output_max
[
0
]
?
output_max
[
0
]
:
-
output_min
[
0
];
output_min
[
0
]
=
abs
(
output_min
[
0
])
<
output_max
[
0
]
?
-
output_max
[
0
]
:
output_min
[
0
];
}
return
;
}
__global__
void
UpdateInputMinMaxPerLayer
(
float
*
output_min
,
float
*
output_max
,
const
float
min
,
const
float
max
,
const
float
symmetric
)
{
__global__
void
UpdateInputMinMaxPerLayer
(
float
*
output_min
,
float
*
output_max
,
const
float
min
,
const
float
max
)
{
output_min
[
0
]
=
min
>
0
?
0
:
min
;
output_max
[
0
]
=
max
<
0
?
0
:
max
;
if
(
symmetric
)
{
output_max
[
0
]
=
abs
(
output_min
[
0
])
<
output_max
[
0
]
?
output_max
[
0
]
:
-
output_min
[
0
];
output_min
[
0
]
=
abs
(
output_min
[
0
])
<
output_max
[
0
]
?
-
output_max
[
0
]
:
output_min
[
0
];
}
return
;
}
__global__
void
UpdateInputMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
int
channels
,
int
per_channel_nums
,
bool
ema
,
float
ema_decay
,
bool
symmetric
)
{
float
ema_decay
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
channels
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
thrust
::
pair
<
float
*
,
float
*>
sum
=
thrust
::
minmax_element
(
thrust
::
device
,
input
+
i
*
per_channel_nums
,
input
+
per_channel_nums
*
(
i
+
1
));
...
...
@@ -64,27 +53,21 @@ __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, floa
}
output_min
[
i
]
=
input_min
[
i
]
>
0
?
0
:
input_min
[
i
];
output_max
[
i
]
=
input_max
[
i
]
<
0
?
0
:
input_max
[
i
];
if
(
symmetric
)
{
output_max
[
i
]
=
abs
(
output_min
[
i
])
<
output_max
[
i
]
?
output_max
[
i
]
:
-
output_min
[
i
];
output_min
[
i
]
=
abs
(
output_min
[
i
])
<
output_max
[
i
]
?
-
output_max
[
i
]
:
output_min
[
i
];
}
}
return
;
}
void
CalMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
int
channel_num
,
const
float
ema_decay
,
const
bool
ema
,
c
onst
bool
symmetric
,
c
udaStream_t
cuda_stream
)
{
cudaStream_t
cuda_stream
)
{
int
per_channel_num
=
total_num
/
channel_num
;
UpdateInputMinMaxPerChannel
<<<
GET_BLOCKS
(
channel_num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
channel_num
,
per_channel_num
,
ema
,
ema_decay
,
symmetric
);
input
,
input_min
,
input_max
,
output_min
,
output_max
,
channel_num
,
per_channel_num
,
ema
,
ema_decay
);
return
;
}
void
CalMinMaxPerLayer
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
float
ema_decay
,
const
bool
ema
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
)
{
const
int
total_num
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
)
{
float
minel
=
0.
f
;
float
maxel
=
0.
f
;
auto
policy
=
thrust
::
cuda
::
par
.
on
(
cuda_stream
);
...
...
@@ -96,9 +79,9 @@ void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *
if
(
ema
)
{
UpdateInputMinMaxPerLayerWithEMA
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
input_min
,
input_max
,
output_min
,
output_max
,
minel
,
maxel
,
ema_decay
,
symmetric
);
maxel
,
ema_decay
);
}
else
{
UpdateInputMinMaxPerLayer
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
output_min
,
output_max
,
minel
,
maxel
,
symmetric
);
UpdateInputMinMaxPerLayer
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
output_min
,
output_max
,
minel
,
maxel
);
}
return
;
}
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
浏览文件 @
02584fe2
...
...
@@ -21,10 +21,9 @@
void
CalMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
int
channel_num
,
const
float
ema_decay
,
const
bool
ema
,
c
onst
bool
symmetric
,
c
udaStream_t
cuda_stream
);
cudaStream_t
cuda_stream
);
void
CalMinMaxPerLayer
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
size
,
const
float
ema_decay
,
const
bool
ema
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
);
const
int
size
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
浏览文件 @
02584fe2
...
...
@@ -24,16 +24,7 @@
namespace
mindspore
{
namespace
kernel
{
MinMaxUpdatePerChannelGpuKernel
::
MinMaxUpdatePerChannelGpuKernel
()
:
input_size_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
),
num_channels_
(
0
),
narrow_range_
(
false
),
symmetric_
(
false
)
{}
:
input_size_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
),
num_channels_
(
0
)
{}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerChannelGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
...
...
@@ -54,22 +45,8 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG
(
EXCEPTION
)
<<
"Output number is "
<<
output_num
<<
", but FakeQuant GpuKernel OP needs 1 output."
;
}
num_bits_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"num_bits"
));
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
if
(
num_bits_
<=
2
||
num_bits_
>=
16
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
num_bits
\'
"
<<
num_bits_
<<
" is out of range, expected between 2 and 16."
;
}
// quant min and max
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
...
...
@@ -110,7 +87,7 @@ bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inpu
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
input_size_
/
sizeof
(
float
),
num_channels_
,
ema_decay_
,
ema_
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
...
...
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
浏览文件 @
02584fe2
...
...
@@ -44,15 +44,10 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
num_bits_
;
float
quant_min_
;
float
quant_max_
;
int
quant_num_
;
bool
ema_
;
float
ema_decay_
;
int
num_channels_
;
bool
narrow_range_
;
bool
symmetric_
;
};
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
浏览文件 @
02584fe2
...
...
@@ -24,15 +24,7 @@
namespace
mindspore
{
namespace
kernel
{
MinMaxUpdatePerLayerGpuKernel
::
MinMaxUpdatePerLayerGpuKernel
()
:
input_size_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
),
narrow_range_
(
false
),
symmetric_
(
false
)
{}
:
input_size_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
)
{}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerLayerGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
...
...
@@ -51,22 +43,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG
(
EXCEPTION
)
<<
"Output number is "
<<
output_num
<<
", but FakeQuant GpuKernel OP needs 1 output."
;
}
num_bits_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"num_bits"
));
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
if
(
num_bits_
<=
2
||
num_bits_
>=
16
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
num_bits
\'
"
<<
num_bits_
<<
" is out of range, expected between 2 and 16."
;
}
// quant min and max
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
...
...
@@ -104,7 +82,7 @@ bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs
MS_LOG
(
EXCEPTION
)
<<
"MinMaxUpdatePerLayerGpuKernel input min or input max is null."
;
}
CalMinMaxPerLayer
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
quant_num_
,
ema_decay_
,
ema_
,
symmetric_
,
CalMinMaxPerLayer
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
quant_num_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
...
...
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
浏览文件 @
02584fe2
...
...
@@ -44,14 +44,9 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
num_bits_
;
float
quant_min_
;
float
quant_max_
;
int
quant_num_
;
bool
ema_
;
float
ema_decay_
;
bool
narrow_range_
;
bool
symmetric_
;
};
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/nn/layer/quant.py
浏览文件 @
02584fe2
此差异已折叠。
点击以展开。
mindspore/ops/_op_impl/_custom_op/
fake_quant_minmax_perchannel_update
.py
→
mindspore/ops/_op_impl/_custom_op/
minmax_update_perchannel
.py
浏览文件 @
02584fe2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -22,20 +21,15 @@ from topi import generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_min_max_per_channel_update_op_info
=
TBERegOp
(
"MinMaxUpdatePerChannel"
)
\
minmax_update_perchannel_op_info
=
TBERegOp
(
"MinMaxUpdatePerChannel"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"
fake_quant_min_max_per_channel_update
.so"
)
\
.
binfile_name
(
"
minmax_update_perchannel
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"
fake_quant_min_max_per_channel_update
"
)
\
.
kernel_name
(
"
minmax_update_perchannel
"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
...
...
@@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel
.
get_op_info
()
@
op_info_register
(
fake_quant_min_max_per_channel_update
_op_info
)
def
_
fake_quant_min_max_per_channel_update
_tbe
():
"""
FakeQuantPerChannelUpdate
TBE register"""
@
op_info_register
(
minmax_update_perchannel
_op_info
)
def
_
minmax_update_perchannel
_tbe
():
"""
MinMaxUpdatePerChannel
TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_min_max_per_channel_update"
)
def
fake_quant_min_max_per_channel_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update"
):
"""FakeQuantPerChannelUpdate compute"""
@
fusion_manager
.
register
(
"minmax_update_perchannel"
)
def
minmax_update_perchannel_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
channel_axis
):
"""MinMaxUpdatePerChannel compute"""
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
# CalMinMax
if
channel_axis
==
0
:
axis
=
[
1
,
2
,
3
,
4
]
else
:
axis
=
[
0
,
2
,
3
]
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_min_max_per_channel_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update
"
):
"""
FakeQuantPerLayer
op"""
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
int
,
str
)
def
minmax_update_perchannel
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
channel_axis
,
kernel_name
=
"minmax_update_perchannel
"
):
"""
MinMaxUpdatePerChannel
op"""
x_shape
=
x
.
get
(
"ori_shape"
)
x_format
=
x
.
get
(
"format"
)
x_dtype
=
x
.
get
(
"dtype"
)
...
...
@@ -108,21 +105,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
if
channel_axis
==
0
:
shape_c
=
min_val
.
get
(
"ori_shape"
)
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
input_data
=
tvm
.
placeholder
(
x
.
get
(
"shape"
),
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res_list
=
fake_quant_min_max_per_channel_update
_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
)
res_list
=
minmax_update_perchannel
_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
channel_axis
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/_op_impl/_custom_op/
fake_quant_minmax_perlayer_update
.py
→
mindspore/ops/_op_impl/_custom_op/
minmax_update_perlayer
.py
浏览文件 @
02584fe2
...
...
@@ -22,20 +22,15 @@ from topi import generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_minmax_update_op_info
=
TBERegOp
(
"MinMaxUpdatePerLayer"
)
\
minmax_update_perlayer_op_info
=
TBERegOp
(
"MinMaxUpdatePerLayer"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"
fake_quant_minmax_update
.so"
)
\
.
binfile_name
(
"
minmax_update_perlayer
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"
fake_quant_minmax_update
"
)
\
.
kernel_name
(
"
minmax_update_perlayer
"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
...
...
@@ -46,15 +41,14 @@ fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.
get_op_info
()
@
op_info_register
(
fake_quant_minmax_update
_op_info
)
def
_
fake_quant_minmax_update
_tbe
():
@
op_info_register
(
minmax_update_perlayer
_op_info
)
def
_
minmax_update_perlayer
_tbe
():
"""MinMaxUpdatePerLayer TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_minmax_update"
)
def
fake_quant_minmax_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
=
"fake_quant_minmax_update"
):
@
fusion_manager
.
register
(
"minmax_update_perlayer"
)
def
minmax_update_perlayer_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
):
"""MinMaxUpdatePerLayer compute"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
...
...
@@ -62,28 +56,27 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
axis
=
tuple
(
range
(
len
(
shape
)))
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
# CalMinMax
axis
=
tuple
(
range
(
len
(
shape
)))
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
str
)
def
fake_quant_minmax_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
kernel_name
=
"fake_quant_minmax_update"
):
"""FakeQuantPerLayer op"""
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
str
)
def
minmax_update_perlayer
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
kernel_name
=
"minmax_update_perlayer"
):
"""MinMaxUpdatePerLayer op"""
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
...
...
@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape
=
(
functools_reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[:]),)
shape_min
,
_
,
_
=
util
.
produce_shapes
(
min_shape
,
input_shape
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res_list
=
fake_quant_minmax_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
)
res_list
=
minmax_update_perlayer_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
02584fe2
...
...
@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
from
...common
import
dtype
as
mstype
__all__
=
[
"FakeQuantPerLayer"
,
__all__
=
[
"MinMaxUpdatePerLayer"
,
"MinMaxUpdatePerChannel"
,
"FakeQuantPerLayer"
,
"FakeQuantPerLayerGrad"
,
"FakeQuantPerChannel"
,
"FakeQuantPerChannelGrad"
,
"MinMaxUpdatePerLayer"
,
"MinMaxUpdatePerChannel"
,
"BatchNormFold"
,
"BatchNormFoldGrad"
,
"CorrectionMul"
,
...
...
@@ -38,10 +38,128 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGradD"
,
"BatchNormFold2_D"
,
"BatchNormFold2GradD"
,
"BatchNormFold2GradReduce"
,
"BatchNormFold2GradReduce"
]
class
MinMaxUpdatePerLayer
(
PrimitiveWithInfer
):
r
"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perlayer
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
MinMaxUpdatePerChannel
(
PrimitiveWithInfer
):
r
"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perchannel
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
(
{
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
FakeQuantPerLayer
(
PrimitiveWithInfer
):
r
"""
Simulate the quantize and dequantize operations in training time.
...
...
@@ -832,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def
infer_dtype
(
self
,
dout_type
,
x_type
):
validator
.
check
(
"dout type"
,
dout_type
,
"x type"
,
x_type
)
return
dout_type
,
dout_type
class
MinMaxUpdatePerLayer
(
PrimitiveWithInfer
):
r
"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
):
"""init MinMaxUpdatePerLayer OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_minmax_perlayer_update
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
MinMaxUpdatePerChannel
(
PrimitiveWithInfer
):
r
"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
,
channel_axis
=
1
):
"""init MinMaxUpdatePerChannel OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_minmax_perchannel_update
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
(
{
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录