From eaccdc71dd04b1f42ceac170c82754dd0a953867 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Thu, 13 Jan 2022 16:34:17 +0800 Subject: [PATCH] [NPU] fix tril_triu (#38864) [NPU] fix tril_triu --- paddle/fluid/operators/tril_triu_op_npu.cc | 41 ++++++++++++++++--- .../unittests/npu/test_tril_triu_op_npu.py | 16 +++++++- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/tril_triu_op_npu.cc b/paddle/fluid/operators/tril_triu_op_npu.cc index ab7a9035fb..02af711567 100644 --- a/paddle/fluid/operators/tril_triu_op_npu.cc +++ b/paddle/fluid/operators/tril_triu_op_npu.cc @@ -33,12 +33,41 @@ class TrilTriuNPUKernel : public framework::OpKernel { framework::NPUAttributeMap attr_input = {{"diagonal", diagonal}}; - auto stream = - ctx.template device_context() - .stream(); + const auto& dev_ctx = + ctx.template device_context(); - const auto& runner = NpuOpRunner(op_type, {*x}, {*out}, attr_input); - runner.Run(stream); + auto op_func_tril = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = NpuOpRunner("Tril", inputs, outputs, attrs); + runner.Run(dev_ctx.stream()); + }; + + auto op_func_triu = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = NpuOpRunner("Triu", inputs, outputs, attrs); + runner.Run(dev_ctx.stream()); + }; + + if (x->type() == framework::proto::VarType::BOOL) { + if (lower) { + NpuOpRunner::TypeAdapter({*x}, {*out}, attr_input, dev_ctx, + op_func_tril, + {framework::proto::VarType::UINT8}, + {framework::proto::VarType::UINT8}); + } else { + NpuOpRunner::TypeAdapter({*x}, {*out}, attr_input, dev_ctx, + op_func_triu, + {framework::proto::VarType::UINT8}, + {framework::proto::VarType::UINT8}); + } + } else { + const auto& runner = NpuOpRunner(op_type, {*x}, {*out}, attr_input); + runner.Run(dev_ctx.stream()); + } } }; @@ -49,4 +78,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( tril_triu, ops::TrilTriuNPUKernel, + ops::TrilTriuNPUKernel, + ops::TrilTriuNPUKernel, ops::TrilTriuNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py index 13adc25a38..8239dd4f3f 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid import paddle.tensor as tensor @@ -187,5 +187,19 @@ class TestTrilTriuOpAPI(unittest.TestCase): fetch_list=[triu_out]) +# @skip_check_grad_ci(reason="[NPU does not support grad right now.") +class TestNPUTrilTriu_bool(TestNPUTrilTriu): + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_dtype(self): + self.dtype = np.bool + + def initTestCase(self): + self.real_op_type = np.random.choice(['triu', 'tril']) + self.diagonal = None + self.X = np.random.choice([False, True], size=(100)).reshape([10, -1]) + + if __name__ == '__main__': unittest.main() -- GitLab