未验证 提交 ec5f8cfd 编写于 作者: F fuyou765 提交者: GitHub

[MLU]add mlu kernel for where_index op (#43720)

上级 5369378b
......@@ -4175,15 +4175,31 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() {
/* static */ void MLUCnnl::Where(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const uint32_t* strides,
const uint32_t* index,
const cnnlTensorDescriptor_t num_true_desc,
const void* num_true,
const bool as_tuple,
const cnnlTensorDescriptor_t y_desc,
int* y,
const bool as_tuple) {
void* y) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlWhere(handle, x_desc, x, strides, index, y_desc, y, as_tuple));
cnnlGetWhereWorkspaceSize(handle, num_true_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlWhere_v2(handle,
x_desc,
x,
num_true_desc,
num_true,
as_tuple,
workspace_ptr,
workspace_size,
y_desc,
y));
}
/* static */ void MLUCnnl::InTopK(const ExecutionContext& ctx,
......
......@@ -1607,12 +1607,11 @@ class MLUCnnl {
static void Where(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const uint32_t* strides,
const uint32_t* index,
const cnnlTensorDescriptor_t num_true_desc,
const void* num_true,
const bool as_tuple,
const cnnlTensorDescriptor_t y_desc,
int* y,
const bool as_tuple);
void* y);
static void Conv2D(const ExecutionContext& ctx,
const cnnlConvolutionDescriptor_t conv_desc,
const cnnlDataType_t tensor_dtype,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class MLUWhereIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<Tensor>("Condition");
auto* out = context.Output<Tensor>("Out");
auto dims = condition->dims();
const int rank = dims.size();
std::vector<int> true_num = {0};
std::vector<T> vec_condition;
paddle::framework::TensorToVector(
*condition, context.device_context(), &vec_condition);
int vec_con_size = vec_condition.size();
for (int i = 0; i < vec_con_size; ++i) {
if (vec_condition[i] > 0) true_num[0]++;
}
out->Resize(phi::make_ddim({true_num[0], rank}));
out->mutable_data<int64_t>(context.GetPlace());
auto& dev_ctx = context.template device_context<MLUDeviceContext>();
framework::Tensor out_int32 =
context.AllocateTmpTensor<int32_t, MLUDeviceContext>(out->dims(),
dev_ctx);
Tensor num_true;
paddle::framework::TensorFromVector(
true_num, context.device_context(), &num_true);
num_true.mutable_data<int>(context.GetPlace());
bool as_tuple = false;
MLUCnnlTensorDesc con_desc(*condition);
MLUCnnlTensorDesc num_true_desc(num_true);
MLUCnnlTensorDesc out_int32_desc(out_int32);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Where(context,
con_desc.get(),
GetBasePtr(condition),
num_true_desc.get(),
GetBasePtr(&num_true),
as_tuple,
out_int32_desc.get(),
GetBasePtr(&out_int32));
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64);
MLUCnnl::Cast(context,
cast_type,
out_int32_desc.get(),
GetBasePtr(&out_int32),
out_desc.get(),
GetBasePtr(out));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(where_index,
ops::MLUWhereIndexKernel<int>,
ops::MLUWhereIndexKernel<bool>,
ops::MLUWhereIndexKernel<float>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
paddle.enable_static()
class TestWhereIndexOp(OpTest):
def setUp(self):
self.op_type = "where_index"
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.init_config()
def test_check_output(self):
self.check_output_with_place(self.place)
def init_config(self):
self.inputs = {
'Condition': np.array([True, False, True]),
}
self.outputs = {'Out': np.array([[0], [2]], dtype='int64')}
class TestAllFalse(unittest.TestCase):
def setUp(self):
self.op_type = "where_index"
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.init_config()
def check_with_place(self, place):
scope = core.Scope()
condition = scope.var('Condition').get_tensor()
condition.set(self.cond_data, place)
out = scope.var("Out").get_tensor()
out.set(np.full(self.shape, 0).astype('int64'), place)
op = Operator("where_index", Condition="Condition", Out="Out")
op.run(scope, place)
out_array = np.array(out)
self.assertTrue((out_array == self.out_data).all())
def init_config(self):
self.cond_data = np.array([False, False, False])
self.shape = (3, 1)
self.out_data = np.array([], dtype='int64')
def test_all_false(self):
self.check_with_place(self.place)
class TestRank2(TestWhereIndexOp):
def init_config(self):
self.inputs = {
'Condition': np.array([[True, False], [False, True]]),
}
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
class TestRank3(TestWhereIndexOp):
def init_config(self):
self.inputs = {
'Condition':
np.array([[[True, False], [False, True]],
[[False, True], [True, False]],
[[False, False], [False, True]]]),
}
self.outputs = {
'Out':
np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]],
dtype='int64')
}
class TestWhereOpError(unittest.TestCase):
def test_api(self):
with program_guard(Program(), Program()):
cond = fluid.layers.data(name='cond', shape=[4], dtype='bool')
result = fluid.layers.where(cond)
exe = fluid.Executor(paddle.device.MLUPlace(0))
exe.run(fluid.default_startup_program())
cond_i = np.array([True, False, False, False]).astype("bool")
out = exe.run(fluid.default_main_program(), feed={'cond': cond_i})
class TestWhereRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.where([10])
self.assertRaises(TypeError, test_type)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册