Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
eaccdc71
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eaccdc71
编写于
1月 13, 2022
作者:
F
furnace
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NPU] fix tril_triu (#38864)
[NPU] fix tril_triu
上级
7a5af630
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
6 deletion
+51
-6
paddle/fluid/operators/tril_triu_op_npu.cc
paddle/fluid/operators/tril_triu_op_npu.cc
+36
-5
python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py
...paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py
+15
-1
未找到文件。
paddle/fluid/operators/tril_triu_op_npu.cc
浏览文件 @
eaccdc71
...
@@ -33,12 +33,41 @@ class TrilTriuNPUKernel : public framework::OpKernel<T> {
...
@@ -33,12 +33,41 @@ class TrilTriuNPUKernel : public framework::OpKernel<T> {
framework
::
NPUAttributeMap
attr_input
=
{{
"diagonal"
,
diagonal
}};
framework
::
NPUAttributeMap
attr_input
=
{{
"diagonal"
,
diagonal
}};
auto
stream
=
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
NPUDeviceContext
>()
ctx
.
template
device_context
<
paddle
::
platform
::
NPUDeviceContext
>();
.
stream
();
const
auto
&
runner
=
NpuOpRunner
(
op_type
,
{
*
x
},
{
*
out
},
attr_input
);
auto
op_func_tril
=
[](
const
std
::
vector
<
Tensor
>&
inputs
,
runner
.
Run
(
stream
);
const
std
::
vector
<
Tensor
>&
outputs
,
const
NPUAttributeMap
&
attrs
,
const
platform
::
NPUDeviceContext
&
dev_ctx
)
{
const
auto
&
runner
=
NpuOpRunner
(
"Tril"
,
inputs
,
outputs
,
attrs
);
runner
.
Run
(
dev_ctx
.
stream
());
};
auto
op_func_triu
=
[](
const
std
::
vector
<
Tensor
>&
inputs
,
const
std
::
vector
<
Tensor
>&
outputs
,
const
NPUAttributeMap
&
attrs
,
const
platform
::
NPUDeviceContext
&
dev_ctx
)
{
const
auto
&
runner
=
NpuOpRunner
(
"Triu"
,
inputs
,
outputs
,
attrs
);
runner
.
Run
(
dev_ctx
.
stream
());
};
if
(
x
->
type
()
==
framework
::
proto
::
VarType
::
BOOL
)
{
if
(
lower
)
{
NpuOpRunner
::
TypeAdapter
({
*
x
},
{
*
out
},
attr_input
,
dev_ctx
,
op_func_tril
,
{
framework
::
proto
::
VarType
::
UINT8
},
{
framework
::
proto
::
VarType
::
UINT8
});
}
else
{
NpuOpRunner
::
TypeAdapter
({
*
x
},
{
*
out
},
attr_input
,
dev_ctx
,
op_func_triu
,
{
framework
::
proto
::
VarType
::
UINT8
},
{
framework
::
proto
::
VarType
::
UINT8
});
}
}
else
{
const
auto
&
runner
=
NpuOpRunner
(
op_type
,
{
*
x
},
{
*
out
},
attr_input
);
runner
.
Run
(
dev_ctx
.
stream
());
}
}
}
};
};
...
@@ -49,4 +78,6 @@ namespace ops = paddle::operators;
...
@@ -49,4 +78,6 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
REGISTER_OP_NPU_KERNEL
(
tril_triu
,
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
float
>
,
tril_triu
,
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
float
>
,
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
int
>
,
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
bool
>
,
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
plat
::
float16
>
);
ops
::
TrilTriuNPUKernel
<
plat
::
NPUDeviceContext
,
plat
::
float16
>
);
python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py
浏览文件 @
eaccdc71
...
@@ -15,7 +15,7 @@ from __future__ import print_function
...
@@ -15,7 +15,7 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
paddle.fluid.tests.unittests.op_test
import
OpTest
from
paddle.fluid.tests.unittests.op_test
import
OpTest
,
skip_check_grad_ci
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.tensor
as
tensor
import
paddle.tensor
as
tensor
...
@@ -187,5 +187,19 @@ class TestTrilTriuOpAPI(unittest.TestCase):
...
@@ -187,5 +187,19 @@ class TestTrilTriuOpAPI(unittest.TestCase):
fetch_list
=
[
triu_out
])
fetch_list
=
[
triu_out
])
# @skip_check_grad_ci(reason="[NPU does not support grad right now.")
class
TestNPUTrilTriu_bool
(
TestNPUTrilTriu
):
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
init_dtype
(
self
):
self
.
dtype
=
np
.
bool
def
initTestCase
(
self
):
self
.
real_op_type
=
np
.
random
.
choice
([
'triu'
,
'tril'
])
self
.
diagonal
=
None
self
.
X
=
np
.
random
.
choice
([
False
,
True
],
size
=
(
100
)).
reshape
([
10
,
-
1
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录