未验证 提交 73d706ce 编写于 作者: R RuohengMa 提交者: GitHub

[PHI] bind nll_loss xpu kernel (#54043)

上级 626ea800
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230519")
set(XPU_BASE_DATE "20230523")
set(XPU_XCCL_BASE_VERSION "1.0.49.2")
set(XPU_XFT_BASE_VERSION "latest")
......
......@@ -525,6 +525,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT64})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"nll_loss", XPUKernelSet({phi::DataType::FLOAT32})},
{"nll_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& weight,
const DenseTensor& total_weight,
const DenseTensor& d_out,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* d_x) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& label_type = label.dtype();
bool label_type_match =
label_type == phi::DataType::INT32 || label_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(label_type_match,
true,
phi::errors::InvalidArgument(
"Input(Label) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
label_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto d_out_data = d_out.data<XPUType>();
auto d_x_data = dev_ctx.template Alloc<XPUType>(d_x);
auto d_x_dims = d_x->dims();
std::vector<int64_t> d_x_shape = phi::vectorize<int64_t>(d_x_dims);
auto weight_data =
weight.get_ptr() ? weight.get_ptr()->data<float>() : nullptr;
int64_t reduction_id = 0;
if (reduction == "none") {
reduction_id = 0;
} else if (reduction == "mean") {
reduction_id = 1;
} else if (reduction == "sum") {
reduction_id = 2;
}
auto total_weight_data = total_weight.data<XPUType>();
int r;
if (label_type == phi::DataType::INT32) {
const int* label_data = label.data<int>();
r = xpu::nll_loss_grad(dev_ctx.x_context(),
d_out_data,
d_x_data,
d_x_shape,
label_data,
weight_data,
reduction_id,
ignore_index,
total_weight_data);
} else if (label_type == phi::DataType::INT64) {
const int64_t* label_data = label.data<int64_t>();
r = xpu::nll_loss_grad(dev_ctx.x_context(),
d_out_data,
d_x_data,
d_x_shape,
label_data,
weight_data,
reduction_id,
ignore_index,
total_weight_data);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nll_loss_grad");
}
} // namespace phi
// TODO(xiongkun): add the non-raw kernel register here.
PD_REGISTER_KERNEL(
nll_loss_grad, XPU, ALL_LAYOUT, phi::NllLossGradKernel, float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void NllLossRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& weight,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* out,
DenseTensor* total_weight) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& label_type = label.dtype();
bool label_type_match =
label_type == phi::DataType::INT32 || label_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(label_type_match,
true,
phi::errors::InvalidArgument(
"Input(Label) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
label_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto x_data = x.data<XPUType>();
auto out_data = dev_ctx.template Alloc<XPUType>(out);
auto weight_data =
weight.get_ptr() ? weight.get_ptr()->data<XPUType>() : nullptr;
auto total_weight_data = dev_ctx.template Alloc<XPUType>(total_weight);
auto x_dims = x.dims();
std::vector<int64_t> x_shape = phi::vectorize<int64_t>(x_dims);
int64_t reduction_id = 0;
if (reduction == "none") {
reduction_id = 0;
} else if (reduction == "mean") {
reduction_id = 1;
} else if (reduction == "sum") {
reduction_id = 2;
}
int r;
if (label_type == phi::DataType::INT32) {
const int* label_data = label.data<int>();
r = xpu::nll_loss(dev_ctx.x_context(),
x_data,
out_data,
total_weight_data,
x_shape,
label_data,
weight_data,
reduction_id,
ignore_index);
} else if (label_type == phi::DataType::INT64) {
const int64_t* label_data = label.data<int64_t>();
r = xpu::nll_loss(dev_ctx.x_context(),
x_data,
out_data,
total_weight_data,
x_shape,
label_data,
weight_data,
reduction_id,
ignore_index);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nll_loss");
}
} // namespace phi
// TODO(xiongkun): add the non-raw kernel register here.
PD_REGISTER_KERNEL(nll_loss, XPU, ALL_LAYOUT, phi::NllLossRawKernel, float) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
def nll_loss_1d(
logs, dtype, targets, weight=None, reduction='mean', ignore_index=-100
):
input_shape = logs.shape
N = input_shape[0]
C = input_shape[1]
out = np.zeros_like(targets).astype(dtype)
total_weight = 0
for i in range(N):
cur_target = targets[i]
if cur_target == ignore_index:
out[i] = 0
continue
cur_weight = weight[cur_target] if weight is not None else 1
total_weight += cur_weight
out[i] = -logs[i][cur_target] * cur_weight
if reduction == 'sum':
out = np.sum(out)
total_weight = np.array([total_weight]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
elif reduction == 'mean':
out = np.sum(out)
if total_weight != 0:
out /= total_weight
total_weight = np.array([total_weight]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
elif reduction == 'none':
total_weight = np.array([0]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
def nll_loss_2d(
logs, dtype, targets, weight=None, reduction='mean', ignore_index=-100
):
input_shape = logs.shape
N = input_shape[0]
H = input_shape[2]
W = input_shape[3]
out = np.zeros_like(targets).astype(dtype)
total_weight = 0
for i in range(N):
for h in range(H):
for w in range(W):
cur_target = targets[i][h][w]
if cur_target == ignore_index:
out[i][h][w] = 0
continue
cur_weight = weight[cur_target] if weight is not None else 1
total_weight += cur_weight
out[i][h][w] = -logs[i][cur_target][h][w] * cur_weight
if reduction == 'sum':
out = np.sum(out)
total_weight = np.array([total_weight]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
elif reduction == 'mean':
out = np.sum(out)
if total_weight != 0:
out /= total_weight
total_weight = np.array([total_weight]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
elif reduction == 'none':
total_weight = np.array([0]).astype(dtype)
return {'Out': out, 'Total_weight': total_weight}
class XPUTestNLLLossOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'nll_loss'
self.use_dynamic_create_class = False
class TestNLLLossOpBase1D(XPUOpTest):
op_type = 'nll_loss'
def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.set_attrs()
self.set_inputs()
self.inputs = {
'X': self.x,
'Label': self.label,
}
if self.weight is not None:
self.inputs['Weight'] = self.weight
self.outputs = nll_loss_1d(
self.x,
self.dtype,
self.label,
self.weight,
self.attrs['reduction'],
)
def set_attrs(self):
self.attrs = {'reduction': 'none'}
def set_inputs(self):
self.class_num = 3
x_shape = [5, self.class_num]
label_shape = [5]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = np.random.random(self.class_num).astype(self.dtype)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestNLLLossOpWithWeightMean1D(TestNLLLossOpBase1D):
def set_attrs(self):
self.attrs = {'reduction': 'mean'}
class TestNLLLossOpWithWeightSum1D(TestNLLLossOpBase1D):
def set_attrs(self):
self.attrs = {'reduction': 'sum'}
class TestNLLLossOpWithoutWeightNone1D(TestNLLLossOpBase1D):
def set_inputs(self):
self.class_num = 3
x_shape = [5, self.class_num]
label_shape = [5]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'none'}
class TestNLLLossOpWithoutWeightMean1D(TestNLLLossOpBase1D):
def set_inputs(self):
self.class_num = 3
x_shape = [5, self.class_num]
label_shape = [5]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'mean'}
class TestNLLLossOpWithoutWeightSum1D(TestNLLLossOpBase1D):
def set_inputs(self):
self.class_num = 3
x_shape = [5, self.class_num]
label_shape = [5]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'sum'}
class TestNLLLossOpBase2D(XPUOpTest):
op_type = 'nll_loss'
def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.set_attrs()
self.set_inputs()
self.inputs = {'X': self.x, 'Label': self.label}
if self.weight is not None:
self.inputs['Weight'] = self.weight
self.outputs = nll_loss_2d(
self.x,
self.dtype,
self.label,
self.weight,
self.attrs['reduction'],
)
def set_attrs(self):
self.attrs = {'reduction': 'none'}
def set_inputs(self):
self.class_num = 3
x_shape = [5, self.class_num, 7, 11]
label_shape = [5, 7, 11]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = np.random.random(self.class_num).astype(self.dtype)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestNLLLossOpWithWeightMean2D(TestNLLLossOpBase2D):
def set_attrs(self):
self.attrs = {'reduction': 'mean'}
class TestNLLLossOpWithWeightSum2D(TestNLLLossOpBase2D):
def set_attrs(self):
self.attrs = {'reduction': 'sum'}
class TestNLLLossOpWithoutWeightNone2D(TestNLLLossOpBase2D):
def set_inputs(self):
self.dtype = self.in_type
self.class_num = 3
x_shape = [5, self.class_num, 7, 11]
label_shape = [5, 7, 11]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'none'}
class TestNLLLossOpWithoutWeightMean2D(TestNLLLossOpBase2D):
def set_inputs(self):
self.dtype = self.in_type
self.class_num = 3
x_shape = [5, self.class_num, 7, 11]
label_shape = [5, 7, 11]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'mean'}
class TestNLLLossOpWithoutWeightSum2D(TestNLLLossOpBase2D):
def set_inputs(self):
self.dtype = self.in_type
self.class_num = 3
x_shape = [5, self.class_num, 7, 11]
label_shape = [5, 7, 11]
self.x = np.random.random(x_shape).astype(self.dtype)
self.label = np.random.randint(
low=0, high=self.class_num, size=label_shape
).astype(np.int64)
self.weight = None
def set_attrs(self):
self.attrs = {'reduction': 'sum'}
support_types = get_xpu_op_support_types('nll_loss')
for stype in support_types:
create_test_class(globals(), XPUTestNLLLossOP, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册