未验证 提交 8757fc5b 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix dtype for arg_max, test=develop (#36457)

上级 3845afff
...@@ -17,30 +17,49 @@ limitations under the Licnse. */ ...@@ -17,30 +17,49 @@ limitations under the Licnse. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename DeviceContext, typename T> template <typename T>
class ArgMaxNPUKernel : public framework::OpKernel<T> { struct VisitDataArgNPUMaxFunctor {
public: const framework::ExecutionContext& ctx;
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
int64_t axis = ctx.Attr<int64_t>("axis");
auto dtype = ctx.Attr<int>("dtype");
auto* out = ctx.Output<Tensor>("Out"); explicit VisitDataArgNPUMaxFunctor(const framework::ExecutionContext& ctx)
out->mutable_data<int32_t>(ctx.GetPlace()); : ctx(ctx) {}
template <typename Tout>
void apply() const {
auto& x = *(ctx.Input<framework::Tensor>("X"));
auto& out = *(ctx.Output<framework::Tensor>("Out"));
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto dtype = ctx.Attr<int>("dtype");
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
NpuOpRunner runner; NpuOpRunner runner;
runner.SetType("ArgMaxV2") runner.SetType("ArgMaxV2")
.AddInput(*x) .AddInput(x)
.AddInput(std::vector<int64_t>{axis}) .AddInput(std::vector<int64_t>{axis})
.AddOutput(*out) .AddOutput(out)
.AddAttr("dtype", dtype); .AddAttrDataType("dtype", dtype)
.Run(stream);
}
};
auto stream = template <typename T>
ctx.template device_context<paddle::platform::NPUDeviceContext>() class ArgMaxNPUKernel : public framework::OpKernel<T> {
.stream(); public:
runner.Run(stream); void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgNPUMaxFunctor<T>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgNPUMaxFunctor<T>(ctx));
} }
}; };
...@@ -48,7 +67,5 @@ class ArgMaxNPUKernel : public framework::OpKernel<T> { ...@@ -48,7 +67,5 @@ class ArgMaxNPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(arg_max, ops::ArgMaxNPUKernel<float>,
arg_max, ops::ArgMaxNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ArgMaxNPUKernel<paddle::platform::float16>);
ops::ArgMaxNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
...@@ -188,6 +188,21 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, ...@@ -188,6 +188,21 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
return *this; return *this;
} }
NpuOpRunner &NpuOpRunner::AddAttrDataType(const std::string &name,
const NPUAttribute &attr) {
PADDLE_ENFORCE_EQ(
(attr.type() == typeid(int)), true,
platform::errors::InvalidArgument(
"Attr type is NOT equal to framework::proto::VarType::Type."));
if (!attr_) {
attr_ = aclopCreateAttr();
}
auto dtype = ConvertToNpuDtype(
static_cast<framework::proto::VarType::Type>(BOOST_GET_CONST(int, attr)));
PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrDataType(attr_, name.c_str(), dtype));
return *this;
}
NpuOpRunner &NpuOpRunner::AddAttrs(const NPUAttributeMap &attrs) { NpuOpRunner &NpuOpRunner::AddAttrs(const NPUAttributeMap &attrs) {
for (const auto &pair : attrs) { for (const auto &pair : attrs) {
AddAttr(pair.first, pair.second); AddAttr(pair.first, pair.second);
......
...@@ -58,6 +58,12 @@ class NpuOpRunner { ...@@ -58,6 +58,12 @@ class NpuOpRunner {
NpuOpRunner &AddAttr(const std::string &name, const NPUAttribute &attr); NpuOpRunner &AddAttr(const std::string &name, const NPUAttribute &attr);
// NOTE(qili93): need to add indivisual api for aclopSetAttrDataType
// as typeid(aclDataType) and typeid(framework::proto::VarType::Type)
// always go to attr.type() == typeid(int) to call aclopSetAttrInt
NpuOpRunner &AddAttrDataType(const std::string &name,
const NPUAttribute &attr);
NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs); NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs);
NpuOpRunner &AddInput(const Tensor &tensor); NpuOpRunner &AddInput(const Tensor &tensor);
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -20,30 +20,31 @@ import sys ...@@ -20,30 +20,31 @@ import sys
sys.path.append("..") sys.path.append("..")
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
paddle.enable_static() paddle.enable_static()
class BaseTestCase(OpTest): class BaseTestCase(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self): def initTestCase(self):
self.op_type = 'arg_max' self.op_type = 'arg_max'
self.dims = (3, 4) self.dims = (3, 4, 5)
self.dtype = 'float32' self.dtype = 'float32'
self.axis = 1 self.axis = 0
def setUp(self): def setUp(self):
self.set_npu()
self.initTestCase() self.initTestCase()
self.__class__.use_npu = True self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
self.place = paddle.NPUPlace(0)
np.random.seed(2021)
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis} self.attrs = {'axis': self.axis}
if self.op_type == "arg_min": self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
else:
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
...@@ -211,6 +212,64 @@ class TestArgMaxFloat32Case10(BaseTestCase): ...@@ -211,6 +212,64 @@ class TestArgMaxFloat32Case10(BaseTestCase):
self.axis = 0 self.axis = 0
class BaseTestComplex1_1(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float32'
self.axis = 2
def setUp(self):
self.set_npu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(
self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class BaseTestComplex1_2(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float16'
self.axis = 2
def setUp(self):
self.set_npu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(
self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class TestArgMaxAPI(unittest.TestCase): class TestArgMaxAPI(unittest.TestCase):
def initTestCase(self): def initTestCase(self):
self.dims = (3, 4, 5) self.dims = (3, 4, 5)
......
...@@ -1675,11 +1675,16 @@ def cross_entropy(input, ...@@ -1675,11 +1675,16 @@ def cross_entropy(input,
raise ValueError( raise ValueError(
"Target({}) is out of class_dimension's upper bound({})". "Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1)) format(invalid_label[0], input.shape[axis] - 1))
if core.is_compiled_with_npu():
_, out = _C_ops.softmax_with_cross_entropy( _, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis, ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax) 'use_softmax', use_softmax)
else:
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if weight is not None: if weight is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册