未验证 提交 b2912939 编写于 作者: C cifar10 提交者: GitHub

add mlu arg_max kernel (#43624)

上级 be05f84b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class ArgMaxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
auto axis = static_cast<int>(ctx.Attr<int64_t>("axis"));
auto dtype = ctx.Attr<int>("dtype");
const bool& flatten = ctx.Attr<bool>("flatten");
if (x->numel() == 0) return;
PADDLE_ENFORCE_EQ(
(dtype == 2 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in argmax op must be [%s] or [%s], "
"but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
if (axis < 0) {
framework::DDim x_dims;
x_dims = x->dims();
axis += x_dims.size();
}
framework::Tensor flatten_x(x->type());
flatten_x.ShareDataWith(*x);
if (flatten) {
flatten_x.Resize(phi::make_ddim({x->numel()}));
// if flatten, the axis just as 0
axis = 0;
}
std::vector<int> reduce_dims;
reduce_dims.push_back(axis);
auto out_dims = out->dims();
int out_count = out_dims[0];
for (int i = 1; i < out_dims.size(); i++) {
out_count = out_count * out_dims[i];
}
size_t indices_size_inbytes = out_count * sizeof(int32_t);
auto& dev_ctx = ctx.template device_context<MLUDeviceContext>();
framework::Tensor value_out =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(out->dims(), dev_ctx);
MLUCnnlTensorDesc value_out_desc(value_out);
MLUCnnlTensorDesc input_desc(flatten_x, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(flatten_x.dtype()));
MLUCnnlReduceDesc reduction_desc(
reduce_dims, CNNL_REDUCE_MAX_LAST_INDEX, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_ONLY_INDICES, CNNL_32BIT_INDICES);
if (dtype == 2) {
out->template mutable_data<int32_t>(ctx.GetPlace());
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, input_desc.get(), GetBasePtr(&flatten_x),
indices_size_inbytes /*indices_size*/, GetBasePtr(out),
nullptr, value_out_desc.get(), GetBasePtr(&value_out));
} else {
out->template mutable_data<int64_t>(ctx.GetPlace());
framework::Tensor out_int32 =
ctx.AllocateTmpTensor<int32_t, MLUDeviceContext>(out->dims(),
dev_ctx);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, input_desc.get(), GetBasePtr(&flatten_x),
indices_size_inbytes /*indices_size*/,
GetBasePtr(&out_int32), nullptr, value_out_desc.get(),
GetBasePtr(&value_out));
// cast indices type to int64
MLUCnnlTensorDesc out_int32_desc(out_int32);
MLUCnnlTensorDesc cast_output_desc(*out);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64);
MLUCnnl::Cast(ctx, cast_type, out_int32_desc.get(),
GetBasePtr(&out_int32), cast_output_desc.get(),
GetBasePtr(out));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(arg_max, ops::ArgMaxMLUKernel<int>,
ops::ArgMaxMLUKernel<float>,
ops::ArgMaxMLUKernel<paddle::platform::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
paddle.enable_static()
class BaseTestCase(OpTest):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self):
self.check_output_with_place(self.place)
# test argmax, dtype: float16
class TestArgMaxFloat16Case1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float16'
self.axis = -1
class TestArgMaxFloat16Case2(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float16'
self.axis = 0
class TestArgMaxFloat16Case3(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float16'
self.axis = 1
class TestArgMaxFloat16Case4(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float16'
self.axis = 2
class TestArgMaxFloat16Case5(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float16'
self.axis = -1
class TestArgMaxFloat16Case6(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float16'
self.axis = 0
class TestArgMaxFloat16Case7(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float16'
self.axis = 1
class TestArgMaxFloat16Case8(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (1, )
self.dtype = 'float16'
self.axis = 0
class TestArgMaxFloat16Case9(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (2, )
self.dtype = 'float16'
self.axis = 0
class TestArgMaxFloat16Case10(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, )
self.dtype = 'float16'
self.axis = 0
# test argmax, dtype: float32
class TestArgMaxFloat32Case1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = -1
class TestArgMaxFloat32Case2(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case3(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 1
class TestArgMaxFloat32Case4(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 2
class TestArgMaxFloat32Case5(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = -1
class TestArgMaxFloat32Case6(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case7(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = 1
class TestArgMaxFloat32Case8(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (1, )
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case9(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (2, )
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case10(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, )
self.dtype = 'float32'
self.axis = 0
class BaseTestComplex1_1(OpTest):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float32'
self.axis = 2
def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class BaseTestComplex1_2(OpTest):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float16'
self.axis = 2
def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class TestArgMaxAPI(unittest.TestCase):
def initTestCase(self):
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
def setUp(self):
self.initTestCase()
self.__class__.use_mlu = True
self.place = [paddle.MLUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2022)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(numpy_input, axis=self.axis)
paddle_output = paddle.argmax(tensor_input, axis=self.axis)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()),
True)
paddle.enable_static()
for place in self.place:
run(place)
class TestArgMaxAPI_2(unittest.TestCase):
def initTestCase(self):
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
self.keep_dims = True
def setUp(self):
self.initTestCase()
self.__class__.use_mlu = True
self.place = [paddle.MLUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2022)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(numpy_input,
axis=self.axis).reshape(1, 4, 5)
paddle_output = paddle.argmax(tensor_input,
axis=self.axis,
keepdim=self.keep_dims)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()),
True)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()
for place in self.place:
run(place)
class TestArgMaxAPI_3(unittest.TestCase):
def initTestCase(self):
self.dims = (1, 9)
self.dtype = 'float32'
def setUp(self):
self.initTestCase()
self.__class__.use_mlu = True
self.place = [paddle.MLUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2022)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(numpy_input).reshape([1])
paddle_output = paddle.argmax(tensor_input)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()),
True)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()
for place in self.place:
run(place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册