Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
05c2b9ba
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
05c2b9ba
编写于
10月 12, 2022
作者:
zhouweiwei2014
提交者:
GitHub
10月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Zero-Dim] support input 0D Tensor for some unary api (#45992)
* [Zero-Dim] support input 0D Tensor for unary api * fix CI
上级
95768115
变更
28
展开全部
隐藏空白更改
内联
并排
Showing
28 changed file
with
905 addition
and
173 deletion
+905
-173
paddle/fluid/framework/details/fetch_async_op_handle.cc
paddle/fluid/framework/details/fetch_async_op_handle.cc
+29
-13
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+2
-0
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+3
-2
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+29
-10
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+2
-2
paddle/phi/core/utils/dim.h
paddle/phi/core/utils/dim.h
+8
-3
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+7
-2
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+2
-2
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+2
-2
paddle/phi/kernels/funcs/unsqueeze.h
paddle/phi/kernels/funcs/unsqueeze.h
+18
-0
paddle/phi/kernels/onednn/reduce_kernel_impl.h
paddle/phi/kernels/onednn/reduce_kernel_impl.h
+1
-2
paddle/phi/tests/core/test_ddim.cc
paddle/phi/tests/core/test_ddim.cc
+66
-26
python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py
...ddle/distributed/fleet/utils/hybrid_parallel_inference.py
+4
-2
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+1
-0
python/paddle/fluid/dygraph/math_op_patch.py
python/paddle/fluid/dygraph/math_op_patch.py
+1
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+28
-19
python/paddle/fluid/layers/utils.py
python/paddle/fluid/layers/utils.py
+0
-3
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_inference_helper.py
...ests/collective/fleet/hybrid_parallel_inference_helper.py
+1
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py
...addle/fluid/tests/unittests/dygraph_to_static/test_len.py
+1
-0
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+31
-12
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+504
-52
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
+0
-6
python/paddle/fluid/tests/unittests/test_full_op.py
python/paddle/fluid/tests/unittests/test_full_op.py
+0
-6
python/paddle/fluid/tests/unittests/test_mse_loss.py
python/paddle/fluid/tests/unittests/test_mse_loss.py
+3
-3
python/paddle/fluid/tests/unittests/test_randn_op.py
python/paddle/fluid/tests/unittests/test_randn_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_select_input_output_op.py
...ddle/fluid/tests/unittests/test_select_input_output_op.py
+5
-1
python/paddle/fluid/tests/unittests/test_set_value_op.py
python/paddle/fluid/tests/unittests/test_set_value_op.py
+0
-2
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
+157
-0
未找到文件。
paddle/fluid/framework/details/fetch_async_op_handle.cc
浏览文件 @
05c2b9ba
...
...
@@ -164,24 +164,32 @@ void FetchAsyncOpHandle::FetchMergedLodTensor(
}
}
bool
find_first_dims
=
false
;
for
(
auto
*
t
:
src_lodtensors
)
{
if
(
t
->
numel
()
&&
t
->
IsInitialized
())
{
if
(
!
find_first_dims
)
{
new_dim
=
t
->
dims
();
find_first_dims
=
true
;
}
else
{
new_dim
[
0
]
+=
t
->
dims
()[
0
];
}
}
}
// check src type,layout,dim,lod consistence
for
(
size_t
i
=
1
;
i
<
src_lodtensors
.
size
();
++
i
)
{
CheckTensorAttrs
(
src_lodtensors
[
i
],
new_type
,
new_layout
,
check_dim
,
new_lod
,
offset_
);
}
auto
rank
=
src_lodtensors
[
0
]
->
dims
().
size
();
// for 0D tensor, can't concat eath tensor. So stack 0D and concat 1+D tensor
if
(
rank
==
0
)
{
int
src_lodtensor_size
=
src_lodtensors
.
size
();
new_dim
=
phi
::
make_ddim
(
std
::
vector
<
int
>
({
src_lodtensor_size
}));
}
else
{
bool
find_first_dims
=
false
;
for
(
auto
*
t
:
src_lodtensors
)
{
if
(
t
->
numel
()
&&
t
->
IsInitialized
())
{
if
(
!
find_first_dims
)
{
new_dim
=
t
->
dims
();
find_first_dims
=
true
;
}
else
{
new_dim
[
0
]
+=
t
->
dims
()[
0
];
}
}
}
}
// set dst tensor
dst_lodtensor
->
Resize
(
new_dim
);
dst_lodtensor
->
set_layout
(
src_lodtensors
[
0
]
->
layout
());
...
...
@@ -195,9 +203,17 @@ void FetchAsyncOpHandle::FetchMergedLodTensor(
}
// slice and memcpy
// for 0D tensor, can't concat eath tensor, stack them. for 1+D tensor, concat
// them
int
begin
=
0
;
int
end
=
0
;
for
(
auto
*
src
:
src_lodtensors
)
{
int
end
=
begin
+
src
->
dims
()[
0
];
if
(
rank
==
0
)
{
end
=
begin
+
1
;
}
else
{
end
=
begin
+
src
->
dims
()[
0
];
}
if
(
end
==
begin
)
{
continue
;
}
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
05c2b9ba
...
...
@@ -16,6 +16,7 @@
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace
phi
{
...
...
@@ -101,6 +102,7 @@ std::string ScaleLossGradOpHandle::LossGradName() const {
void
ScaleLossGradOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
record_event
(
Name
(),
platform
::
TracerEventType
::
UserDefined
,
2
);
RunOnVar
(
local_exec_scopes_
[
0
]
->
FindVar
(
LossGradName
()),
true
);
}
...
...
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
05c2b9ba
...
...
@@ -213,8 +213,9 @@ DDim CompatMetaTensor::dims() const {
}
else
{
auto
*
var
=
PADDLE_GET_CONST
(
VarDesc
*
,
var_
);
return
var
->
GetShape
().
empty
()
?
phi
::
make_ddim
({
0UL
})
:
phi
::
make_ddim
(
var
->
GetShape
());
return
phi
::
make_ddim
(
var
->
GetShape
());
// return var->GetShape().empty() ? phi::make_ddim({0UL}) :
// phi::make_ddim(var->GetShape());
}
}
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
05c2b9ba
...
...
@@ -262,6 +262,16 @@ std::vector<LoDTensor> SplitLoDTensor(
platform
::
errors
::
InvalidArgument
(
"Place number cannot be empty when splitting."
));
src
.
check_memory_size
();
auto
rank
=
src
.
dims
().
size
();
// if rank is 0, just return #places.size() copys of src
if
(
rank
==
0
)
{
LoDTensor
dst
;
framework
::
TensorCopy
(
src
,
src
.
place
(),
&
dst
);
std
::
vector
<
LoDTensor
>
ret
;
ret
.
emplace_back
(
std
::
move
(
dst
));
return
ret
;
}
size_t
batch_size
=
src
.
lod
().
empty
()
?
static_cast
<
size_t
>
(
src
.
dims
()[
0
])
:
src
.
lod
()[
0
].
size
()
-
1
;
...
...
@@ -349,6 +359,7 @@ void MergeLoDTensor(LoDTensor *target,
}
LoD
new_lod
=
lod_tensors
[
0
]
->
lod
();
auto
rank
=
lod_tensors
[
0
]
->
dims
().
size
();
for
(
size_t
i
=
1
;
i
<
lod_tensors
.
size
();
++
i
)
{
auto
*
t
=
lod_tensors
[
i
];
...
...
@@ -369,16 +380,24 @@ void MergeLoDTensor(LoDTensor *target,
"actual layout is %s."
,
DataLayoutToString
(
new_layout
),
DataLayoutToString
(
t
->
layout
())));
PADDLE_ENFORCE_EQ
(
phi
::
product
(
new_dim
)
/
new_dim
[
0
],
phi
::
product
(
t
->
dims
())
/
t
->
dims
()[
0
],
platform
::
errors
::
InvalidArgument
(
"LoDTensor dimension does not match, all dimensions except the "
"first dimension need to be equal,"
"but expected dimension is %s, actual dimension is %s."
,
new_dim
,
t
->
dims
()));
new_dim
[
0
]
+=
t
->
dims
()[
0
];
auto
tensor_dims
=
t
->
dims
();
PADDLE_ENFORCE_EQ
(
tensor_dims
.
size
(),
new_dim
.
size
(),
platform
::
errors
::
InvalidArgument
(
"dimensions of LoDTensor does not match"
));
for
(
int
j
=
1
;
j
<
t
->
dims
().
size
();
j
++
)
{
PADDLE_ENFORCE_EQ
(
tensor_dims
[
j
],
new_dim
[
j
],
platform
::
errors
::
InvalidArgument
(
"LoDTensor.ddim[%d] should eaqual to %d, but is %d"
,
j
,
new_dim
[
j
],
tensor_dims
[
j
]));
}
if
(
rank
>
0
)
{
new_dim
[
0
]
+=
t
->
dims
()[
0
];
}
}
auto
&
lod
=
t
->
lod
();
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
05c2b9ba
...
...
@@ -362,7 +362,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
DDim
res
;
try
{
auto
shape
=
var
->
GetShape
();
res
=
shape
.
empty
()
?
phi
::
make_ddim
({
0UL
})
:
phi
::
make_ddim
(
shape
);
res
=
phi
::
make_ddim
(
shape
);
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
...
...
@@ -1258,7 +1258,7 @@ std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
try
{
auto
shapes
=
var
->
GetShapes
();
for
(
const
auto
&
s
:
shapes
)
{
res
.
push_back
(
s
.
empty
()
?
phi
::
make_ddim
({
0UL
})
:
phi
::
make_ddim
(
s
));
res
.
push_back
(
phi
::
make_ddim
(
s
));
}
}
catch
(...)
{
VLOG
(
5
)
<<
"GetRepeatedDim of variable "
<<
name
<<
" error."
;
...
...
paddle/phi/core/utils/dim.h
浏览文件 @
05c2b9ba
...
...
@@ -72,10 +72,15 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
// Allows us to output a Dim
template
<
int
D
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
D
>&
d
)
{
os
<<
d
[
0
];
for
(
int
i
=
1
;
i
<
D
;
++
i
)
{
os
<<
", "
<<
d
[
i
];
if
(
D
>
0
)
{
os
<<
d
[
0
];
for
(
int
i
=
1
;
i
<
D
;
++
i
)
{
os
<<
", "
<<
d
[
i
];
}
}
else
{
os
<<
""
;
}
return
os
;
}
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
05c2b9ba
...
...
@@ -305,9 +305,14 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
if
(
x
[
i
]
->
is_selected_rows
()
&&
x_dim
.
size
()
==
1
)
{
continue
;
}
// for zero-sized tensor
if
(
phi
::
product
(
x_dim
)
==
0
)
{
continue
;
}
// for 0D tensor
if
(
x_dim
.
size
()
==
0
)
{
continue
;
}
if
(
phi
::
product
(
in_dim
)
==
0
)
{
in_dim
=
x_dim
;
}
else
{
...
...
@@ -2547,8 +2552,8 @@ void WarpctcInferMeta(const MetaTensor& logits,
const
MetaTensor
&
labels_length
,
int
blank
,
bool
norm_by_times
,
MetaTensor
*
warpctcgrad
,
MetaTensor
*
loss
)
{
MetaTensor
*
loss
,
MetaTensor
*
warpctcgrad
)
{
auto
logits_dims
=
logits
.
dims
();
int
sequence_width
=
0
;
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
05c2b9ba
...
...
@@ -483,8 +483,8 @@ void WarpctcInferMeta(const MetaTensor& logits,
const
MetaTensor
&
labels_length
,
int
blank
,
bool
norm_by_times
,
MetaTensor
*
warpctcgrad
,
MetaTensor
*
loss
);
MetaTensor
*
loss
,
MetaTensor
*
warpctcgrad
);
void
WhereInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
05c2b9ba
...
...
@@ -2668,7 +2668,7 @@ DDim ReduceInferDim(const MetaTensor& x,
x_rank
,
errors
::
InvalidArgument
(
"The reduce dim index %d should be in the "
"range [
-dimension(X), dimension(X)]
"
"range [
-dimension(X), dimension(X) )
"
"which dimesion = %d. But received dim index = %d."
,
i
,
x_rank
,
...
...
@@ -2677,7 +2677,7 @@ DDim ReduceInferDim(const MetaTensor& x,
-
x_rank
,
errors
::
InvalidArgument
(
"The reduce dim index %d should be in the "
"range [
-dimension(X), dimension(X)]
"
"range [
-dimension(X), dimension(X) )
"
"which dimesion = %d. But received dim index = %d."
,
i
,
x_rank
,
...
...
paddle/phi/kernels/funcs/unsqueeze.h
浏览文件 @
05c2b9ba
...
...
@@ -36,6 +36,24 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,
}
}
else
{
for
(
size_t
i
=
0
;
i
<
num_squeeze_dims
;
++
i
)
{
if
(
in_dims
.
size
()
==
0
)
{
PADDLE_ENFORCE_GE
(
squeeze_dims
[
i
],
-
1
,
phi
::
errors
::
InvalidArgument
(
"For 0D Tensor, Each axis in Attr(axes) should be in the range "
"of [-1, 0]"
"But current axis is:%d, input tensor's shape = [%s]."
));
PADDLE_ENFORCE_LE
(
squeeze_dims
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"For 0D Tensor, Each axis in Attr(axes) should be in the range "
"of [-1, 0]"
"But current axis is:%d, input tensor's shape = [%s]."
));
continue
;
}
int
current
=
squeeze_dims
[
i
]
<
0
?
squeeze_dims
[
i
]
+
in_dims
.
size
()
:
squeeze_dims
[
i
];
...
...
paddle/phi/kernels/onednn/reduce_kernel_impl.h
浏览文件 @
05c2b9ba
...
...
@@ -25,8 +25,7 @@ inline std::vector<int64_t> CalculateReducedDims(
bool
keep_dim
)
{
if
(
keep_dim
)
return
vectorize
(
output
->
dims
());
if
(
reduce_all
&&
reduce_dims
.
size
()
>
0
)
return
std
::
vector
<
int64_t
>
(
input
->
dims
().
size
(),
1
);
if
(
reduce_all
)
return
std
::
vector
<
int64_t
>
(
input
->
dims
().
size
(),
1
);
std
::
vector
<
int64_t
>
output_dims
(
vectorize
(
input
->
dims
()));
for
(
size_t
i
=
0
;
i
<
reduce_dims
.
size
();
++
i
)
{
...
...
paddle/phi/tests/core/test_ddim.cc
浏览文件 @
05c2b9ba
...
...
@@ -21,18 +21,43 @@ namespace phi {
namespace
tests
{
TEST
(
DDim
,
Equality
)
{
// default construct ddim
phi
::
DDim
default_ddim
;
EXPECT_EQ
(
arity
(
default_ddim
),
1
);
EXPECT_EQ
(
default_ddim
[
0
],
0
);
// construct a zero-DDim
phi
::
DDim
zero_ddim
=
phi
::
make_ddim
({});
EXPECT_EQ
(
arity
(
zero_ddim
),
0
);
EXPECT_EQ
(
zero_ddim
.
size
(),
0
);
EXPECT_EQ
(
phi
::
product
(
zero_ddim
),
1
);
std
::
vector
<
int64_t
>
zero_vec
;
phi
::
DDim
zero_ddim1
=
phi
::
make_ddim
(
zero_vec
);
EXPECT_EQ
(
arity
(
zero_ddim1
),
0
);
EXPECT_EQ
(
zero_ddim1
.
size
(),
0
);
EXPECT_EQ
(
phi
::
product
(
zero_ddim1
),
1
);
// zero-DDim to vector
std
::
vector
<
int64_t
>
zero_ddim_vec
=
phi
::
vectorize
(
zero_ddim
);
EXPECT_EQ
(
zero_ddim_vec
.
size
(),
size_t
(
0
));
// reshape zero-DDim
std
::
vector
<
int
>
reshape_vec
=
{
1
};
phi
::
DDim
reshape_ddim
=
zero_ddim
.
reshape
(
reshape_vec
);
EXPECT_EQ
(
arity
(
reshape_ddim
),
1
);
EXPECT_EQ
(
reshape_ddim
.
size
(),
1
);
EXPECT_EQ
(
phi
::
product
(
reshape_ddim
),
1
);
// construct a DDim from an initialization list
phi
::
DDim
ddim
=
phi
::
make_ddim
({
9
,
1
,
5
});
EXPECT_EQ
(
ddim
[
0
],
9
);
EXPECT_EQ
(
ddim
[
1
],
1
);
EXPECT_EQ
(
ddim
[
2
],
5
);
// construct a DDim from a vector
std
::
vector
<
int64_t
>
vec
({
9
,
1
,
5
});
phi
::
DDim
vddim
=
phi
::
make_ddim
(
vec
);
EXPECT_EQ
(
ddim
[
0
],
9
);
EXPECT_EQ
(
ddim
[
1
],
1
);
EXPECT_EQ
(
ddim
[
2
],
5
);
// arity of a DDim
EXPECT_EQ
(
phi
::
arity
(
ddim
),
3
);
EXPECT_EQ
(
ddim
.
size
(),
3
);
// mutate a DDim
ddim
[
1
]
=
2
;
...
...
@@ -40,6 +65,13 @@ TEST(DDim, Equality) {
ddim
[
0
]
=
6
;
EXPECT_EQ
(
ddim
[
0
],
6
);
// construct a DDim from a vector
std
::
vector
<
int64_t
>
vec
({
9
,
1
,
5
});
phi
::
DDim
vddim
=
phi
::
make_ddim
(
vec
);
EXPECT_EQ
(
vddim
[
0
],
9
);
EXPECT_EQ
(
vddim
[
1
],
1
);
EXPECT_EQ
(
vddim
[
2
],
5
);
// vectorize a DDim
std
::
vector
<
int64_t
>
res_vec
=
phi
::
vectorize
(
vddim
);
EXPECT_EQ
(
res_vec
[
0
],
9
);
...
...
@@ -51,37 +83,45 @@ TEST(DDim, Equality) {
EXPECT_EQ
(
res_vec
[
1
],
2
);
EXPECT_EQ
(
res_vec
[
2
],
1
);
// arity of a DDim
EXPECT_EQ
(
phi
::
arity
(
ddim
),
3
);
EXPECT_EQ
(
ddim
.
size
(),
3
);
// product of a DDim
EXPECT_EQ
(
phi
::
product
(
vddim
),
45
);
EXPECT_EQ
(
phi
::
product
(
phi
::
make_ddim
({
3
,
2
,
5
,
3
})),
90
);
// slice a DDim
phi
::
DDim
ddim2
=
phi
::
make_ddim
({
1
,
2
,
3
,
4
,
5
,
6
});
phi
::
DDim
ss
=
phi
::
slice_ddim
(
ddim2
,
2
,
5
);
EXPECT_EQ
(
arity
(
ss
),
3
);
EXPECT_EQ
(
ss
[
0
],
3
);
EXPECT_EQ
(
ss
[
1
],
4
);
EXPECT_EQ
(
ss
[
2
],
5
);
phi
::
DDim
ss2
=
phi
::
slice_ddim
(
ddim2
,
0
,
6
);
EXPECT_EQ
(
arity
(
ss2
),
6
);
EXPECT_EQ
(
ss2
[
0
],
1
);
EXPECT_EQ
(
ss2
[
1
],
2
);
EXPECT_EQ
(
ss2
[
2
],
3
);
EXPECT_EQ
(
ss2
[
3
],
4
);
EXPECT_EQ
(
ss2
[
4
],
5
);
EXPECT_EQ
(
ss2
[
5
],
6
);
phi
::
DDim
slice_dim1
=
phi
::
slice_ddim
(
ddim2
,
2
,
5
);
EXPECT_EQ
(
arity
(
slice_dim1
),
3
);
EXPECT_EQ
(
slice_dim1
[
0
],
3
);
EXPECT_EQ
(
slice_dim1
[
1
],
4
);
EXPECT_EQ
(
slice_dim1
[
2
],
5
);
phi
::
DDim
slice_dim2
=
phi
::
slice_ddim
(
ddim2
,
0
,
6
);
EXPECT_EQ
(
arity
(
slice_dim2
),
6
);
EXPECT_EQ
(
slice_dim2
[
0
],
1
);
EXPECT_EQ
(
slice_dim2
[
1
],
2
);
EXPECT_EQ
(
slice_dim2
[
2
],
3
);
EXPECT_EQ
(
slice_dim2
[
3
],
4
);
EXPECT_EQ
(
slice_dim2
[
4
],
5
);
EXPECT_EQ
(
slice_dim2
[
5
],
6
);
phi
::
DDim
slice_dim3
=
phi
::
slice_ddim
(
ddim2
,
1
,
1
);
EXPECT_EQ
(
arity
(
slice_dim3
),
0
);
EXPECT_EQ
(
slice_dim3
.
size
(),
0
);
EXPECT_EQ
(
phi
::
product
(
slice_dim3
),
1
);
}
TEST
(
DDim
,
Print
)
{
// print a DDim
std
::
stringstream
ss
;
std
::
stringstream
ss
1
;
phi
::
DDim
ddim
=
phi
::
make_ddim
({
2
,
3
,
4
});
ss
<<
ddim
;
EXPECT_EQ
(
"2, 3, 4"
,
ss
.
str
());
ss1
<<
ddim
;
EXPECT_EQ
(
"2, 3, 4"
,
ss1
.
str
());
// print a zero-DDim
std
::
stringstream
ss2
;
phi
::
DDim
zero_ddim
=
phi
::
make_ddim
({});
ss2
<<
zero_ddim
;
EXPECT_EQ
(
""
,
ss2
.
str
());
}
}
// namespace tests
...
...
python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py
浏览文件 @
05c2b9ba
...
...
@@ -688,8 +688,10 @@ class HybridParallelInferenceHelper(object):
})
else
:
var_shape
=
list
(
var
.
shape
)
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
print
(
var_name
)
if
len
(
var
.
shape
)
>
0
:
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
block
.
_insert_op_without_sync
(
index
=
index
,
type
=
'recv_v2'
,
...
...
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
05c2b9ba
...
...
@@ -462,6 +462,7 @@ def convert_len(var):
`shape_op` in var.block.
"""
if
isinstance
(
var
,
Variable
):
assert
var
.
ndim
>
0
,
"len() of a 0D tensor is wrong"
if
var
.
type
in
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
...
...
python/paddle/fluid/dygraph/math_op_patch.py
浏览文件 @
05c2b9ba
...
...
@@ -144,6 +144,7 @@ def monkey_patch_math_varbase():
return
int
(
var
.
numpy
().
flatten
()[
0
])
def
_len_
(
var
):
assert
var
.
ndim
>
0
,
"len() of a 0D tensor is wrong"
if
var
.
type
==
core
.
VarDesc
.
VarType
.
VOCAB
:
return
len
(
var
.
value
().
get_map_tensor
())
elif
var
.
type
==
core
.
VarDesc
.
VarType
.
STRINGS
:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
05c2b9ba
...
...
@@ -208,6 +208,30 @@ OP_NAMEMAPPING = {
}
def _get_reduce_dim(dim, input):
"""
Internal function for reduce_sum, reduce_mean, reduce_max, reduce_min, reduce_prod.
It computes the attribute reduce_all value based on axis.
"""
if dim is not None and not isinstance(dim, list):
if isinstance(dim, (tuple, range)):
dim = list(dim)
elif isinstance(dim, int):
dim = [dim]
else:
raise TypeError(
"The type of dim must be int, list, tuple or range, but received {}"
.format(type(axis)))
if dim is None:
dim = []
if dim == [] or len(dim) == len(input.shape):
reduce_all = True
else:
reduce_all = False
return reduce_all, dim
@dygraph_only
def _elementwise_op_in_dygraph(x,
y,
...
...
@@ -4689,29 +4713,14 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
if dim is not None and not isinstance(dim, list):
dim = [dim]
reduce_all, dim = _get_reduce_dim(dim, input)
if in_dygraph_mode():
reduce_all = True if dim == None or dim == [] or len(dim) == len(
input.shape) else False
dim = dim if dim != None and dim != [] else [0]
if reduce_all:
return _C_ops.sum(input, [], None, keep_dim)
else:
return _C_ops.sum(input, dim, None, keep_dim)
return _C_ops.sum(input, dim, None, keep_dim)
elif _in_legacy_dygraph():
reduce_all = True if dim == None or dim == [] or len(dim) == len(
input.shape) else False
dim = dim if dim != None and dim != [] else [0]
return _legacy_C_ops.reduce_sum(input, 'dim', dim, 'keep_dim', keep_dim,
'reduce_all', reduce_all)
attrs = {
'dim':
dim if dim != None and dim != [] else [0],
'keep_dim':
keep_dim,
'reduce_all':
True
if dim == None or dim == [] or len(dim) == len(input.shape) else False
}
attrs = {'dim': dim, 'keep_dim': keep_dim, 'reduce_all': reduce_all}
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'reduce_sum')
...
...
python/paddle/fluid/layers/utils.py
浏览文件 @
05c2b9ba
...
...
@@ -363,9 +363,6 @@ def get_shape_tensor_inputs(inputs, attrs, shape, op_type):
shape
=
cast
(
shape
,
'int32'
)
inputs
[
"ShapeTensor"
]
=
shape
elif
isinstance
(
shape
,
(
list
,
tuple
)):
assert
len
(
shape
)
>
0
,
(
"The size of 'shape' in"
+
op_type
+
" can't be zero, "
"but received %s."
%
len
(
shape
))
attrs
[
"shape"
]
=
_get_attr_shape
(
shape
)
if
_contain_var
(
shape
):
inputs
[
'ShapeTensorList'
]
=
_get_shape_tensor
(
shape
)
...
...
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_inference_helper.py
浏览文件 @
05c2b9ba
...
...
@@ -82,6 +82,7 @@ class TestHybridParallelInferenceHelperClass(unittest.TestCase):
value
=
0
,
force_cpu
=
False
,
name
=
"cond_int"
)
print
(
cond_int
.
shape
)
cond
=
layers
.
less_than
(
x
=
step_idx
,
y
=
max_len
)
while_op
=
layers
.
While
(
cond
,
is_test
=
True
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py
浏览文件 @
05c2b9ba
...
...
@@ -82,6 +82,7 @@ def len_with_selected_rows(place):
# create selected_rows variable
var
=
block
.
create_var
(
name
=
"X"
,
dtype
=
"float32"
,
shape
=
[
-
1
],
persistable
=
True
,
type
=
fluid
.
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
# y is Variable(SelectedRows)
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
05c2b9ba
...
...
@@ -505,6 +505,7 @@ class OpTest(unittest.TestCase):
else
:
tensor
.
set
(
self
.
inputs
[
var_name
],
place
)
feed_map
[
var_name
]
=
tensor
return
feed_map
def
_append_ops
(
self
,
block
):
...
...
@@ -1136,6 +1137,7 @@ class OpTest(unittest.TestCase):
continue
else
:
grad_feed_map
[
arg
]
=
fwd_outs
[
i
].
_copy
(
p
)
return
grad_feed_map
def
_get_need_run_ops
(
self
,
op_desc
,
fwd_op_desc
=
None
):
...
...
@@ -1254,6 +1256,7 @@ class OpTest(unittest.TestCase):
build_strategy
=
build_strategy
,
places
=
place
)
program
=
compiled_program
outs
=
exe
.
run
(
program
,
feed
=
grad_feed_map
,
fetch_list
=
grad_fetch_list
,
...
...
@@ -1290,6 +1293,7 @@ class OpTest(unittest.TestCase):
fwd_res
,
grad_op_desc
,
enable_inplace
=
True
)
self
.
_compare_expect_and_actual_outputs
(
place
,
expect_res
[
1
],
expect_res
[
0
],
...
...
@@ -1457,7 +1461,7 @@ class OpTest(unittest.TestCase):
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if
expect_np
.
size
==
0
:
self
.
op_test
.
assertTrue
(
actual_np
.
size
==
0
)
# }}}
self
.
op_test
.
assertTrue
(
actual_np
.
size
==
0
)
self
.
_compare_numpy
(
name
,
actual_np
,
expect_np
)
if
isinstance
(
expect
,
tuple
):
self
.
_compare_list
(
name
,
actual
,
expect
)
...
...
@@ -1663,7 +1667,6 @@ class OpTest(unittest.TestCase):
if
check_dygraph
:
# always enable legacy dygraph
g_enable_legacy_dygraph
()
dygraph_checker
=
DygraphChecker
(
self
,
self
.
outputs
)
dygraph_checker
.
check
()
dygraph_outs
=
dygraph_checker
.
outputs
...
...
@@ -1830,15 +1833,29 @@ class OpTest(unittest.TestCase):
# Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error,
# which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4.
abs_a
=
np
.
abs
(
a
)
if
self
.
dtype
==
np
.
float64
and
\
self
.
op_type
not
in
op_threshold_white_list
.
NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
:
abs_a
[
abs_a
<
1e-10
]
=
1e-3
abs_a
[
np
.
logical_and
(
abs_a
>
1e-10
,
abs_a
<=
1e-8
)]
*=
1e4
abs_a
[
np
.
logical_and
(
abs_a
>
1e-8
,
abs_a
<=
1e-6
)]
*=
1e2
elif
self
.
is_bfloat16_op
():
abs_a
[
abs_a
<
1e-2
]
=
1
else
:
abs_a
[
abs_a
<
1e-3
]
=
1
if
abs_a
.
ndim
>
0
:
if
self
.
dtype
==
np
.
float64
and
\
self
.
op_type
not
in
op_threshold_white_list
.
NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
:
abs_a
[
abs_a
<
1e-10
]
=
1e-3
abs_a
[
np
.
logical_and
(
abs_a
>
1e-10
,
abs_a
<=
1e-8
)]
*=
1e4
abs_a
[
np
.
logical_and
(
abs_a
>
1e-8
,
abs_a
<=
1e-6
)]
*=
1e2
elif
self
.
is_bfloat16_op
():
abs_a
[
abs_a
<
1e-2
]
=
1
else
:
abs_a
[
abs_a
<
1e-3
]
=
1
elif
abs_a
.
ndim
==
0
:
if
self
.
dtype
==
np
.
float64
and
\
self
.
op_type
not
in
op_threshold_white_list
.
NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
:
if
abs_a
<
1e-10
:
abs_a
=
1e-3
elif
abs_a
>
1e-10
and
abs_a
<=
1e-8
:
abs_a
=
abs_a
*
1e4
elif
abs_a
>
1e-8
and
abs_a
<=
1e-6
:
abs_a
=
abs_a
*
1e2
elif
self
.
is_bfloat16_op
():
abs_a
=
1
if
abs_a
<
1e-2
else
abs_a
else
:
abs_a
=
1
if
abs_a
<
1e-3
else
abs_a
diff_mat
=
np
.
abs
(
a
-
b
)
/
abs_a
max_diff
=
np
.
max
(
diff_mat
)
...
...
@@ -1958,7 +1975,9 @@ class OpTest(unittest.TestCase):
tensor_to_check
=
self
.
scope
.
find_var
(
input_to_check
).
get_tensor
()
tensor_size
=
six
.
moves
.
reduce
(
lambda
a
,
b
:
a
*
b
,
tensor_to_check
.
shape
(),
1
)
if
tensor_size
<
100
:
tensor_ndim
=
len
(
tensor_to_check
.
shape
())
# for 0D Tensor, it's additional case for OP, so not raise error
if
tensor_ndim
>
0
and
tensor_size
<
100
:
self
.
__class__
.
input_shape_is_large
=
False
if
not
type
(
output_names
)
is
list
:
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
05c2b9ba
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
浏览文件 @
05c2b9ba
...
...
@@ -437,12 +437,6 @@ class TestFillConstantOpError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_shape_type
)
# The argument shape's size of fill_constant_op must not be 0.
def
test_shape_size
():
fluid
.
layers
.
fill_constant
(
shape
=
[],
dtype
=
"float32"
,
value
=
1
)
self
.
assertRaises
(
AssertionError
,
test_shape_size
)
# The shape dtype of fill_constant_op must be int32 or int64.
def
test_shape_tensor_dtype
():
shape
=
fluid
.
data
(
name
=
"shape_tensor"
,
...
...
python/paddle/fluid/tests/unittests/test_full_op.py
浏览文件 @
05c2b9ba
...
...
@@ -175,12 +175,6 @@ class TestFullOpError(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_shape_type
)
# The argument shape's size of full_op must not be 0.
def
test_shape_size
():
paddle
.
full
(
shape
=
[],
dtype
=
"float32"
,
fill_value
=
1
)
self
.
assertRaises
(
AssertionError
,
test_shape_size
)
# The shape dtype of full op must be int32 or int64.
def
test_shape_tensor_dtype
():
shape
=
fluid
.
data
(
name
=
"shape_tensor"
,
...
...
python/paddle/fluid/tests/unittests/test_mse_loss.py
浏览文件 @
05c2b9ba
...
...
@@ -30,8 +30,8 @@ class TestMseLoss(unittest.TestCase):
sub
=
input_val
-
label_val
np_result
=
np
.
mean
(
sub
*
sub
)
input_var
=
layers
.
create_tensor
(
dtype
=
"float32"
,
name
=
"input
"
)
label_var
=
layers
.
create_tensor
(
dtype
=
"float32"
,
name
=
"label
"
)
input_var
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
-
1
,
3
],
dtype
=
"float32
"
)
label_var
=
fluid
.
data
(
name
=
"label"
,
shape
=
[
-
1
,
3
],
dtype
=
"float32
"
)
output
=
layers
.
mse_loss
(
input
=
input_var
,
label
=
label_var
)
for
use_cuda
in
([
False
,
True
]
...
...
@@ -54,7 +54,7 @@ class TestMseInvalidInput(unittest.TestCase):
def
test_invalid_input
():
input
=
[
256
,
3
]
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
3
],
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label
1
'
,
shape
=
[
None
,
3
],
dtype
=
'float32'
)
loss
=
fluid
.
layers
.
mse_loss
(
input
,
label
)
self
.
assertRaises
(
TypeError
,
test_invalid_input
)
...
...
python/paddle/fluid/tests/unittests/test_randn_op.py
浏览文件 @
05c2b9ba
...
...
@@ -75,9 +75,6 @@ class TestRandnOpError(unittest.TestCase):
def
test_error
(
self
):
with
program_guard
(
Program
(),
Program
()):
# The argument shape's size of randn_op should not be 0.
self
.
assertRaises
(
AssertionError
,
paddle
.
randn
,
[])
# The argument shape's type of randn_op should be list or tuple.
self
.
assertRaises
(
TypeError
,
paddle
.
randn
,
1
)
...
...
python/paddle/fluid/tests/unittests/test_select_input_output_op.py
浏览文件 @
05c2b9ba
...
...
@@ -23,6 +23,8 @@ from paddle.fluid.executor import Executor
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.layers.control_flow
import
select_input
,
select_output
paddle
.
enable_static
()
class
TestSplitMergeSelectedVarOps
(
unittest
.
TestCase
):
...
...
@@ -37,7 +39,9 @@ class TestSplitMergeSelectedVarOps(unittest.TestCase):
outputs
=
[]
for
i
in
range
(
branch_num
):
out
=
program
.
current_block
().
create_var
(
dtype
=
'float32'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
dtype
=
'float32'
,
shape
=
[
2
],
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
outputs
.
append
(
out
)
select_output
(
x
,
outputs
,
mask
)
...
...
python/paddle/fluid/tests/unittests/test_set_value_op.py
浏览文件 @
05c2b9ba
...
...
@@ -23,7 +23,6 @@ from paddle.fluid.layer_helper import LayerHelper
from
functools
import
reduce
from
paddle.fluid.framework
import
_test_eager_guard
class
TestSetValueBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -1442,7 +1441,6 @@ class TestGradientTruncated(unittest.TestCase):
# When `input.stop_gradient = True` and `value.stop_gradient = False`,
# set_value_grad_op will not be run during backward.
y
,
value
=
op
(
x
)
y2
=
y
+
1
loss
=
paddle
.
fluid
.
layers
.
reduce_sum
(
y2
)
sgd
=
paddle
.
optimizer
.
Adam
()
...
...
python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
0 → 100644
浏览文件 @
05c2b9ba
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
import
unittest
unary_api_list
=
[
paddle
.
nn
.
functional
.
elu
,
paddle
.
nn
.
functional
.
gelu
,
paddle
.
nn
.
functional
.
hardsigmoid
,
paddle
.
nn
.
functional
.
hardswish
,
paddle
.
nn
.
functional
.
leaky_relu
,
paddle
.
nn
.
functional
.
log_sigmoid
,
paddle
.
nn
.
functional
.
relu
,
paddle
.
nn
.
functional
.
relu6
,
paddle
.
nn
.
functional
.
sigmoid
,
paddle
.
nn
.
functional
.
softplus
,
paddle
.
nn
.
functional
.
softshrink
,
paddle
.
nn
.
functional
.
softsign
,
paddle
.
nn
.
functional
.
swish
,
paddle
.
nn
.
functional
.
tanhshrink
,
paddle
.
nn
.
functional
.
thresholded_relu
,
paddle
.
stanh
,
paddle
.
nn
.
functional
.
celu
,
paddle
.
nn
.
functional
.
mish
,
paddle
.
nn
.
functional
.
silu
,
paddle
.
nn
.
functional
.
tanh
,
paddle
.
cosh
,
paddle
.
sinh
,
paddle
.
abs
,
paddle
.
acos
,
paddle
.
asin
,
paddle
.
atan
,
paddle
.
ceil
,
paddle
.
cos
,
paddle
.
exp
,
paddle
.
floor
,
paddle
.
log
,
paddle
.
log1p
,
paddle
.
reciprocal
,
paddle
.
round
,
paddle
.
sin
,
paddle
.
sqrt
,
paddle
.
square
,
paddle
.
tanh
,
paddle
.
acosh
,
paddle
.
asinh
,
paddle
.
atanh
,
paddle
.
expm1
,
paddle
.
log10
,
paddle
.
log2
,
paddle
.
tan
,
]
# Use to test zero-dim in the whole API
class
TestUnaryAPI
(
unittest
.
TestCase
):
def
test_dygraph_unary
(
self
):
paddle
.
disable_static
()
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
for
api
in
unary_api_list
:
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
api
(
x
)
out
.
backward
()
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
paddle
.
enable_static
()
def
test_static_unary
(
self
):
paddle
.
enable_static
()
for
api
in
unary_api_list
:
main_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
fluid
.
Program
()):
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
api
(
x
)
fluid
.
backward
.
append_backward
(
out
)
# ScaleLossGradOp / append_backward always set grad shape to [1]
prog
=
paddle
.
static
.
default_main_program
()
block
=
prog
.
global_block
()
x_grad
=
block
.
var
(
fluid
.
framework
.
grad_var_name
(
x
.
name
))
out_grad
=
block
.
var
(
fluid
.
framework
.
grad_var_name
(
out
.
name
))
# Test compile shape, grad is always [1]
self
.
assertEqual
(
x
.
shape
,
())
self
.
assertEqual
(
out
.
shape
,
())
exe
=
fluid
.
Executor
()
result
=
exe
.
run
(
main_prog
,
fetch_list
=
[
x
,
out
,
x_grad
,
out_grad
])
# Test runtime shape
self
.
assertEqual
(
result
[
0
].
shape
,
())
self
.
assertEqual
(
result
[
1
].
shape
,
())
self
.
assertEqual
(
result
[
3
].
shape
,
(
1
,
))
# 0D will be stacked when 1+ place, due to it cannot be concated
# for 1 place: [ x-place1 ]
# for 1+ place: [ paddle.stack([x-place1, x_place2...]) ]
if
paddle
.
device
.
is_compiled_with_cuda
():
places
=
[
paddle
.
CUDAPlace
(
0
)]
device_num
=
1
expect_shape
=
()
else
:
places
=
[
paddle
.
CPUPlace
()]
*
4
device_num
=
4
expect_shape
=
(
device_num
,
)
compiled_program
=
fluid
.
CompiledProgram
(
main_prog
).
with_data_parallel
(
out
.
name
,
places
=
places
)
result
=
exe
.
run
(
compiled_program
,
fetch_list
=
[
x
,
out
,
x_grad
,
out_grad
],
return_merged
=
True
)
# Test runtime parallel shape
self
.
assertEqual
(
result
[
0
].
shape
,
expect_shape
)
self
.
assertEqual
(
result
[
1
].
shape
,
expect_shape
)
self
.
assertEqual
(
result
[
3
].
shape
,
(
device_num
,
))
compiled_program
=
fluid
.
CompiledProgram
(
main_prog
).
with_data_parallel
(
out
.
name
,
places
=
places
)
result
=
exe
.
run
(
compiled_program
,
fetch_list
=
[
x
,
out
,
x_grad
,
out_grad
],
return_merged
=
False
)
# [[x-place1, x-place2, ...], [], [], ...]
self
.
assertEqual
(
np
.
array
(
result
[
0
]).
shape
,
(
device_num
,
))
self
.
assertEqual
(
np
.
array
(
result
[
1
]).
shape
,
(
device_num
,
))
self
.
assertEqual
(
np
.
array
(
result
[
3
]).
shape
,
(
device_num
,
1
))
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录