Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
62e41150
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
62e41150
编写于
10月 09, 2021
作者:
Z
zhiboniu
提交者:
GitHub
10月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fill_diagonal op fix border cross caused by offset (#36212)
上级
c8a01010
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
56 addition
and
8 deletion
+56
-8
paddle/fluid/operators/fill_diagonal_op.cc
paddle/fluid/operators/fill_diagonal_op.cc
+14
-4
paddle/fluid/operators/fill_diagonal_op.cu
paddle/fluid/operators/fill_diagonal_op.cu
+12
-4
python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py
...addle/fluid/tests/unittests/test_tensor_fill_diagonal_.py
+30
-0
未找到文件。
paddle/fluid/operators/fill_diagonal_op.cc
浏览文件 @
62e41150
...
@@ -108,8 +108,15 @@ class FillIDiagonalKernel : public framework::OpKernel<T> {
...
@@ -108,8 +108,15 @@ class FillIDiagonalKernel : public framework::OpKernel<T> {
size
=
std
::
min
(
size
,
out_dims
[
1
]
*
out_dims
[
1
]);
size
=
std
::
min
(
size
,
out_dims
[
1
]
*
out_dims
[
1
]);
}
}
for
(
int64_t
i
=
offset
;
i
<
size
;
i
+=
strides
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
i
+=
strides
)
{
out_data
[
i
]
=
temp_var
;
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if
(
i
%
out_dims
[
1
]
+
offset
>=
0
&&
i
%
out_dims
[
1
]
+
offset
<
out_dims
[
1
])
{
out_data
[
i
+
offset
]
=
temp_var
;
}
}
}
}
}
};
};
...
@@ -176,8 +183,11 @@ class FillIDiagonalGradKernel : public framework::OpKernel<T> {
...
@@ -176,8 +183,11 @@ class FillIDiagonalGradKernel : public framework::OpKernel<T> {
wrapsize
=
size
;
wrapsize
=
size
;
}
}
for
(
int64_t
i
=
offset
;
i
<
wrapsize
;
i
+=
strides
)
{
for
(
int64_t
i
=
0
;
i
<
wrapsize
;
i
+=
strides
)
{
data
[
i
]
=
T
(
0
);
if
(
i
%
dx_dims
[
1
]
+
offset
>=
0
&&
i
%
dx_dims
[
1
]
+
offset
<
dx_dims
[
1
])
{
data
[
i
+
offset
]
=
T
(
0
);
}
}
}
}
}
}
}
...
...
paddle/fluid/operators/fill_diagonal_op.cu
浏览文件 @
62e41150
...
@@ -22,12 +22,20 @@ using CUDADeviceContext = paddle::platform::CUDADeviceContext;
...
@@ -22,12 +22,20 @@ using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
fill_constant_kernel
(
const
int64_t
featuresize
,
T
*
in_data
,
__global__
void
fill_constant_kernel
(
const
int64_t
featuresize
,
T
*
in_data
,
int64_t
strides
,
int
offset
,
T
fillvar
)
{
int64_t
strides
,
int
offset
,
T
fillvar
,
int
dims
)
{
for
(
int64_t
idx
=
blockIdx
.
x
*
featuresize
+
threadIdx
.
x
;
for
(
int64_t
idx
=
blockIdx
.
x
*
featuresize
+
threadIdx
.
x
;
idx
*
strides
+
offset
<
(
blockIdx
.
x
+
1
)
*
featuresize
;
idx
*
strides
+
offset
<
(
blockIdx
.
x
+
1
)
*
featuresize
;
idx
+=
blockDim
.
x
)
{
idx
+=
blockDim
.
x
)
{
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if
((
idx
*
strides
)
%
dims
+
offset
<
dims
&&
(
idx
*
strides
)
%
dims
+
offset
>=
0
)
{
in_data
[
idx
*
strides
+
offset
]
=
fillvar
;
in_data
[
idx
*
strides
+
offset
]
=
fillvar
;
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -62,7 +70,7 @@ class FillIDiagonalCUDAKernel : public framework::OpKernel<T> {
...
@@ -62,7 +70,7 @@ class FillIDiagonalCUDAKernel : public framework::OpKernel<T> {
int64_t
kBlockDim
=
std
::
min
(
int64_t
(
size
/
strides
),
kMaxBlockDim
);
int64_t
kBlockDim
=
std
::
min
(
int64_t
(
size
/
strides
),
kMaxBlockDim
);
fill_constant_kernel
<
T
><<<
1
,
kBlockDim
,
0
>>>
(
size
,
out_data
,
strides
,
fill_constant_kernel
<
T
><<<
1
,
kBlockDim
,
0
>>>
(
size
,
out_data
,
strides
,
offset
,
temp_var
);
offset
,
temp_var
,
out_dims
[
1
]
);
}
}
};
};
...
@@ -96,7 +104,7 @@ class FillIDiagonalGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -96,7 +104,7 @@ class FillIDiagonalGradCUDAKernel : public framework::OpKernel<T> {
int64_t
kBlockDim
=
std
::
min
(
int64_t
(
size
),
kMaxBlockDim
);
int64_t
kBlockDim
=
std
::
min
(
int64_t
(
size
),
kMaxBlockDim
);
fill_constant_kernel
<
T
><<<
1
,
kBlockDim
,
0
>>>
(
wrapsize
,
in_data
,
strides
,
fill_constant_kernel
<
T
><<<
1
,
kBlockDim
,
0
>>>
(
wrapsize
,
in_data
,
strides
,
offset
,
T
(
0
));
offset
,
T
(
0
)
,
out_dims
[
1
]
);
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py
浏览文件 @
62e41150
...
@@ -50,6 +50,36 @@ class TensorFillDiagonal_Test(unittest.TestCase):
...
@@ -50,6 +50,36 @@ class TensorFillDiagonal_Test(unittest.TestCase):
(
y
.
grad
.
numpy
().
astype
(
'float32'
)
==
expected_grad
).
all
(),
(
y
.
grad
.
numpy
().
astype
(
'float32'
)
==
expected_grad
).
all
(),
True
)
True
)
def
test_offset
(
self
):
expected_np
=
np
.
array
(
[[
2
,
2
,
1
],
[
2
,
2
,
2
],
[
2
,
2
,
2
]]).
astype
(
'float32'
)
expected_grad
=
np
.
array
(
[[
1
,
1
,
0
],
[
1
,
1
,
1
],
[
1
,
1
,
1
]]).
astype
(
'float32'
)
typelist
=
[
'float32'
,
'float64'
,
'int32'
,
'int64'
]
places
=
[
fluid
.
CPUPlace
()]
if
fluid
.
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
idx
,
p
in
enumerate
(
places
):
if
idx
==
0
:
paddle
.
set_device
(
'cpu'
)
else
:
paddle
.
set_device
(
'gpu'
)
for
dtype
in
typelist
:
x
=
paddle
.
ones
((
3
,
3
),
dtype
=
dtype
)
x
.
stop_gradient
=
False
y
=
x
*
2
y
.
fill_diagonal_
(
1
,
offset
=
2
,
wrap
=
True
)
loss
=
y
.
sum
()
loss
.
backward
()
self
.
assertEqual
(
(
y
.
numpy
().
astype
(
'float32'
)
==
expected_np
).
all
(),
True
)
self
.
assertEqual
(
(
y
.
grad
.
numpy
().
astype
(
'float32'
)
==
expected_grad
).
all
(),
True
)
def
test_bool
(
self
):
def
test_bool
(
self
):
expected_np
=
np
.
array
(
expected_np
=
np
.
array
(
[[
False
,
True
,
True
],
[
True
,
False
,
True
],
[
True
,
True
,
False
]])
[[
False
,
True
,
True
],
[
True
,
False
,
True
],
[
True
,
True
,
False
]])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录