未验证 提交 eaccdc71 编写于 作者: F furnace 提交者: GitHub

[NPU] fix tril_triu (#38864)

[NPU] fix tril_triu
上级 7a5af630
......@@ -33,12 +33,41 @@ class TrilTriuNPUKernel : public framework::OpKernel<T> {
framework::NPUAttributeMap attr_input = {{"diagonal", diagonal}};
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
const auto& runner = NpuOpRunner(op_type, {*x}, {*out}, attr_input);
runner.Run(stream);
auto op_func_tril = [](const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs,
const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& runner = NpuOpRunner("Tril", inputs, outputs, attrs);
runner.Run(dev_ctx.stream());
};
auto op_func_triu = [](const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs,
const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& runner = NpuOpRunner("Triu", inputs, outputs, attrs);
runner.Run(dev_ctx.stream());
};
if (x->type() == framework::proto::VarType::BOOL) {
if (lower) {
NpuOpRunner::TypeAdapter({*x}, {*out}, attr_input, dev_ctx,
op_func_tril,
{framework::proto::VarType::UINT8},
{framework::proto::VarType::UINT8});
} else {
NpuOpRunner::TypeAdapter({*x}, {*out}, attr_input, dev_ctx,
op_func_triu,
{framework::proto::VarType::UINT8},
{framework::proto::VarType::UINT8});
}
} else {
const auto& runner = NpuOpRunner(op_type, {*x}, {*out}, attr_input);
runner.Run(dev_ctx.stream());
}
}
};
......@@ -49,4 +78,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
tril_triu, ops::TrilTriuNPUKernel<plat::NPUDeviceContext, float>,
ops::TrilTriuNPUKernel<plat::NPUDeviceContext, int>,
ops::TrilTriuNPUKernel<plat::NPUDeviceContext, bool>,
ops::TrilTriuNPUKernel<plat::NPUDeviceContext, plat::float16>);
......@@ -15,7 +15,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
import paddle.tensor as tensor
......@@ -187,5 +187,19 @@ class TestTrilTriuOpAPI(unittest.TestCase):
fetch_list=[triu_out])
# @skip_check_grad_ci(reason="[NPU does not support grad right now.")
class TestNPUTrilTriu_bool(TestNPUTrilTriu):
def test_check_output(self):
self.check_output_with_place(self.place)
def init_dtype(self):
self.dtype = np.bool
def initTestCase(self):
self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None
self.X = np.random.choice([False, True], size=(100)).reshape([10, -1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册