未验证 提交 8757fc5b 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix dtype for arg_max, test=develop (#36457)

上级 3845afff
......@@ -17,30 +17,49 @@ limitations under the Licnse. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename DeviceContext, typename T>
class ArgMaxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
int64_t axis = ctx.Attr<int64_t>("axis");
auto dtype = ctx.Attr<int>("dtype");
template <typename T>
struct VisitDataArgNPUMaxFunctor {
const framework::ExecutionContext& ctx;
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<int32_t>(ctx.GetPlace());
explicit VisitDataArgNPUMaxFunctor(const framework::ExecutionContext& ctx)
: ctx(ctx) {}
template <typename Tout>
void apply() const {
auto& x = *(ctx.Input<framework::Tensor>("X"));
auto& out = *(ctx.Output<framework::Tensor>("Out"));
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto dtype = ctx.Attr<int>("dtype");
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
NpuOpRunner runner;
runner.SetType("ArgMaxV2")
.AddInput(*x)
.AddInput(x)
.AddInput(std::vector<int64_t>{axis})
.AddOutput(*out)
.AddAttr("dtype", dtype);
.AddOutput(out)
.AddAttrDataType("dtype", dtype)
.Run(stream);
}
};
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
template <typename T>
class ArgMaxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataTypeTiny(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgNPUMaxFunctor<T>(ctx));
return;
}
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgNPUMaxFunctor<T>(ctx));
}
};
......@@ -48,7 +67,5 @@ class ArgMaxNPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
arg_max, ops::ArgMaxNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ArgMaxNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(arg_max, ops::ArgMaxNPUKernel<float>,
ops::ArgMaxNPUKernel<paddle::platform::float16>);
......@@ -188,6 +188,21 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
return *this;
}
NpuOpRunner &NpuOpRunner::AddAttrDataType(const std::string &name,
const NPUAttribute &attr) {
PADDLE_ENFORCE_EQ(
(attr.type() == typeid(int)), true,
platform::errors::InvalidArgument(
"Attr type is NOT equal to framework::proto::VarType::Type."));
if (!attr_) {
attr_ = aclopCreateAttr();
}
auto dtype = ConvertToNpuDtype(
static_cast<framework::proto::VarType::Type>(BOOST_GET_CONST(int, attr)));
PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrDataType(attr_, name.c_str(), dtype));
return *this;
}
NpuOpRunner &NpuOpRunner::AddAttrs(const NPUAttributeMap &attrs) {
for (const auto &pair : attrs) {
AddAttr(pair.first, pair.second);
......
......@@ -58,6 +58,12 @@ class NpuOpRunner {
NpuOpRunner &AddAttr(const std::string &name, const NPUAttribute &attr);
// NOTE(qili93): need to add indivisual api for aclopSetAttrDataType
// as typeid(aclDataType) and typeid(framework::proto::VarType::Type)
// always go to attr.type() == typeid(int) to call aclopSetAttrInt
NpuOpRunner &AddAttrDataType(const std::string &name,
const NPUAttribute &attr);
NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs);
NpuOpRunner &AddInput(const Tensor &tensor);
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -20,30 +20,31 @@ import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
paddle.enable_static()
class BaseTestCase(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 1
self.axis = 0
def setUp(self):
self.set_npu()
self.initTestCase()
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
np.random.seed(2021)
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
if self.op_type == "arg_min":
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
else:
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self):
self.check_output_with_place(self.place)
......@@ -211,6 +212,64 @@ class TestArgMaxFloat32Case10(BaseTestCase):
self.axis = 0
class BaseTestComplex1_1(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float32'
self.axis = 2
def setUp(self):
self.set_npu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(
self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class BaseTestComplex1_2(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (4, 5, 6)
self.dtype = 'float16'
self.axis = 2
def setUp(self):
self.set_npu()
self.initTestCase()
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out': np.argmax(
self.x, axis=self.axis).astype("int32")
}
def test_check_output(self):
self.check_output_with_place(self.place)
class TestArgMaxAPI(unittest.TestCase):
def initTestCase(self):
self.dims = (3, 4, 5)
......
......@@ -1675,11 +1675,16 @@ def cross_entropy(input,
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if core.is_compiled_with_npu():
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
else:
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
if weight is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册