Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
59c7aea5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
59c7aea5
编写于
2月 10, 2022
作者:
Z
zhangbo9674
提交者:
GitHub
2月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[bf16] add bf16 kernel: squeeze & unsqueeze & stack (#39402)
* add squeeze unsqueeze stack * add unittest * add cpu kernel
上级
e8ac7fc3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
153 addition
and
23 deletion
+153
-23
paddle/fluid/operators/squeeze_op.cc
paddle/fluid/operators/squeeze_op.cc
+12
-4
paddle/fluid/operators/squeeze_op.cu.cc
paddle/fluid/operators/squeeze_op.cu.cc
+5
-0
paddle/fluid/operators/stack_op.cc
paddle/fluid/operators/stack_op.cc
+13
-10
paddle/fluid/operators/stack_op.cu
paddle/fluid/operators/stack_op.cu
+4
-2
paddle/fluid/operators/unsqueeze_op.cc
paddle/fluid/operators/unsqueeze_op.cc
+12
-4
paddle/fluid/operators/unsqueeze_op.cu.cc
paddle/fluid/operators/unsqueeze_op.cu.cc
+6
-0
python/paddle/fluid/tests/unittests/test_squeeze_op.py
python/paddle/fluid/tests/unittests/test_squeeze_op.py
+28
-1
python/paddle/fluid/tests/unittests/test_stack_op.py
python/paddle/fluid/tests/unittests/test_stack_op.py
+45
-1
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+28
-1
未找到文件。
paddle/fluid/operators/squeeze_op.cc
浏览文件 @
59c7aea5
...
...
@@ -393,7 +393,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
SqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
squeeze_grad
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
@@ -406,7 +408,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
squeeze2
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
@@ -419,7 +423,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
squeeze2_grad
,
...
...
@@ -433,4 +439,6 @@ REGISTER_OP_CPU_KERNEL(
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
paddle/fluid/operators/squeeze_op.cu.cc
浏览文件 @
59c7aea5
...
...
@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
...
...
@@ -35,6 +36,7 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
...
...
@@ -48,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze2
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
Squeeze2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
int8_t
>
,
...
...
@@ -62,6 +65,8 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
Squeeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int8_t
>
,
...
...
paddle/fluid/operators/stack_op.cc
浏览文件 @
59c7aea5
...
...
@@ -173,13 +173,16 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops
::
StackGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
);
REGISTER_OP_CPU_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
int
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
int
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
int
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
int
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
paddle/fluid/operators/stack_op.cu
浏览文件 @
59c7aea5
...
...
@@ -196,10 +196,12 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL
(
stack
,
ops
::
StackGPUKernel
<
float
>
,
ops
::
StackGPUKernel
<
double
>
,
ops
::
StackGPUKernel
<
int
>
,
ops
::
StackGPUKernel
<
int64_t
>
,
ops
::
StackGPUKernel
<
plat
::
float16
>
);
ops
::
StackGPUKernel
<
plat
::
float16
>
,
ops
::
StackGPUKernel
<
plat
::
bfloat16
>
);
REGISTER_OP_CUDA_KERNEL
(
stack_grad
,
ops
::
StackGradGPUKernel
<
float
>
,
ops
::
StackGradGPUKernel
<
double
>
,
ops
::
StackGradGPUKernel
<
int
>
,
ops
::
StackGradGPUKernel
<
int64_t
>
,
ops
::
StackGradGPUKernel
<
plat
::
float16
>
);
ops
::
StackGradGPUKernel
<
plat
::
float16
>
,
ops
::
StackGradGPUKernel
<
plat
::
bfloat16
>
);
paddle/fluid/operators/unsqueeze_op.cc
浏览文件 @
59c7aea5
...
...
@@ -366,7 +366,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
unsqueeze_grad
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
@@ -379,7 +381,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
unsqueeze2
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
...
...
@@ -391,7 +395,9 @@ REGISTER_OP_CPU_KERNEL(
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
unsqueeze2_grad
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
@@ -404,4 +410,6 @@ REGISTER_OP_CPU_KERNEL(
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
paddle/fluid/operators/unsqueeze_op.cu.cc
浏览文件 @
59c7aea5
...
...
@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
...
...
@@ -36,6 +37,8 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int8_t
>
,
...
...
@@ -50,6 +53,7 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
...
...
@@ -65,6 +69,8 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
Unsqueeze2GradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
...
...
python/paddle/fluid/tests/unittests/test_squeeze_op.py
浏览文件 @
59c7aea5
...
...
@@ -20,7 +20,8 @@ import numpy as np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
,
Program
,
program_guard
from
op_test
import
OpTest
from
op_test
import
OpTest
,
convert_float_to_uint16
import
paddle.fluid.core
as
core
paddle
.
enable_static
()
...
...
@@ -49,6 +50,32 @@ class TestSqueezeOp(OpTest):
self
.
attrs
=
{
"axes"
:
self
.
axes
}
class
TestSqueezeBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"squeeze"
self
.
dtype
=
np
.
uint16
self
.
init_test_case
()
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)
out
=
x
.
reshape
(
self
.
new_shape
)
self
.
inputs
=
{
"X"
:
convert_float_to_uint16
(
x
)}
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
convert_float_to_uint16
(
out
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
def
init_test_case
(
self
):
self
.
ori_shape
=
(
1
,
3
,
1
,
40
)
self
.
axes
=
(
0
,
2
)
self
.
new_shape
=
(
3
,
40
)
def
init_attrs
(
self
):
self
.
attrs
=
{
"axes"
:
self
.
axes
}
# Correct: There is mins axis.
class
TestSqueezeOp1
(
TestSqueezeOp
):
def
init_test_case
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_stack_op.py
浏览文件 @
59c7aea5
...
...
@@ -16,7 +16,8 @@ import numpy as np
import
unittest
import
paddle
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
op_test
import
OpTest
,
convert_float_to_uint16
import
paddle.fluid.core
as
core
class
TestStackOpBase
(
OpTest
):
...
...
@@ -90,6 +91,49 @@ class TestStackOp6(TestStackOpBase):
self
.
axis
=
3
class
TestStackBF16Op
(
OpTest
):
def
initDefaultParameters
(
self
):
self
.
num_inputs
=
4
self
.
input_dim
=
(
5
,
6
,
7
)
self
.
axis
=
0
self
.
dtype
=
np
.
uint16
def
initParameters
(
self
):
pass
def
get_x_names
(
self
):
x_names
=
[]
for
i
in
range
(
self
.
num_inputs
):
x_names
.
append
(
'x{}'
.
format
(
i
))
return
x_names
def
setUp
(
self
):
self
.
initDefaultParameters
()
self
.
initParameters
()
self
.
op_type
=
'stack'
self
.
x
=
[]
for
i
in
range
(
self
.
num_inputs
):
self
.
x
.
append
(
np
.
random
.
random
(
size
=
self
.
input_dim
).
astype
(
np
.
float32
))
out
=
np
.
stack
(
self
.
x
,
axis
=
self
.
axis
)
tmp
=
[]
x_names
=
self
.
get_x_names
()
for
i
in
range
(
self
.
num_inputs
):
tmp
.
append
((
x_names
[
i
],
convert_float_to_uint16
(
self
.
x
[
i
])))
self
.
inputs
=
{
'X'
:
tmp
}
self
.
outputs
=
{
'Y'
:
convert_float_to_uint16
(
out
)}
self
.
attrs
=
{
'axis'
:
self
.
axis
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
get_x_names
(),
'Y'
)
class
TestStackAPIWithLoDTensorArray
(
unittest
.
TestCase
):
"""
Test stack api when the input(x) is a LoDTensorArray.
...
...
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
浏览文件 @
59c7aea5
...
...
@@ -19,7 +19,8 @@ import numpy as np
import
paddle
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
op_test
import
OpTest
,
convert_float_to_uint16
import
paddle.fluid.core
as
core
paddle
.
enable_static
()
...
...
@@ -48,6 +49,32 @@ class TestUnsqueezeOp(OpTest):
self
.
attrs
=
{
"axes"
:
self
.
axes
}
class
TestUnsqueezeBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
init_test_case
()
self
.
op_type
=
"unsqueeze"
self
.
dtype
=
np
.
uint16
x
=
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)
out
=
x
.
reshape
(
self
.
new_shape
)
self
.
inputs
=
{
"X"
:
convert_float_to_uint16
(
x
)}
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
convert_float_to_uint16
(
out
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
def
init_test_case
(
self
):
self
.
ori_shape
=
(
3
,
40
)
self
.
axes
=
(
1
,
2
)
self
.
new_shape
=
(
3
,
1
,
1
,
40
)
def
init_attrs
(
self
):
self
.
attrs
=
{
"axes"
:
self
.
axes
}
# Correct: Single input index.
class
TestUnsqueezeOp1
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录