Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
78add057
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
78add057
编写于
10月 13, 2022
作者:
zhouweiwei2014
提交者:
GitHub
10月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Zero-Dim] support 0D for paddle.transpose/reshape/stack/tile/unsqueeze (#46555)
上级
19438131
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
303 addition
and
2 deletion
+303
-2
paddle/fluid/operators/transpose_op.h
paddle/fluid/operators/transpose_op.h
+4
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+2
-2
paddle/phi/kernels/cpu/transpose_kernel.cc
paddle/phi/kernels/cpu/transpose_kernel.cc
+3
-0
paddle/phi/kernels/gpu/transpose_kernel.cu
paddle/phi/kernels/gpu/transpose_kernel.cu
+4
-0
paddle/phi/kernels/impl/tile_kernel_impl.h
paddle/phi/kernels/impl/tile_kernel_impl.h
+8
-0
python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py
...paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py
+10
-0
python/paddle/fluid/tests/unittests/npu/test_transpose_op_npu.py
...paddle/fluid/tests/unittests/npu/test_transpose_op_npu.py
+6
-0
python/paddle/fluid/tests/unittests/test_reshape_op.py
python/paddle/fluid/tests/unittests/test_reshape_op.py
+76
-0
python/paddle/fluid/tests/unittests/test_stack_op.py
python/paddle/fluid/tests/unittests/test_stack_op.py
+29
-0
python/paddle/fluid/tests/unittests/test_tile_op.py
python/paddle/fluid/tests/unittests/test_tile_op.py
+51
-0
python/paddle/fluid/tests/unittests/test_transpose_op.py
python/paddle/fluid/tests/unittests/test_transpose_op.py
+25
-0
python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py
+54
-0
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+24
-0
python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py
...paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py
+7
-0
未找到文件。
paddle/fluid/operators/transpose_op.h
浏览文件 @
78add057
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
@@ -32,6 +33,9 @@ inline void TransCompute(const int dim,
...
@@ -32,6 +33,9 @@ inline void TransCompute(const int dim,
phi
::
DenseTensor
*
out
,
phi
::
DenseTensor
*
out
,
const
std
::
vector
<
int
>&
axis
)
{
const
std
::
vector
<
int
>&
axis
)
{
switch
(
dim
)
{
switch
(
dim
)
{
case
0
:
phi
::
Copy
<
DeviceContext
>
(
dev_ctx
,
in
,
dev_ctx
.
GetPlace
(),
false
,
out
);
break
;
case
1
:
case
1
:
phi
::
funcs
::
Transpose
<
DeviceContext
,
T
,
1
>
trans1
;
phi
::
funcs
::
Transpose
<
DeviceContext
,
T
,
1
>
trans1
;
trans1
(
dev_ctx
,
in
,
out
,
axis
);
trans1
(
dev_ctx
,
in
,
out
,
axis
);
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
78add057
...
@@ -3713,7 +3713,7 @@ void TileInferMeta(const MetaTensor& x,
...
@@ -3713,7 +3713,7 @@ void TileInferMeta(const MetaTensor& x,
repeat_times_data
.
size
()));
repeat_times_data
.
size
()));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
repeat_times_data
.
size
(),
repeat_times_data
.
size
(),
1
,
0
,
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The size of the shape of input 'repeat_times' for tile op "
"The size of the shape of input 'repeat_times' for tile op "
"must be positive integers, but the value received is %d."
,
"must be positive integers, but the value received is %d."
,
...
@@ -3746,7 +3746,7 @@ void TileInferMeta(const MetaTensor& x,
...
@@ -3746,7 +3746,7 @@ void TileInferMeta(const MetaTensor& x,
}
}
out
->
set_dims
(
phi
::
make_ddim
(
out_shape
));
out
->
set_dims
(
phi
::
make_ddim
(
out_shape
));
if
(
out_
shape
[
0
]
==
x_dims
[
0
]
)
{
if
(
out_
rank
>
0
&&
(
out_shape
[
0
]
==
x_dims
[
0
])
)
{
out
->
share_lod
(
x
);
out
->
share_lod
(
x
);
}
}
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dtype
(
x
.
dtype
());
...
...
paddle/phi/kernels/cpu/transpose_kernel.cc
浏览文件 @
78add057
...
@@ -35,6 +35,9 @@ void TransposeKernel(const Context& ctx,
...
@@ -35,6 +35,9 @@ void TransposeKernel(const Context& ctx,
}
}
int
rank
=
axis
.
size
();
int
rank
=
axis
.
size
();
switch
(
rank
)
{
switch
(
rank
)
{
case
0
:
phi
::
Copy
<
Context
>
(
ctx
,
x
,
ctx
.
GetPlace
(),
false
,
out
);
break
;
case
1
:
case
1
:
funcs
::
Transpose
<
Context
,
T
,
1
>
trans1
;
funcs
::
Transpose
<
Context
,
T
,
1
>
trans1
;
trans1
(
ctx
,
x
,
out
,
axis
);
trans1
(
ctx
,
x
,
out
,
axis
);
...
...
paddle/phi/kernels/gpu/transpose_kernel.cu
浏览文件 @
78add057
...
@@ -35,6 +35,10 @@ void TransposeKernel(const Context& ctx,
...
@@ -35,6 +35,10 @@ void TransposeKernel(const Context& ctx,
if
(
out
->
numel
()
==
0
)
{
if
(
out
->
numel
()
==
0
)
{
return
;
return
;
}
}
if
(
axis
.
size
()
==
0
)
{
phi
::
Copy
<
Context
>
(
ctx
,
x
,
ctx
.
GetPlace
(),
false
,
out
);
return
;
}
paddle
::
operators
::
TransposeGPUKernelDriver
<
T
>
(
ctx
,
x
,
axis
,
out
);
paddle
::
operators
::
TransposeGPUKernelDriver
<
T
>
(
ctx
,
x
,
axis
,
out
);
}
}
...
...
paddle/phi/kernels/impl/tile_kernel_impl.h
浏览文件 @
78add057
...
@@ -54,6 +54,10 @@ void Tile(const Context& dev_ctx,
...
@@ -54,6 +54,10 @@ void Tile(const Context& dev_ctx,
vec_x_dims
.
size
(),
vec_x_dims
.
size
(),
repeat_times
.
size
()));
repeat_times
.
size
()));
if
(
Rank
==
0
)
{
phi
::
Copy
<
DeviceContext
>
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
return
;
}
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
Rank
>
bcast_dims
;
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
Rank
>
bcast_dims
;
for
(
size_t
i
=
0
;
i
<
repeat_times
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
repeat_times
.
size
();
++
i
)
{
bcast_dims
[
i
]
=
repeat_times
[
i
];
bcast_dims
[
i
]
=
repeat_times
[
i
];
...
@@ -71,6 +75,7 @@ void Tile(const Context& dev_ctx,
...
@@ -71,6 +75,7 @@ void Tile(const Context& dev_ctx,
auto
eigen_out
=
EigenTensor
<
T
,
Rank
>::
From
(
*
out
,
out_dims
);
auto
eigen_out
=
EigenTensor
<
T
,
Rank
>::
From
(
*
out
,
out_dims
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
// use 32-bit index to speed up
// use 32-bit index to speed up
bool
use_32bit_index
=
eigen_out
.
size
()
<
Eigen
::
NumTraits
<
int
>::
highest
();
bool
use_32bit_index
=
eigen_out
.
size
()
<
Eigen
::
NumTraits
<
int
>::
highest
();
if
(
use_32bit_index
)
{
if
(
use_32bit_index
)
{
...
@@ -93,6 +98,9 @@ void TileKernel(const Context& dev_ctx,
...
@@ -93,6 +98,9 @@ void TileKernel(const Context& dev_ctx,
rank
=
std
::
max
(
rank
,
repeat_times_size
);
rank
=
std
::
max
(
rank
,
repeat_times_size
);
switch
(
rank
)
{
switch
(
rank
)
{
case
0
:
Tile
<
Context
,
T
,
0
>
(
dev_ctx
,
x
,
repeat_times_data
,
out
);
break
;
case
1
:
case
1
:
Tile
<
Context
,
T
,
1
>
(
dev_ctx
,
x
,
repeat_times_data
,
out
);
Tile
<
Context
,
T
,
1
>
(
dev_ctx
,
x
,
repeat_times_data
,
out
);
break
;
break
;
...
...
python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py
浏览文件 @
78add057
...
@@ -78,5 +78,15 @@ class TestCase2(TestBase):
...
@@ -78,5 +78,15 @@ class TestCase2(TestBase):
self
.
attrs
=
{
"perm"
:
[
4
,
0
,
2
,
3
,
1
]}
self
.
attrs
=
{
"perm"
:
[
4
,
0
,
2
,
3
,
1
]}
class
TestCase_ZeroDim
(
TestBase
):
def
set_data_feed
(
self
):
data
=
np
.
random
.
uniform
(
size
=
[])
self
.
feed_fp32
=
{
"x"
:
data
.
astype
(
np
.
float32
)}
def
set_op_attrs
(
self
):
self
.
attrs
=
{
"perm"
:
[]}
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/npu/test_transpose_op_npu.py
浏览文件 @
78add057
...
@@ -54,6 +54,12 @@ class TestTransposeOp(OpTest):
...
@@ -54,6 +54,12 @@ class TestTransposeOp(OpTest):
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
],
'Out'
)
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
],
'Out'
)
class
TestCase_ZeroDim
(
TestTransposeOp
):
def
init_shape_axis
(
self
):
self
.
shape
=
()
self
.
axis
=
()
class
TestCase0
(
TestTransposeOp
):
class
TestCase0
(
TestTransposeOp
):
def
init_shape_axis
(
self
):
def
init_shape_axis
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_reshape_op.py
浏览文件 @
78add057
...
@@ -46,6 +46,30 @@ class TestReshapeOp(OpTest):
...
@@ -46,6 +46,30 @@ class TestReshapeOp(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOp_ZeroDim1
(
OpTest
):
def
init_data
(
self
):
self
.
ori_shape
=
()
self
.
new_shape
=
(
1
)
self
.
infered_shape
=
(
1
)
class
TestReshapeOp_ZeroDim2
(
OpTest
):
def
init_data
(
self
):
self
.
ori_shape
=
(
1
)
self
.
new_shape
=
()
self
.
infered_shape
=
()
class
TestReshapeOp_ZeroDim3
(
OpTest
):
def
init_data
(
self
):
self
.
ori_shape
=
()
self
.
new_shape
=
(
-
1
)
self
.
infered_shape
=
(
1
)
class
TestReshapeBF16Op
(
OpTest
):
class
TestReshapeBF16Op
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -526,6 +550,58 @@ class TestReshapeZeroTensor(unittest.TestCase):
...
@@ -526,6 +550,58 @@ class TestReshapeZeroTensor(unittest.TestCase):
zero_tensor
.
reshape
([
2
,
3
])
zero_tensor
.
reshape
([
2
,
3
])
class
TestReshapeAPI_ZeroDim
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
reshape
(
x
,
[
1
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
])
out
=
paddle
.
reshape
(
x
,
[
-
1
,
1
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
,
1
])
paddle
.
enable_static
()
def
test_static
(
self
):
main_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
fluid
.
Program
()):
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
reshape
(
x
,
[
-
1
])
fluid
.
backward
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
block
=
prog
.
global_block
()
x_grad
=
block
.
var
(
fluid
.
framework
.
grad_var_name
(
x
.
name
))
out_grad
=
block
.
var
(
fluid
.
framework
.
grad_var_name
(
out
.
name
))
# Test compile shape
self
.
assertEqual
(
x
.
shape
,
())
self
.
assertEqual
(
out
.
shape
,
(
1
,
))
self
.
assertEqual
(
x_grad
.
shape
,
())
self
.
assertEqual
(
out_grad
.
shape
,
(
1
,
))
exe
=
fluid
.
Executor
()
result
=
exe
.
run
(
main_prog
,
fetch_list
=
[
x
,
out
,
x_grad
,
out_grad
])
# Test runtime shape
self
.
assertEqual
(
result
[
0
].
shape
,
())
self
.
assertEqual
(
result
[
1
].
shape
,
(
1
,
))
self
.
assertEqual
(
result
[
2
].
shape
,
())
self
.
assertEqual
(
result
[
3
].
shape
,
(
1
,
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_stack_op.py
浏览文件 @
78add057
...
@@ -19,6 +19,8 @@ import paddle.fluid as fluid
...
@@ -19,6 +19,8 @@ import paddle.fluid as fluid
from
op_test
import
OpTest
,
convert_float_to_uint16
from
op_test
import
OpTest
,
convert_float_to_uint16
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
paddle
.
enable_static
()
class
TestStackOpBase
(
OpTest
):
class
TestStackOpBase
(
OpTest
):
...
@@ -99,6 +101,12 @@ class TestStackOp6(TestStackOpBase):
...
@@ -99,6 +101,12 @@ class TestStackOp6(TestStackOpBase):
self
.
axis
=
3
self
.
axis
=
3
class
TestStackOp_ZeroDim
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
input_dim
=
()
class
TestStackBF16Op
(
OpTest
):
class
TestStackBF16Op
(
OpTest
):
def
initDefaultParameters
(
self
):
def
initDefaultParameters
(
self
):
...
@@ -293,5 +301,26 @@ class TestStackOpWithNegativeShape(unittest.TestCase):
...
@@ -293,5 +301,26 @@ class TestStackOpWithNegativeShape(unittest.TestCase):
rtol
=
1e-05
)
rtol
=
1e-05
)
class
TestStackAPI_ZeroDim
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
x1
=
paddle
.
rand
([])
x2
=
paddle
.
rand
([])
x1
.
stop_gradient
=
False
x2
.
stop_gradient
=
False
out
=
paddle
.
stack
([
x1
,
x2
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
2
])
self
.
assertEqual
(
x1
.
grad
.
shape
,
[])
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
2
])
paddle
.
enable_static
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_tile_op.py
浏览文件 @
78add057
...
@@ -46,6 +46,27 @@ class TestTileOpRank1(OpTest):
...
@@ -46,6 +46,27 @@ class TestTileOpRank1(OpTest):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
)
class
TestTileOpRank_ZeroDim1
(
TestTileOpRank1
):
def
init_data
(
self
):
self
.
ori_shape
=
[]
self
.
repeat_times
=
[]
class
TestTileOpRank_ZeroDim2
(
TestTileOpRank1
):
def
init_data
(
self
):
self
.
ori_shape
=
[]
self
.
repeat_times
=
[
2
]
class
TestTileOpRank_ZeroDim3
(
TestTileOpRank1
):
def
init_data
(
self
):
self
.
ori_shape
=
[]
self
.
repeat_times
=
[
2
,
3
]
# with dimension expanding
# with dimension expanding
class
TestTileOpRank2Expanding
(
TestTileOpRank1
):
class
TestTileOpRank2Expanding
(
TestTileOpRank1
):
...
@@ -338,6 +359,36 @@ class TestTileTripleGradCheck(unittest.TestCase):
...
@@ -338,6 +359,36 @@ class TestTileTripleGradCheck(unittest.TestCase):
self
.
func
(
p
)
self
.
func
(
p
)
class
TestTileAPI_ZeroDim
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
tile
(
x
,
[])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
out
=
paddle
.
tile
(
x
,
[
3
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
3
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
3
])
out
=
paddle
.
tile
(
x
,
[
2
,
3
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
2
,
3
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
2
,
3
])
paddle
.
enable_static
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_transpose_op.py
浏览文件 @
78add057
...
@@ -127,6 +127,13 @@ class TestCase9(TestTransposeOp):
...
@@ -127,6 +127,13 @@ class TestCase9(TestTransposeOp):
self
.
axis
=
(
6
,
1
,
3
,
5
,
0
,
2
,
4
,
7
)
self
.
axis
=
(
6
,
1
,
3
,
5
,
0
,
2
,
4
,
7
)
class
TestCase_ZeroDim
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
()
self
.
axis
=
()
class
TestAutoTuneTransposeOp
(
OpTest
):
class
TestAutoTuneTransposeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -601,6 +608,24 @@ class TestTransposeTripleGradCheck(unittest.TestCase):
...
@@ -601,6 +608,24 @@ class TestTransposeTripleGradCheck(unittest.TestCase):
self
.
func
(
p
)
self
.
func
(
p
)
class
TestTransposeAPI_ZeroDim
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
transpose
(
x
,
[])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
paddle
.
enable_static
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py
浏览文件 @
78add057
...
@@ -89,6 +89,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
...
@@ -89,6 +89,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self
.
new_shape
=
(
10
,
1
,
1
,
2
,
5
,
1
)
self
.
new_shape
=
(
10
,
1
,
1
,
2
,
5
,
1
)
class
TestUnsqueezeOp_ZeroDim1
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
-
1
,
)
self
.
new_shape
=
(
1
)
class
TestUnsqueezeOp_ZeroDim2
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
-
1
,
1
)
self
.
new_shape
=
(
1
,
1
)
class
TestUnsqueezeOp_ZeroDim3
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
0
,
1
,
2
)
self
.
new_shape
=
(
1
,
1
,
1
)
# axes is a list(with tensor)
# axes is a list(with tensor)
class
TestUnsqueezeOp_AxesTensorList
(
OpTest
):
class
TestUnsqueezeOp_AxesTensorList
(
OpTest
):
...
@@ -284,5 +308,35 @@ class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
...
@@ -284,5 +308,35 @@ class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
self
.
unsqueeze
=
paddle
.
unsqueeze_
self
.
unsqueeze
=
paddle
.
unsqueeze_
class
TestUnsqueezeAPI_ZeroDim
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
unsqueeze
(
x
,
[
-
1
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
])
out
=
paddle
.
unsqueeze
(
x
,
[
-
1
,
1
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
,
1
])
out
=
paddle
.
unsqueeze
(
x
,
[
0
,
1
,
2
])
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
,
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
,
1
,
1
])
paddle
.
enable_static
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
浏览文件 @
78add057
...
@@ -115,6 +115,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
...
@@ -115,6 +115,30 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self
.
new_shape
=
(
10
,
1
,
1
,
2
,
5
,
1
)
self
.
new_shape
=
(
10
,
1
,
1
,
2
,
5
,
1
)
class
TestUnsqueezeOp_ZeroDim1
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
-
1
,
)
self
.
new_shape
=
(
1
)
class
TestUnsqueezeOp_ZeroDim2
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
-
1
,
1
)
self
.
new_shape
=
(
1
,
1
)
class
TestUnsqueezeOp_ZeroDim3
(
TestUnsqueezeOp
):
def
init_test_case
(
self
):
self
.
ori_shape
=
()
self
.
axes
=
(
0
,
1
,
2
)
self
.
new_shape
=
(
1
,
1
,
1
)
class
API_TestUnsqueeze
(
unittest
.
TestCase
):
class
API_TestUnsqueeze
(
unittest
.
TestCase
):
def
test_out
(
self
):
def
test_out
(
self
):
...
...
python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py
浏览文件 @
78add057
...
@@ -60,6 +60,13 @@ class TestXPUTransposeOp(XPUOpTest):
...
@@ -60,6 +60,13 @@ class TestXPUTransposeOp(XPUOpTest):
self
.
axis
=
(
1
,
0
)
self
.
axis
=
(
1
,
0
)
class
TestCase_ZeroDim
(
TestXPUTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
()
self
.
axis
=
()
class
TestCase0
(
TestXPUTransposeOp
):
class
TestCase0
(
TestXPUTransposeOp
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录