diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 86de9e999ec9aa70af4ac047e50730a684bbc1d2..77bc446243ad031d85a16f869b8654f5fee64504 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -4175,15 +4175,31 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { /* static */ void MLUCnnl::Where(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, const void* x, - const uint32_t* strides, - const uint32_t* index, + const cnnlTensorDescriptor_t num_true_desc, + const void* num_true, + const bool as_tuple, const cnnlTensorDescriptor_t y_desc, - int* y, - const bool as_tuple) { + void* y) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - + size_t workspace_size; PADDLE_ENFORCE_MLU_SUCCESS( - cnnlWhere(handle, x_desc, x, strides, index, y_desc, y, as_tuple)); + cnnlGetWhereWorkspaceSize(handle, num_true_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlWhere_v2(handle, + x_desc, + x, + num_true_desc, + num_true, + as_tuple, + workspace_ptr, + workspace_size, + y_desc, + y)); } /* static */ void MLUCnnl::InTopK(const ExecutionContext& ctx, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index bbe7f221568e5798db247b0b794fa3f94b0a41d7..6882fc17f08b8a21e122ec0a7ac0a9892e8f5a8b 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1607,12 +1607,11 @@ class MLUCnnl { static void Where(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, const void* x, - const uint32_t* strides, - const uint32_t* index, + const cnnlTensorDescriptor_t num_true_desc, + const void* num_true, + const bool as_tuple, const cnnlTensorDescriptor_t y_desc, - int* y, - const bool as_tuple); - + void* y); static void Conv2D(const ExecutionContext& ctx, const cnnlConvolutionDescriptor_t conv_desc, const cnnlDataType_t tensor_dtype, diff --git a/paddle/fluid/operators/where_index_op_mlu.cc b/paddle/fluid/operators/where_index_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0699521aa46e41a120d1f2347516d93322285f4 --- /dev/null +++ b/paddle/fluid/operators/where_index_op_mlu.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MLUWhereIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* out = context.Output("Out"); + auto dims = condition->dims(); + const int rank = dims.size(); + std::vector true_num = {0}; + std::vector vec_condition; + paddle::framework::TensorToVector( + *condition, context.device_context(), &vec_condition); + int vec_con_size = vec_condition.size(); + for (int i = 0; i < vec_con_size; ++i) { + if (vec_condition[i] > 0) true_num[0]++; + } + + out->Resize(phi::make_ddim({true_num[0], rank})); + out->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + framework::Tensor out_int32 = + context.AllocateTmpTensor(out->dims(), + dev_ctx); + Tensor num_true; + paddle::framework::TensorFromVector( + true_num, context.device_context(), &num_true); + num_true.mutable_data(context.GetPlace()); + bool as_tuple = false; + MLUCnnlTensorDesc con_desc(*condition); + MLUCnnlTensorDesc num_true_desc(num_true); + MLUCnnlTensorDesc out_int32_desc(out_int32); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Where(context, + con_desc.get(), + GetBasePtr(condition), + num_true_desc.get(), + GetBasePtr(&num_true), + as_tuple, + out_int32_desc.get(), + GetBasePtr(&out_int32)); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(context, + cast_type, + out_int32_desc.get(), + GetBasePtr(&out_int32), + out_desc.get(), + GetBasePtr(out)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(where_index, + ops::MLUWhereIndexKernel, + ops::MLUWhereIndexKernel, + ops::MLUWhereIndexKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_where_index_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_where_index_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..3b25887e3571268c9893f2cf30514676c6d36beb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_where_index_op_mlu.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys + +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + +paddle.enable_static() + + +class TestWhereIndexOp(OpTest): + + def setUp(self): + self.op_type = "where_index" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_config() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_config(self): + self.inputs = { + 'Condition': np.array([True, False, True]), + } + + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + +class TestAllFalse(unittest.TestCase): + + def setUp(self): + self.op_type = "where_index" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_config() + + def check_with_place(self, place): + scope = core.Scope() + condition = scope.var('Condition').get_tensor() + condition.set(self.cond_data, place) + + out = scope.var("Out").get_tensor() + out.set(np.full(self.shape, 0).astype('int64'), place) + + op = Operator("where_index", Condition="Condition", Out="Out") + op.run(scope, place) + + out_array = np.array(out) + self.assertTrue((out_array == self.out_data).all()) + + def init_config(self): + self.cond_data = np.array([False, False, False]) + self.shape = (3, 1) + self.out_data = np.array([], dtype='int64') + + def test_all_false(self): + self.check_with_place(self.place) + + +class TestRank2(TestWhereIndexOp): + + def init_config(self): + self.inputs = { + 'Condition': np.array([[True, False], [False, True]]), + } + + self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} + + +class TestRank3(TestWhereIndexOp): + + def init_config(self): + self.inputs = { + 'Condition': + np.array([[[True, False], [False, True]], + [[False, True], [True, False]], + [[False, False], [False, True]]]), + } + + self.outputs = { + 'Out': + np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]], + dtype='int64') + } + + +class TestWhereOpError(unittest.TestCase): + + def test_api(self): + with program_guard(Program(), Program()): + cond = fluid.layers.data(name='cond', shape=[4], dtype='bool') + result = fluid.layers.where(cond) + + exe = fluid.Executor(paddle.device.MLUPlace(0)) + exe.run(fluid.default_startup_program()) + cond_i = np.array([True, False, False, False]).astype("bool") + out = exe.run(fluid.default_main_program(), feed={'cond': cond_i}) + + +class TestWhereRaiseError(unittest.TestCase): + + def test_errors(self): + + def test_type(): + fluid.layers.where([10]) + + self.assertRaises(TypeError, test_type) + + +if __name__ == "__main__": + unittest.main()