Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
89186edc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
89186edc
编写于
11月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn): correct reduce/argmxx/fakequant calculation with nan
GitOrigin-RevId: 7e78bdae9106186c5d1a1b8ee2ab337ed69db21b
上级
68cdabd2
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
219 addition
and
14 deletion
+219
-14
dnn/src/common/argmxx_helper.h
dnn/src/common/argmxx_helper.h
+66
-0
dnn/src/common/reduce_helper.h
dnn/src/common/reduce_helper.h
+44
-0
dnn/src/cuda/fake_quant/kern.cuh
dnn/src/cuda/fake_quant/kern.cuh
+10
-2
dnn/src/naive/argmxx/opr_impl.cpp
dnn/src/naive/argmxx/opr_impl.cpp
+6
-2
dnn/src/naive/reduce/opr_impl.cpp
dnn/src/naive/reduce/opr_impl.cpp
+18
-8
dnn/test/cuda/fake_quant.cpp
dnn/test/cuda/fake_quant.cpp
+51
-1
dnn/test/cuda/reduce.cpp
dnn/test/cuda/reduce.cpp
+14
-0
imperative/python/test/unit/functional/test_math.py
imperative/python/test/unit/functional/test_math.py
+5
-1
imperative/python/test/unit/quantization/test_fake_quant.py
imperative/python/test/unit/quantization/test_fake_quant.py
+5
-0
未找到文件。
dnn/src/common/argmxx_helper.h
浏览文件 @
89186edc
...
@@ -78,6 +78,72 @@ struct ArgmxxOp {
...
@@ -78,6 +78,72 @@ struct ArgmxxOp {
const
wtype
INIT
;
const
wtype
INIT
;
};
};
template
<
bool
is_max
>
struct
ArgmxxOp
<
dt_float32
,
is_max
>
{
using
stype_
=
dt_float32
;
struct
wtype
{
stype_
key
;
dt_int32
val
;
MEGDNN_HOST
MEGDNN_DEVICE
wtype
()
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
(
stype_
key
,
dt_int32
val
)
:
key
(
key
),
val
(
val
)
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
(
wtype
&
rhs
)
:
key
(
rhs
.
key
),
val
(
rhs
.
val
)
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
(
volatile
wtype
&
rhs
)
:
key
(
rhs
.
key
),
val
(
rhs
.
val
)
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
(
const
wtype
&
rhs
)
:
key
(
rhs
.
key
),
val
(
rhs
.
val
)
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
(
const
volatile
wtype
&
rhs
)
:
key
(
rhs
.
key
),
val
(
rhs
.
val
)
{}
MEGDNN_HOST
MEGDNN_DEVICE
volatile
wtype
&
operator
=
(
const
wtype
&
rhs
)
volatile
{
this
->
key
=
rhs
.
key
;
this
->
val
=
rhs
.
val
;
return
*
this
;
}
};
MEGDNN_HOST
MEGDNN_DEVICE
ArgmxxOp
(
stype_
*
src
,
dt_int32
*
dst
,
uint32_t
A
,
uint32_t
B
,
uint32_t
C
)
:
src
(
src
),
dst
(
dst
),
A
(
A
),
B
(
B
),
C
(
C
),
INIT
(
wtype
(
is_max
?
DTypeTrait
<
stype_
>::
min
()
:
DTypeTrait
<
stype_
>::
max
(),
0
))
{}
MEGDNN_HOST
MEGDNN_DEVICE
wtype
read
(
uint32_t
idx
)
{
wtype
res
;
res
.
key
=
src
[
idx
];
res
.
val
=
idx
/
C
%
B
;
return
res
;
}
MEGDNN_HOST
MEGDNN_DEVICE
void
write
(
uint32_t
idx
,
wtype
val
)
{
dst
[
idx
]
=
val
.
val
;
}
static
MEGDNN_HOST
MEGDNN_DEVICE
wtype
apply
(
wtype
lhs
,
wtype
rhs
)
{
#if defined(__CUDA_ARCH__)
if
(
isnan
(
lhs
.
key
))
#else
if
(
std
::
isnan
(
lhs
.
key
))
#endif
return
lhs
;
if
(
is_max
)
{
if
(
lhs
.
key
>
rhs
.
key
)
return
lhs
;
else
return
rhs
;
}
else
{
if
(
lhs
.
key
<
rhs
.
key
)
return
lhs
;
else
return
rhs
;
}
}
stype_
*
src
;
dt_int32
*
dst
;
uint32_t
A
,
B
,
C
;
const
wtype
INIT
;
};
}
// namespace argmxx
}
// namespace argmxx
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/common/reduce_helper.h
浏览文件 @
89186edc
...
@@ -119,6 +119,28 @@ struct MinOp {
...
@@ -119,6 +119,28 @@ struct MinOp {
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
max
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
max
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
};
};
template
<
typename
src_ctype
,
typename
dst_ctype
>
struct
MinOp
<
src_ctype
,
dst_ctype
,
dt_float32
>
{
typedef
dt_float32
wtype
;
const
wtype
INIT
;
src_ctype
*
src
;
dst_ctype
*
dst
;
const
size_t
B
;
MEGDNN_HOST
MEGDNN_DEVICE
wtype
read
(
uint32_t
idx
)
{
return
src
[
idx
];
}
MEGDNN_HOST
MEGDNN_DEVICE
void
write
(
uint32_t
idx
,
wtype
val
)
{
dst
[
idx
]
=
val
;
}
static
MEGDNN_HOST
MEGDNN_DEVICE
wtype
apply
(
wtype
lhs
,
wtype
rhs
)
{
#if defined(__CUDA_ARCH__)
return
(
isnan
(
lhs
)
||
lhs
<
rhs
)
?
lhs
:
rhs
;
#else
return
(
std
::
isnan
(
lhs
)
||
lhs
<
rhs
)
?
lhs
:
rhs
;
#endif
}
MEGDNN_HOST
MEGDNN_DEVICE
MinOp
(
src_ctype
*
src
,
dst_ctype
*
dst
,
size_t
B
)
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
max
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
};
template
<
typename
src_ctype
,
typename
dst_ctype
,
typename
wtype_
>
template
<
typename
src_ctype
,
typename
dst_ctype
,
typename
wtype_
>
struct
MaxOp
{
struct
MaxOp
{
typedef
wtype_
wtype
;
typedef
wtype_
wtype
;
...
@@ -141,6 +163,28 @@ struct MaxOp {
...
@@ -141,6 +163,28 @@ struct MaxOp {
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
min
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
min
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
};
};
template
<
typename
src_ctype
,
typename
dst_ctype
>
struct
MaxOp
<
src_ctype
,
dst_ctype
,
dt_float32
>
{
typedef
dt_float32
wtype
;
const
wtype
INIT
;
src_ctype
*
src
;
dst_ctype
*
dst
;
const
size_t
B
;
MEGDNN_HOST
MEGDNN_DEVICE
wtype
read
(
uint32_t
idx
)
{
return
src
[
idx
];
}
MEGDNN_HOST
MEGDNN_DEVICE
void
write
(
uint32_t
idx
,
wtype
val
)
{
dst
[
idx
]
=
val
;
}
static
MEGDNN_HOST
MEGDNN_DEVICE
wtype
apply
(
wtype
lhs
,
wtype
rhs
)
{
#if defined(__CUDA_ARCH__)
return
(
isnan
(
lhs
)
||
lhs
>
rhs
)
?
lhs
:
rhs
;
#else
return
(
std
::
isnan
(
lhs
)
||
lhs
>
rhs
)
?
lhs
:
rhs
;
#endif
}
MEGDNN_HOST
MEGDNN_DEVICE
MaxOp
(
src_ctype
*
src
,
dst_ctype
*
dst
,
size_t
B
)
:
INIT
(
wtype
(
DTypeTrait
<
wtype
>::
min
())),
src
(
src
),
dst
(
dst
),
B
(
B
)
{}
};
template
<
typename
src_ctype
,
typename
dst_ctype
,
typename
wtype_
>
template
<
typename
src_ctype
,
typename
dst_ctype
,
typename
wtype_
>
struct
CheckNonFiniteOp
{
struct
CheckNonFiniteOp
{
typedef
wtype_
wtype
;
typedef
wtype_
wtype
;
...
...
dnn/src/cuda/fake_quant/kern.cuh
浏览文件 @
89186edc
...
@@ -30,6 +30,10 @@ struct FakeQuantKernOp {
...
@@ -30,6 +30,10 @@ struct FakeQuantKernOp {
__device__
void
operator
()(
uint32_t
idx
,
ctype
scale
,
ctype
zero_point
)
{
__device__
void
operator
()(
uint32_t
idx
,
ctype
scale
,
ctype
zero_point
)
{
ctype
x
=
round
(
input
[
idx
]
/
scale
)
+
zero_point
;
ctype
x
=
round
(
input
[
idx
]
/
scale
)
+
zero_point
;
if
(
isnan
(
x
))
{
output
[
idx
]
=
NAN
;
return
;
}
x
=
fmaxf
(
fminf
(
x
,
qmax
),
qmin
);
x
=
fmaxf
(
fminf
(
x
,
qmax
),
qmin
);
output
[
idx
]
=
(
x
-
zero_point
)
*
scale
;
output
[
idx
]
=
(
x
-
zero_point
)
*
scale
;
}
}
...
@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp {
...
@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp {
__device__
void
operator
()(
uint32_t
idx
,
ctype
scale
,
ctype
zero_point
)
{
__device__
void
operator
()(
uint32_t
idx
,
ctype
scale
,
ctype
zero_point
)
{
ctype
x
=
round
(
input
[
idx
]
/
scale
)
+
zero_point
;
ctype
x
=
round
(
input
[
idx
]
/
scale
)
+
zero_point
;
grad
[
idx
]
=
x
<=
qmax
&&
x
>=
qmin
?
diff
[
idx
]
:
0.0
;
grad
[
idx
]
=
isnan
(
x
)
?
NAN
:
x
<=
qmax
&&
x
>=
qmin
?
diff
[
idx
]
:
0.0
;
}
}
#if MEGDNN_CC_HOST
#if MEGDNN_CC_HOST
...
@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig {
...
@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig {
__device__
void
operator
()(
__device__
void
operator
()(
uint32_t
,
ctype
&
output
,
ctype
input
,
ctype
scale
,
ctype
zero_point
)
{
uint32_t
,
ctype
&
output
,
ctype
input
,
ctype
scale
,
ctype
zero_point
)
{
ctype
x
=
round
(
input
/
scale
)
+
zero_point
;
ctype
x
=
round
(
input
/
scale
)
+
zero_point
;
if
(
isnan
(
x
))
{
output
=
NAN
;
return
;
}
x
=
fmaxf
(
fminf
(
x
,
qmax
),
qmin
);
x
=
fmaxf
(
fminf
(
x
,
qmax
),
qmin
);
output
=
(
x
-
zero_point
)
*
scale
;
output
=
(
x
-
zero_point
)
*
scale
;
}
}
...
@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig {
...
@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig {
uint32_t
,
ctype
&
grad
,
ctype
diff
,
ctype
input
,
ctype
scale
,
uint32_t
,
ctype
&
grad
,
ctype
diff
,
ctype
input
,
ctype
scale
,
ctype
zero_point
)
{
ctype
zero_point
)
{
ctype
x
=
round
(
input
/
scale
)
+
zero_point
;
ctype
x
=
round
(
input
/
scale
)
+
zero_point
;
grad
=
x
<=
qmax
&&
x
>=
qmin
?
diff
:
0.0
;
grad
=
isnan
(
x
)
?
NAN
:
x
<=
qmax
&&
x
>=
qmin
?
diff
:
0.0
;
}
}
#if MEGDNN_CC_HOST
#if MEGDNN_CC_HOST
...
...
dnn/src/naive/argmxx/opr_impl.cpp
浏览文件 @
89186edc
...
@@ -26,14 +26,18 @@ struct traits;
...
@@ -26,14 +26,18 @@ struct traits;
template
<
>
template
<
>
struct
traits
<
true
>
{
struct
traits
<
true
>
{
static
const
float
init
;
static
const
float
init
;
static
bool
better_than
(
float
lhs
,
float
rhs
)
{
return
lhs
>
rhs
;
}
static
bool
better_than
(
float
lhs
,
float
rhs
)
{
return
std
::
isnan
(
lhs
)
?
true
:
lhs
>
rhs
;
}
};
};
const
float
traits
<
true
>::
init
=
std
::
numeric_limits
<
float
>::
lowest
();
const
float
traits
<
true
>::
init
=
std
::
numeric_limits
<
float
>::
lowest
();
template
<
>
template
<
>
struct
traits
<
false
>
{
struct
traits
<
false
>
{
static
const
float
init
;
static
const
float
init
;
static
float
better_than
(
float
lhs
,
float
rhs
)
{
return
lhs
<
rhs
;
}
static
float
better_than
(
float
lhs
,
float
rhs
)
{
return
std
::
isnan
(
lhs
)
?
true
:
lhs
<
rhs
;
}
};
};
const
float
traits
<
false
>::
init
=
std
::
numeric_limits
<
float
>::
max
();
const
float
traits
<
false
>::
init
=
std
::
numeric_limits
<
float
>::
max
();
...
...
dnn/src/naive/reduce/opr_impl.cpp
浏览文件 @
89186edc
...
@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);
...
@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
Trait
<
Mode
::
MIN
,
ctype
>
{
struct
Trait
<
Mode
::
MIN
,
ctype
>
{
static
const
ctype
INIT
;
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
x
<
y
?
x
:
y
;
}
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
x
<
y
?
x
:
y
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
};
};
template
<
typename
ctype
>
const
ctype
Trait
<
Mode
::
MIN
,
ctype
>::
INIT
=
DTypeTrait
<
ctype
>::
max
();
template
<
>
struct
Trait
<
Mode
::
MIN
,
dt_float32
>
{
using
ctype
=
dt_float32
;
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
(
std
::
isnan
(
x
)
||
x
<
y
)
?
x
:
y
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
};
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
Trait
<
Mode
::
MAX
,
ctype
>
{
struct
Trait
<
Mode
::
MAX
,
ctype
>
{
static
const
ctype
INIT
;
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
x
>
y
?
x
:
y
;
}
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
x
>
y
?
x
:
y
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
};
};
template
<
typename
ctype
>
const
ctype
Trait
<
Mode
::
MAX
,
ctype
>::
INIT
=
DTypeTrait
<
ctype
>::
min
();
template
<
>
struct
Trait
<
Mode
::
MAX
,
dt_float32
>
{
using
ctype
=
dt_float32
;
static
ctype
apply
(
ctype
x
,
ctype
y
)
{
return
(
std
::
isnan
(
x
)
||
x
>
y
)
?
x
:
y
;
}
static
ctype
visit
(
ctype
x
)
{
return
x
;
}
static
ctype
write
(
ctype
x
,
size_t
)
{
return
x
;
}
};
template
<
Mode
mode
,
typename
ctype
>
template
<
Mode
mode
,
typename
ctype
>
void
reduce_fwd
(
void
reduce_fwd
(
...
...
dnn/test/cuda/fake_quant.cpp
浏览文件 @
89186edc
...
@@ -21,7 +21,9 @@ using namespace fake_quant;
...
@@ -21,7 +21,9 @@ using namespace fake_quant;
TEST_F
(
CUDA
,
FAKE_QUANT
)
{
TEST_F
(
CUDA
,
FAKE_QUANT
)
{
std
::
vector
<
TestArg
>
args
=
get_args
();
std
::
vector
<
TestArg
>
args
=
get_args
();
auto
dtype
=
dtype
::
Float32
();
auto
dtype
=
dtype
::
Float32
();
std
::
unique_ptr
<
RNG
>
rng
;
UniformFloatRNG
rng
(
-
1.0
f
,
1.0
f
);
const
auto
nan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
UniformFloatWithValueRNG
rng1
=
UniformFloatWithValueRNG
(
-
1.0
f
,
1.0
f
,
0.5
f
,
nan
);
for
(
auto
&&
arg
:
args
)
{
for
(
auto
&&
arg
:
args
)
{
auto
param
=
arg
.
param
;
auto
param
=
arg
.
param
;
...
@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) {
...
@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) {
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
execs
(
TensorShapeArray
{
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
.
execs
(
TensorShapeArray
{
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
checker
.
set_allow_invalid_check
(
true
);
checker
.
set_rng
(
0
,
&
rng1
);
checker
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
)
.
set_dtype
(
1
,
dtype
)
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
execs
(
TensorShapeArray
{
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_allow_invalid_check
(
false
);
}
}
// test noncontiguous layout
// test noncontiguous layout
for
(
auto
&&
arg
:
args
)
{
for
(
auto
&&
arg
:
args
)
{
...
@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) {
...
@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) {
{
scale_shape
,
dtype
::
Float32
()},
{
scale_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
ilayout
});
ilayout
});
checker
.
set_allow_invalid_check
(
true
);
checker
.
set_rng
(
0
,
&
rng1
);
checker
.
set_param
(
param
).
execl
(
{
ilayout
,
{
scale_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
ilayout
});
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_allow_invalid_check
(
false
);
}
}
}
}
TEST_F
(
CUDA
,
FAKE_QUANT_BACKWARD
)
{
TEST_F
(
CUDA
,
FAKE_QUANT_BACKWARD
)
{
std
::
vector
<
TestArg
>
args
=
get_args
();
std
::
vector
<
TestArg
>
args
=
get_args
();
auto
dtype
=
dtype
::
Float32
();
auto
dtype
=
dtype
::
Float32
();
UniformFloatRNG
rng
(
-
1.0
f
,
1.0
f
);
const
auto
nan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
UniformFloatWithValueRNG
rng1
=
UniformFloatWithValueRNG
(
-
1.0
f
,
1.0
f
,
0.5
f
,
nan
);
for
(
auto
&&
arg
:
args
)
{
for
(
auto
&&
arg
:
args
)
{
auto
param
=
arg
.
param
;
auto
param
=
arg
.
param
;
...
@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
...
@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
.
set_dtype
(
4
,
dtype
)
.
set_dtype
(
4
,
dtype
)
.
execs
(
TensorShapeArray
{
.
execs
(
TensorShapeArray
{
ishape
,
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
ishape
,
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
checker
.
set_allow_invalid_check
(
true
);
checker
.
set_rng
(
0
,
&
rng1
);
checker
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
)
.
set_dtype
(
1
,
dtype
)
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
set_dtype
(
4
,
dtype
)
.
execs
(
TensorShapeArray
{
ishape
,
ishape
,
scale_shape
,
zeropoint_shape
,
ishape
});
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_allow_invalid_check
(
false
);
}
}
// test noncontiguous layout
// test noncontiguous layout
for
(
auto
&&
arg
:
args
)
{
for
(
auto
&&
arg
:
args
)
{
...
@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
...
@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
{
scale_shape
,
dtype
::
Float32
()},
{
scale_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
ilayout
});
ilayout
});
checker
.
set_allow_invalid_check
(
true
);
checker
.
set_rng
(
0
,
&
rng1
);
checker
.
set_param
(
param
).
execl
(
{
ilayout
,
ilayout
,
{
scale_shape
,
dtype
::
Float32
()},
{
zeropoint_shape
,
dtype
::
Float32
()},
ilayout
});
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_allow_invalid_check
(
false
);
}
}
}
}
...
...
dnn/test/cuda/reduce.cpp
浏览文件 @
89186edc
...
@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) {
...
@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) {
// very large reduce
// very large reduce
checker
.
execs
({{
1
,
4194304
,
1
},
{}});
checker
.
execs
({{
1
,
4194304
,
1
},
{}});
// inputs have nan
{
const
auto
nan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
UniformFloatWithValueRNG
rng1
=
UniformFloatWithValueRNG
(
-
1.0
f
,
1.0
f
,
0.5
f
,
nan
);
checker
.
set_allow_invalid_check
(
true
).
set_rng
(
0
,
&
rng1
);
for
(
auto
mode
:
{
Mode
::
MIN
,
Mode
::
MAX
})
{
checker
.
set_param
({
mode
,
1
});
checker
.
execs
({{
2
,
64
,
32
},
{}});
}
checker
.
set_allow_invalid_check
(
false
);
}
checker
.
set_rng
(
0
,
&
rng
);
auto
check
=
[
&
](
Reduce
::
Mode
mode
,
DType
src_dtype
,
DType
dst_dtype
,
auto
check
=
[
&
](
Reduce
::
Mode
mode
,
DType
src_dtype
,
DType
dst_dtype
,
Reduce
::
DataType
data_type
)
{
Reduce
::
DataType
data_type
)
{
for
(
int32_t
axis
:
{
0
,
1
,
2
,
3
})
{
for
(
int32_t
axis
:
{
0
,
1
,
2
,
3
})
{
...
...
imperative/python/test/unit/functional/test_math.py
浏览文件 @
89186edc
...
@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr):
...
@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr):
data2_shape
=
(
2
,
9
,
12
)
data2_shape
=
(
2
,
9
,
12
)
data1
=
np
.
random
.
random
(
data1_shape
).
astype
(
np
.
float32
)
data1
=
np
.
random
.
random
(
data1_shape
).
astype
(
np
.
float32
)
data2
=
np
.
random
.
random
(
data2_shape
).
astype
(
np
.
float32
)
data2
=
np
.
random
.
random
(
data2_shape
).
astype
(
np
.
float32
)
cases
=
[{
"input"
:
data1
},
{
"input"
:
data2
}]
cases
=
[
{
"input"
:
data1
},
{
"input"
:
data2
},
{
"input"
:
np
.
array
([[[
1
,
2
,
np
.
nan
,
4
],
[
8
,
6
,
5
,
2
],
[
2
,
3
,
4
,
5
]]])},
]
if
opr
not
in
(
F
.
argmin
,
F
.
argmax
):
if
opr
not
in
(
F
.
argmin
,
F
.
argmax
):
# test default axis
# test default axis
...
...
imperative/python/test/unit/quantization/test_fake_quant.py
浏览文件 @
89186edc
...
@@ -143,6 +143,11 @@ def test_fakequant():
...
@@ -143,6 +143,11 @@ def test_fakequant():
assert
np
.
allclose
(
x
.
grad
.
numpy
(),
x1
.
grad
.
numpy
())
assert
np
.
allclose
(
x
.
grad
.
numpy
(),
x1
.
grad
.
numpy
())
assert
make_shape_tuple
(
x
.
grad
.
shape
)
==
make_shape_tuple
(
x1
.
grad
.
shape
)
assert
make_shape_tuple
(
x
.
grad
.
shape
)
==
make_shape_tuple
(
x1
.
grad
.
shape
)
# test nan
x
=
F
.
full
((
1
,
32
,
3
,
3
),
np
.
nan
)
y
=
fake_quant_tensor
(
x
,
qparams
).
numpy
()
assert
np
.
isnan
(
y
).
all
()
zero_point
=
tensor
([
1.0
],
dtype
=
np
.
float32
)
zero_point
=
tensor
([
1.0
],
dtype
=
np
.
float32
)
scale
=
tensor
([
4.0
],
dtype
=
np
.
float32
)
scale
=
tensor
([
4.0
],
dtype
=
np
.
float32
)
run
(
zero_point
,
scale
)
run
(
zero_point
,
scale
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录