Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
525c32e3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
525c32e3
编写于
3月 29, 2021
作者:
L
liym27
提交者:
GitHub
3月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix bug of set_value op:Decerease axes to do right broadcast (#31875)
上级
123949eb
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
95 addition
and
13 deletion
+95
-13
paddle/fluid/operators/set_value_op.cc
paddle/fluid/operators/set_value_op.cc
+10
-1
paddle/fluid/operators/set_value_op.h
paddle/fluid/operators/set_value_op.h
+61
-11
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+10
-1
python/paddle/fluid/tests/unittests/test_set_value_op.py
python/paddle/fluid/tests/unittests/test_set_value_op.py
+14
-0
未找到文件。
paddle/fluid/operators/set_value_op.cc
浏览文件 @
525c32e3
...
@@ -124,6 +124,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -124,6 +124,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
int64_t
>>
(
AddAttr
<
std
::
vector
<
int64_t
>>
(
"steps"
,
"(list<int64_t>) Stride step from the start to the end."
)
"steps"
,
"(list<int64_t>) Stride step from the start to the end."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int64_t
>>
(
"decrease_axes"
,
"(list<int>) The axes to decrease."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int
>>
(
"bool_values"
,
"Store the bool values."
)
AddAttr
<
std
::
vector
<
int
>>
(
"bool_values"
,
"Store the bool values."
)
.
SetDefault
({});
.
SetDefault
({});
...
@@ -185,4 +188,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
...
@@ -185,4 +188,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
"Ending indices of corresponding axis in `axes`."
,
"Ending indices of corresponding axis in `axes`."
,
std
::
vector
<
int64_t
>
{})
std
::
vector
<
int64_t
>
{})
.
NewAttr
(
"steps"
,
"Stride step from the start to the end."
,
.
NewAttr
(
"steps"
,
"Stride step from the start to the end."
,
std
::
vector
<
int64_t
>
{}));
std
::
vector
<
int64_t
>
{}))
.
AddCheckpoint
(
R"ROC(
Upgrade set_value, add 1 attribute [decrease_axes].
)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
().
NewAttr
(
"decrease_axes"
,
"The axes to decrease."
,
std
::
vector
<
int64_t
>
{}));
paddle/fluid/operators/set_value_op.h
浏览文件 @
525c32e3
...
@@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims,
...
@@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims,
}
}
inline
framework
::
DDim
GetSliceDims
(
const
framework
::
DDim
in_dims
,
inline
framework
::
DDim
GetSliceDims
(
const
framework
::
DDim
in_dims
,
const
std
::
vector
<
int64_t
>
axes
,
const
std
::
vector
<
int64_t
>
&
axes
,
const
std
::
vector
<
int64_t
>
starts
,
const
std
::
vector
<
int64_t
>
&
starts
,
const
std
::
vector
<
int64_t
>
ends
,
const
std
::
vector
<
int64_t
>
&
ends
,
const
std
::
vector
<
int64_t
>
steps
)
{
const
std
::
vector
<
int64_t
>
&
steps
)
{
framework
::
DDim
slice_dims
(
in_dims
);
framework
::
DDim
slice_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
...
@@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
...
@@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
return
slice_dims
;
return
slice_dims
;
}
}
inline
framework
::
DDim
GetDecreasedDims
(
const
framework
::
DDim
slice_dims
,
const
std
::
vector
<
int64_t
>&
decrease_axes
)
{
// Get dims after decreasing axes.
framework
::
DDim
decreased_dims
(
slice_dims
);
if
(
decrease_axes
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axes
.
size
();
++
i
)
{
int64_t
axis
=
decrease_axes
[
i
];
PADDLE_ENFORCE_EQ
(
decreased_dims
[
axis
],
1
,
platform
::
errors
::
InvalidArgument
(
"decrease dim should be 1"
));
decreased_dims
[
axis
]
=
0
;
}
std
::
vector
<
int64_t
>
new_shape
;
for
(
int
i
=
0
;
i
<
decreased_dims
.
size
();
++
i
)
{
if
(
decreased_dims
[
i
]
!=
0
)
{
new_shape
.
push_back
(
decreased_dims
[
i
]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if
(
new_shape
.
size
()
==
0
)
{
new_shape
.
push_back
(
1
);
}
decreased_dims
=
framework
::
make_ddim
(
new_shape
);
}
return
decreased_dims
;
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
SetValueKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SetValueKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> {
...
@@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto
ends
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"ends"
);
auto
ends
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"ends"
);
auto
steps
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"steps"
);
auto
steps
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"steps"
);
auto
shape
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
auto
shape
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
auto
decrease_axes
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"decrease_axes"
);
auto
dtype
=
in
->
type
();
auto
dtype
=
in
->
type
();
if
(
!
starts_tensor_list
.
empty
())
{
if
(
!
starts_tensor_list
.
empty
())
{
...
@@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> {
...
@@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto
in_dims
=
in
->
dims
();
auto
in_dims
=
in
->
dims
();
CheckAndUpdateSlice
(
in_dims
,
axes
,
&
starts
,
&
ends
,
&
steps
);
CheckAndUpdateSlice
(
in_dims
,
axes
,
&
starts
,
&
ends
,
&
steps
);
auto
slice_dims
=
GetSliceDims
(
in_dims
,
axes
,
starts
,
ends
,
steps
);
auto
slice_dims
=
GetSliceDims
(
in_dims
,
axes
,
starts
,
ends
,
steps
);
auto
decrease_slice_dims
=
GetDecreasedDims
(
slice_dims
,
decrease_axes
);
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
auto
&
eigen_place
=
auto
&
eigen_place
=
...
@@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> {
...
@@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> {
// set_value is what we want.
// set_value is what we want.
TensorCopy
(
*
in
,
place
,
out
);
TensorCopy
(
*
in
,
place
,
out
);
Tensor
slice_t
(
dtype
),
pad_t
(
dtype
);
Tensor
slice_t
ensor
(
dtype
),
pad_tensor
(
dtype
);
slice_t
.
mutable_data
<
T
>
(
slice_dims
,
place
);
slice_t
ensor
.
mutable_data
<
T
>
(
slice_dims
,
place
);
pad_t
.
mutable_data
<
T
>
(
in_dims
,
place
);
pad_t
ensor
.
mutable_data
<
T
>
(
in_dims
,
place
);
auto
pad_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
pad_t
,
in_dims
);
auto
pad_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
pad_t
ensor
,
in_dims
);
auto
out_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
*
out
);
auto
out_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
*
out
);
auto
slice_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
slice_t
,
slice_dims
);
auto
slice_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
slice_t
ensor
,
slice_dims
);
// Step 1: Set the value of out at `_index` to zero
// Step 1: Set the value of out at `_index` to zero
slice_e
.
device
(
eigen_place
)
=
slice_e
.
constant
(
T
(
0
));
slice_e
.
device
(
eigen_place
)
=
slice_e
.
constant
(
T
(
0
));
...
@@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> {
...
@@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> {
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor
.
Resize
(
decrease_slice_dims
);
if
(
value_tensor
!=
nullptr
)
{
if
(
value_tensor
!=
nullptr
)
{
// ElementwiseComputeEx can do broadcasting
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
slice_t
,
value_tensor
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_t
);
ctx
,
&
slice_t
ensor
,
value_tensor
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_tensor
);
}
else
{
}
else
{
Tensor
value_t
(
dtype
);
Tensor
value_t
(
dtype
);
auto
value_dims
=
framework
::
make_ddim
(
shape
);
auto
value_dims
=
framework
::
make_ddim
(
shape
);
...
@@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> {
...
@@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> {
CopyVecotorToTensor
<
T
>
(
value_name
.
c_str
(),
&
value_t
,
ctx
);
CopyVecotorToTensor
<
T
>
(
value_name
.
c_str
(),
&
value_t
,
ctx
);
value_t
.
Resize
(
value_dims
);
value_t
.
Resize
(
value_dims
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
slice_t
,
&
value_t
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_t
);
ctx
,
&
slice_t
ensor
,
&
value_t
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_tensor
);
}
}
slice_tensor
.
Resize
(
slice_dims
);
// - Step 2.2 Pad slice tensor with 0
// - Step 2.2 Pad slice tensor with 0
pad_e
.
device
(
eigen_place
)
=
pad_e
.
constant
(
T
(
0
));
pad_e
.
device
(
eigen_place
)
=
pad_e
.
constant
(
T
(
0
));
...
...
python/paddle/fluid/framework.py
浏览文件 @
525c32e3
...
@@ -1863,6 +1863,7 @@ class Variable(object):
...
@@ -1863,6 +1863,7 @@ class Variable(object):
if
not
isinstance
(
item
,
tuple
):
if
not
isinstance
(
item
,
tuple
):
item
=
[
item
]
item
=
[
item
]
decrease_axes
=
[]
axes
=
[]
axes
=
[]
starts
=
[]
starts
=
[]
ends
=
[]
ends
=
[]
...
@@ -1933,15 +1934,23 @@ class Variable(object):
...
@@ -1933,15 +1934,23 @@ class Variable(object):
if
end
is
None
:
if
end
is
None
:
end
=
max_integer
if
step
>
0
else
(
0
-
max_integer
)
end
=
max_integer
if
step
>
0
else
(
0
-
max_integer
)
else
:
else
:
decrease_axes
.
append
(
dim
)
start
=
slice_item
start
=
slice_item
end
=
slice_item
+
1
if
slice_item
!=
-
1
else
max_integer
end
=
slice_item
+
1
if
slice_item
!=
-
1
else
max_integer
step
=
1
step
=
1
axes
.
append
(
dim
)
axes
.
append
(
dim
)
starts
.
append
(
start
)
starts
.
append
(
start
)
ends
.
append
(
end
)
ends
.
append
(
end
)
steps
.
append
(
step
)
steps
.
append
(
step
)
attrs
=
{
'axes'
:
axes
,
'starts'
:
starts
,
'ends'
:
ends
,
'steps'
:
steps
}
attrs
=
{
'axes'
:
axes
,
'starts'
:
starts
,
'ends'
:
ends
,
'steps'
:
steps
,
'decrease_axes'
:
decrease_axes
}
from
.layers
import
utils
from
.layers
import
utils
if
utils
.
_contain_var
(
starts
):
if
utils
.
_contain_var
(
starts
):
...
...
python/paddle/fluid/tests/unittests/test_set_value_op.py
浏览文件 @
525c32e3
...
@@ -671,6 +671,20 @@ class TestSetValueValueShape4(TestSetValueApi):
...
@@ -671,6 +671,20 @@ class TestSetValueValueShape4(TestSetValueApi):
self
.
data
[
0
]
=
self
.
value
self
.
data
[
0
]
=
self
.
value
class
TestSetValueValueShape5
(
TestSetValueApi
):
def
set_value
(
self
):
self
.
value
=
np
.
array
([
3
,
3
,
3
]).
astype
(
self
.
dtype
)
def
set_shape
(
self
):
self
.
shape
=
[
3
,
4
]
def
_call_setitem
(
self
,
x
):
x
[:,
0
]
=
paddle
.
assign
(
self
.
value
)
# x is Paddle.Tensor
def
_get_answer
(
self
):
self
.
data
[:,
0
]
=
self
.
value
# 4. Test error
# 4. Test error
class
TestError
(
TestSetValueBase
):
class
TestError
(
TestSetValueBase
):
def
_value_type_error
(
self
):
def
_value_type_error
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录