未验证 提交 55bfc6cb 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer the nll_loss kernel to phi and pass the test (#39936)

* transfer the nll_loss_op and pass the CI

* push

* fix by self-review

* fix by cr

* add nll_loss

* fix code
上级 7f43055d
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nll_loss_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -264,10 +264,3 @@ REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker,
ops::NLLLossGradMaker<paddle::framework::OpDesc>,
ops::NLLLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp);
REGISTER_OP_CPU_KERNEL(
nll_loss, ops::NLLLossOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::NLLLossOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
nll_loss_grad,
ops::NLLLossGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::NLLLossGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data,
const int64_t batch_size, const int64_t n_classes,
const std::string reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int64_t i = 0; i < batch_size; ++i) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"Label value is out of range. "
"Expected label value in range of [0, %d), but "
"received value is %d.",
n_classes, cur_label));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int64_t i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -= x_data[i * n_classes + cur_label] * cur_weight;
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename T>
static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data,
const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3,
const std::string reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[index] = -x_data[i * sample_size + cur_label * map_size +
h * in_dim3 + w] *
cur_weight;
}
}
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -=
x_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] *
cur_weight;
}
}
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename DeviceContext, typename T>
class NLLLossOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* out = ctx.Output<Tensor>("Out");
auto* total_weight = ctx.Output<Tensor>("Total_weight");
auto reduction = ctx.Attr<std::string>("reduction");
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto x_data = x->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto out_data = out->mutable_data<T>(ctx.GetPlace());
auto total_weight_data = total_weight->mutable_data<T>(ctx.GetPlace());
*total_weight_data = 0;
auto x_dims = x->dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_1D<T>(out_data, total_weight_data, x_data, label_data,
weight_data, batch_size, n_classes, reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_2D<T>(out_data, total_weight_data, x_data, label_data,
weight_data, batch_size, n_classes, in_dim2, in_dim3,
reduction, ignore_index);
}
}
};
template <typename T>
static void nll_loss_grad_1D(T* dx_data, const T* dout_data,
const int64_t* label_data, const T* weight_data,
const T* total_weight_data,
const int64_t batch_size, const int64_t n_classes,
const std::string reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight;
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[i * n_classes + cur_label] /= total_weight_val;
}
}
}
template <typename T>
static void nll_loss_grad_2D(T* dx_data, const T* dout_data,
const int64_t* label_data, const T* weight_data,
const T* total_weight_data,
const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3,
const std::string reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] =
-cur_weight * dout_data[index];
}
}
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
const auto dx_index =
i * sample_size + cur_label * map_size + h * in_dim3 + w;
dx_data[dx_index] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[dx_index] /= total_weight_val;
}
}
}
}
}
template <typename DeviceContext, typename T>
class NLLLossGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* total_weight = ctx.Input<Tensor>("Total_weight");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto total_weight_data = total_weight->data<T>();
memset(dx_data, 0, dx->numel() * sizeof(T));
const auto x_dims = x->dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_grad_1D(dx_data, dout_data, label_data, weight_data,
total_weight_data, batch_size, n_classes, reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_grad_2D(dx_data, dout_data, label_data, weight_data,
total_weight_data, batch_size, n_classes, in_dim2,
in_dim3, reduction, ignore_index);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_grad_kernel.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/math.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
static void nll_loss_grad_1D(T* dx_data,
const T* dout_data,
const int64_t* label_data,
const T* weight_data,
const T* total_weight_data,
const int64_t batch_size,
const int64_t n_classes,
const std::string reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight;
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[i * n_classes + cur_label] /= total_weight_val;
}
}
}
template <typename T>
static void nll_loss_grad_2D(T* dx_data,
const T* dout_data,
const int64_t* label_data,
const T* weight_data,
const T* total_weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t in_dim2,
const int64_t in_dim3,
const std::string& reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] =
-cur_weight * dout_data[index];
}
}
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
const auto dx_index =
i * sample_size + cur_label * map_size + h * in_dim3 + w;
dx_data[dx_index] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[dx_index] /= total_weight_val;
}
}
}
}
}
template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& d_out,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* dx) {
auto dx_data = dev_ctx.template Alloc<T>(dx);
auto dout_data = d_out.data<T>();
auto label_data = labels.data<int64_t>();
auto weight_data = weight.get_ptr() ? weight.get_ptr()->data<T>() : nullptr;
auto total_weight_data = total_weight.data<T>();
memset(dx_data, 0, dx->numel() * sizeof(T));
const auto x_dims = x.dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_grad_1D(dx_data,
dout_data,
label_data,
weight_data,
total_weight_data,
batch_size,
n_classes,
reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_grad_2D(dx_data,
dout_data,
label_data,
weight_data,
total_weight_data,
batch_size,
n_classes,
in_dim2,
in_dim3,
reduction,
ignore_index);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
nll_loss_grad, CPU, ALL_LAYOUT, phi::NllLossGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
static void nll_loss_1D(T* out_data,
T* total_weight_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const std::string& reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int64_t i = 0; i < batch_size; ++i) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes,
true,
phi::errors::InvalidArgument(
"Label value is out of range. "
"Expected label value in range of [0, %d), but "
"received value is %d.",
n_classes,
cur_label));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int64_t i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(
cur_label >= 0 && cur_label < n_classes,
true,
phi::errors::InvalidArgument("label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -= x_data[i * n_classes + cur_label] * cur_weight;
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename T>
static void nll_loss_2D(T* out_data,
T* total_weight_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t in_dim2,
const int64_t in_dim3,
const std::string& reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes,
true,
phi::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[index] = -x_data[i * sample_size + cur_label * map_size +
h * in_dim3 + w] *
cur_weight;
}
}
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(
cur_label >= 0 && cur_label < n_classes,
true,
phi::errors::InvalidArgument("label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -=
x_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] *
cur_weight;
}
}
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename T, typename Context>
void NllLossRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& labels,
paddle::optional<const DenseTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* out,
DenseTensor* total_weight) {
auto x_data = x.data<T>();
auto label_data = labels.data<int64_t>();
auto weight_data = weight.get_ptr() ? weight.get_ptr()->data<T>() : nullptr;
auto out_data = dev_ctx.template Alloc<T>(out);
auto total_weight_data = dev_ctx.template Alloc<T>(total_weight);
*total_weight_data = 0;
auto x_dims = x.dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_1D<T>(out_data,
total_weight_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_2D<T>(out_data,
total_weight_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
in_dim2,
in_dim3,
reduction,
ignore_index);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
nll_loss, CPU, ALL_LAYOUT, phi::NllLossRawKernel, float, double) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <thrust/functional.h>
#include <algorithm>
#include <functional>
#include <string>
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/nll_loss_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static const int NTHREADS = 32;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data,
__global__ void GPUNLLLossForward1D_no_reduce(T* out_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
......@@ -51,11 +53,15 @@ __global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data,
}
template <typename T>
__global__ void GPUNLLLossForward1D_with_reduce(
T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data, const int64_t batch_size,
const int64_t n_classes, const int64_t size_average,
const int64_t ignore_index) {
__global__ void GPUNLLLossForward1D_with_reduce(T* out_data,
T* total_weight_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t size_average,
const int64_t ignore_index) {
__shared__ T sharedInputs[NTHREADS], sharedWeights[NTHREADS];
sharedInputs[threadIdx.x] = 0;
sharedWeights[threadIdx.x] = 0;
......@@ -99,9 +105,11 @@ __global__ void GPUNLLLossForward1D_with_reduce(
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp, int N>
__device__ void reduceNValuesInBlock(T* smem, T threadVals[N],
__device__ void reduceNValuesInBlock(T* smem,
T threadVals[N],
const unsigned int numVals,
ReduceOp reduceOp, T init) {
ReduceOp reduceOp,
T init) {
if (numVals == 0) {
#pragma unroll
for (int i = 0; i < N; ++i) {
......@@ -175,18 +183,26 @@ __device__ void reduceNValuesInBlock(T* smem, T threadVals[N],
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp>
__device__ T reduceBlock(T* smem, const unsigned int numVals, T threadVal,
ReduceOp reduceOp, T init) {
reduceNValuesInBlock<T, ReduceOp, 1>(smem, &threadVal, numVals, reduceOp,
init);
__device__ T reduceBlock(T* smem,
const unsigned int numVals,
T threadVal,
ReduceOp reduceOp,
T init) {
reduceNValuesInBlock<T, ReduceOp, 1>(
smem, &threadVal, numVals, reduceOp, init);
return threadVal;
}
template <typename T>
__global__ void GPUNLLLossForward2D_no_reduce(
T* out_data, const T* x_data, const int64_t* label_data,
const T* weight_data, const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) {
__global__ void GPUNLLLossForward2D_no_reduce(T* out_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t in_dim2,
const int64_t in_dim3,
const int64_t ignore_index) {
const int64_t map_size = in_dim2 * in_dim3;
const int64_t sample_size = n_classes * map_size;
const int64_t out_numel = batch_size * map_size;
......@@ -211,11 +227,16 @@ __global__ void GPUNLLLossForward2D_no_reduce(
}
template <typename T>
__global__ void GPUNLLLossForward2D_with_reduce(
T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data, const int64_t batch_size,
const int64_t n_classes, const int64_t map_nelem,
const int64_t blocks_per_sample, const int64_t ignore_index) {
__global__ void GPUNLLLossForward2D_with_reduce(T* out_data,
T* total_weight_data,
const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t map_nelem,
const int64_t blocks_per_sample,
const int64_t ignore_index) {
__shared__ T partial_sums[kNumCUDAThreads];
int64_t i;
T input_sum = 0;
......@@ -228,7 +249,8 @@ __global__ void GPUNLLLossForward2D_with_reduce(
int64_t ioffset = sample * map_nelem * n_classes;
int64_t step = blockDim.x * blocks_per_sample;
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem; i += step) {
i < map_nelem;
i += step) {
const int64_t cur_label = label_data[toffset + i];
if (cur_label != ignore_index) {
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
......@@ -242,8 +264,8 @@ __global__ void GPUNLLLossForward2D_with_reduce(
input_sum =
reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus<T>(), (T)0);
__syncthreads();
acc_weight = reduceBlock(partial_sums, blockDim.x, acc_weight,
thrust::plus<T>(), (T)0);
acc_weight = reduceBlock(
partial_sums, blockDim.x, acc_weight, thrust::plus<T>(), (T)0);
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(total_weight_data, acc_weight);
......@@ -258,12 +280,14 @@ __global__ void GPUNLLLossForward2D_size_average(T* out_data,
*out_data /= *total_weight_data;
}
}
template <typename T>
__global__ void GPUNLLLossBackward1D_no_reduce(
T* dx_data, const int64_t* label_data, const T* weight_data,
const T* dout_data, const int64_t batch_size, const int64_t n_classes,
const int64_t ignore_index) {
__global__ void GPUNLLLossBackward1D_no_reduce(T* dx_data,
const int64_t* label_data,
const T* weight_data,
const T* dout_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t ignore_index) {
CUDA_KERNEL_LOOP(i, batch_size) {
const int64_t cur_label = label_data[i];
if (cur_label == ignore_index) {
......@@ -275,11 +299,15 @@ __global__ void GPUNLLLossBackward1D_no_reduce(
}
template <typename T>
__global__ void GPUNLLLossBackward1D_with_reduce(
T* dx_data, const T* total_weight_data, const int64_t* label_data,
const T* weight_data, const T* dout_data, const int64_t batch_size,
const int64_t n_classes, const int64_t size_average,
const int64_t ignore_index) {
__global__ void GPUNLLLossBackward1D_with_reduce(T* dx_data,
const T* total_weight_data,
const int64_t* label_data,
const T* weight_data,
const T* dout_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t size_average,
const int64_t ignore_index) {
if (*total_weight_data <= 0) {
return;
}
......@@ -295,10 +323,15 @@ __global__ void GPUNLLLossBackward1D_with_reduce(
}
template <typename T>
__global__ void GPUNLLLossBackward2D_no_reduce(
T* dx_data, const int64_t* label_data, const T* weight_data,
const T* dout_data, const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) {
__global__ void GPUNLLLossBackward2D_no_reduce(T* dx_data,
const int64_t* label_data,
const T* weight_data,
const T* dout_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t in_dim2,
const int64_t in_dim3,
const int64_t ignore_index) {
const int64_t map_size = in_dim2 * in_dim3;
const int64_t sample_size = n_classes * map_size;
const int64_t out_numel = batch_size * map_size;
......@@ -319,10 +352,16 @@ __global__ void GPUNLLLossBackward2D_no_reduce(
template <typename T>
__global__ void GPUNLLLossBackward2D_with_reduce(
T* dx_data, const T* total_weight_data, const int64_t* label_data,
const T* weight_data, const T* dout_data, const int64_t batch_size,
const int64_t n_classes, const int64_t map_nelem,
const int64_t blocks_per_sample, const int64_t size_average,
T* dx_data,
const T* total_weight_data,
const int64_t* label_data,
const T* weight_data,
const T* dout_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t map_nelem,
const int64_t blocks_per_sample,
const int64_t size_average,
const int64_t ignore_index) {
if (*total_weight_data <= 0) {
return;
......@@ -334,7 +373,8 @@ __global__ void GPUNLLLossBackward2D_with_reduce(
int toffset = sample * map_nelem;
int ioffset = sample * map_nelem * n_classes;
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem; i += step) {
i < map_nelem;
i += step) {
const int64_t cur_label = label_data[toffset + i];
if (cur_label != ignore_index) {
dx_data[ioffset + i + map_nelem * cur_label] =
......@@ -343,158 +383,4 @@ __global__ void GPUNLLLossBackward2D_with_reduce(
}
}
template <typename DeviceContext, typename T>
class NLLLossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* out = ctx.Output<Tensor>("Out");
auto* total_weight = ctx.Output<Tensor>("Total_weight");
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
auto x_data = x->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
auto total_weight_data = total_weight->mutable_data<T>(ctx.GetPlace());
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
#ifdef PADDLE_WITH_HIP
hipMemset(total_weight_data, 0, sizeof(T));
#else
cudaMemset(total_weight_data, 0, sizeof(T));
#endif
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
int64_t size_average = (int64_t)(reduction == "mean");
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossForward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_data, x_data, label_data, weight_data, batch_size, n_classes,
ignore_index);
} else {
GPUNLLLossForward1D_with_reduce<
T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
out_data, total_weight_data, x_data, label_data, weight_data,
batch_size, n_classes, size_average, ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossForward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_data, x_data, label_data, weight_data, batch_size, n_classes,
in_dim2, in_dim3, ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossForward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(
out_data, total_weight_data, x_data, label_data, weight_data,
batch_size, n_classes, map_size, blocks_per_sample, ignore_index);
if (size_average) {
GPUNLLLossForward2D_size_average<T><<<1, 1, 0, dev_ctx.stream()>>>(
out_data, total_weight_data);
}
}
}
}
};
template <typename DeviceContext, typename T>
class NLLLossGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* total_weight = ctx.Input<Tensor>("Total_weight");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto total_weight_data = total_weight->data<T>();
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
#ifdef PADDLE_WITH_HIP
hipMemset(dx_data, 0, dx->numel() * sizeof(T));
#else
cudaMemset(dx_data, 0, dx->numel() * sizeof(T));
#endif
int64_t size_average = (int64_t)(reduction == "mean");
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossBackward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, label_data, weight_data, dout_data, batch_size, n_classes,
ignore_index);
} else {
GPUNLLLossBackward1D_with_reduce<
T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
dx_data, total_weight_data, label_data, weight_data, dout_data,
batch_size, n_classes, size_average, ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossBackward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, label_data, weight_data, dout_data, batch_size, n_classes,
in_dim2, in_dim3, ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossBackward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, total_weight_data, label_data, weight_data, dout_data,
batch_size, n_classes, map_size, blocks_per_sample, size_average,
ignore_index);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
nll_loss,
ops::NLLLossCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::NLLLossCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
nll_loss_grad,
ops::NLLLossGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::NLLLossGradCUDAKernel<paddle::platform::CUDADeviceContext, double>);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/nll_loss.h"
namespace phi {
template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& labels,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& dout,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* dx) {
auto dx_data = dev_ctx.template Alloc<T>(dx);
auto dout_data = dout.data<T>();
auto label_data = labels.data<int64_t>();
auto weight_data = weight.get_ptr() ? weight.get_ptr()->data<T>() : nullptr;
auto total_weight_data = total_weight.data<T>();
#ifdef PADDLE_WITH_HIP
hipMemset(dx_data, 0, dx->numel() * sizeof(T));
#else
cudaMemset(dx_data, 0, dx->numel() * sizeof(T));
#endif
int64_t size_average = (int64_t)(reduction == "mean");
auto x_dims = x.dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossBackward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(dx_data,
label_data,
weight_data,
dout_data,
batch_size,
n_classes,
ignore_index);
} else {
GPUNLLLossBackward1D_with_reduce<T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
dx_data,
total_weight_data,
label_data,
weight_data,
dout_data,
batch_size,
n_classes,
size_average,
ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossBackward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(dx_data,
label_data,
weight_data,
dout_data,
batch_size,
n_classes,
in_dim2,
in_dim3,
ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossBackward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(dx_data,
total_weight_data,
label_data,
weight_data,
dout_data,
batch_size,
n_classes,
map_size,
blocks_per_sample,
size_average,
ignore_index);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
nll_loss_grad, GPU, ALL_LAYOUT, phi::NllLossGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/nll_loss.h"
namespace phi {
template <typename T, typename Context>
void NllLossRawKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
paddle::optional<const DenseTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* out,
DenseTensor* total_weight) {
auto* x = &input;
auto x_data = x->data<T>();
auto out_data = dev_ctx.template Alloc<T>(out);
auto total_weight_data = dev_ctx.template Alloc<T>(total_weight);
auto label_data = label.data<int64_t>();
auto weight_data = weight.get_ptr() ? weight.get_ptr()->data<T>() : nullptr;
#ifdef PADDLE_WITH_HIP
hipMemset(total_weight_data, 0, sizeof(T));
#else
cudaMemset(total_weight_data, 0, sizeof(T));
#endif
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
int64_t size_average = (int64_t)(reduction == "mean");
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossForward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
ignore_index);
} else {
GPUNLLLossForward1D_with_reduce<T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
out_data,
total_weight_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
size_average,
ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossForward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
in_dim2,
in_dim3,
ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossForward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(out_data,
total_weight_data,
x_data,
label_data,
weight_data,
batch_size,
n_classes,
map_size,
blocks_per_sample,
ignore_index);
if (size_average) {
GPUNLLLossForward2D_size_average<T><<<1, 1, 0, dev_ctx.stream()>>>(
out_data, total_weight_data);
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
nll_loss, GPU, ALL_LAYOUT, phi::NllLossRawKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void NllLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const DenseTensor& total_weight,
paddle::optional<const DenseTensor&> weight,
const DenseTensor& d_out,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nll_loss_kernel.h"
namespace phi {
template <typename T, typename Context>
void NllLossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
paddle::optional<const DenseTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* out) {
DenseTensor total_weight;
total_weight.set_meta(
DenseTensorMeta(paddle::experimental::CppTypeToDataType<T>::Type(), {1}));
dev_ctx.template Alloc<T>(total_weight);
NllLossRawKernel(dev_ctx,
input,
label,
weight,
ignore_index,
reduction,
out,
&total_weight);
}
} // namespace phi
// TODO(xiongkun): add the non-raw kernel register here.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void NllLossRawKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
paddle::optional<const DenseTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
DenseTensor* out,
DenseTensor* total_weight);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature NllLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
// TODO(xiongkun): can't remove the forward mapping, because the Weight is
// optional
return KernelSignature("nll_loss",
{"X", "Label", "Weight"},
{"ignore_index", "reduction"},
{"Out", "Total_weight"});
}
KernelSignature NllLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nll_loss_grad",
{"X", "Label", "Total_weight", "Weight", GradVarName("Out")},
{"ignore_index", "reduction"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(nll_loss_grad, phi::NllLossGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nll_loss, phi::NllLossOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册