未验证 提交 e21b3c27 编写于 作者: L lijianshe02 提交者: GitHub

add nll_loss op test=develop (#23758)

* add nll_loss op test=develop
上级 40aa14ec
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nll_loss_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class NLLLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput("Total_weight"), "Output", "Total_weight",
"NLLLoss");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"The tensor rank of Input(X) must be 2 or 4."));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[0], label_dims[0],
platform::errors::InvalidArgument(
"ShapeError: Expected input batch_size to match label batch_size,"
"But received: the Input(x) batch_size is [%s], the Input(label) "
" batch_size is [%s].",
x_dims[0], label_dims[0]));
if (ctx->HasInput("Weight")) {
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 1,
platform::errors::InvalidArgument(
"Input(Weight) should be a 1D tensor."));
PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0],
platform::errors::InvalidArgument(
"Input(Weight) Tensor's size should match"
"to the class numer."));
}
}
if (x_dims.size() == 2) {
if (reduction == "none") {
ctx->SetOutputDim("Out", {x_dims[0]});
} else {
ctx->SetOutputDim("Out", {1});
}
} else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
platform::errors::InvalidArgument(
"The tensor rank of Input(Label) must be 3."));
auto input0 = x_dims[0];
auto input2 = x_dims[2];
auto input3 = x_dims[3];
auto label0 = label_dims[0];
auto label1 = label_dims[1];
auto label2 = label_dims[2];
PADDLE_ENFORCE_EQ(
input0 == label0 && input2 == label1 && input3 == label2, true,
platform::errors::InvalidArgument("Input(X) tensor shape should "
"match to Input(Label) tensor "
"shape."));
if (reduction == "none") {
ctx->SetOutputDim("Out", {x_dims[0], x_dims[2], x_dims[3]});
} else {
ctx->SetOutputDim("Out", {1});
}
}
ctx->SetOutputDim("Total_weight", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class NLLLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor whose last dimension "
"size is equal to the number of classes. It is expected to "
"contain log-probabilities of each class. "
"The X tensor's shape has to be either [batch_size, C] or"
"[batch_size, C, dim1, ..., dimK] in with K >= 1 in the case "
" K-dimensional loss.");
AddInput("Label",
"(Tensor, default Tensor<int64_t>) A tensor which represents the "
"the ground truth. It contains the class index in the range "
"[0, C-1] where C = number of classes. The Lable tensor's "
"shape has to be (batch_size), or "
"(batch_size, dim1, ..., dimK) "
"with K >= 1 in the case K-dimensional loss.");
AddInput("Weight",
"(Tensor, optional) A tensor should be a 1D tensor assigning "
"weight to each of the classes. It's shape must be [C], where "
"C is the class number.")
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor that represents the "
"NLL loss.");
AddOutput("Total_weight",
"(Tensor, default Tensor<float>) A tensor saves the total"
"weight value in the forward process.");
AddAttr<int64_t>("ignore_index",
"(int64_t, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient.")
.SetDefault(-100);
AddAttr<std::string>(
"reduction",
"(string, default mean), Specifies the reduction to apply"
"to the output. The options include \"none\", \"mean\","
"\"sum\".")
.SetDefault("mean");
AddComment(R"DOC(
NLL(Negative Log Likelihood) Loss Operator.
This operator computes the NLL loss according to the inputs.
The loss can be described as:
$Out[i] = -X[Label[i]]*Weight[Label[i]]$
It can also be used for higher dimension inputs, such as 2D images, by
providing an input of shape (batch_size, C, d1, d2, ..., dK), with
K >= 1, where K is the number of dimensions, and a Label of
appropriate shape. In the case of images, it computes NLL loss
per-pixel.
)DOC");
}
};
class NLLLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "NLLLoss");
auto reduction = ctx->Attrs().Get<std::string>("reduction");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(dout_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
auto batch_size = x_dims[0];
if (x_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dout_dims.size(), 1,
platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims[0], batch_size,
platform::errors::InvalidArgument(
"The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size."));
} else {
PADDLE_ENFORCE_EQ(
dout_dims[0], 1,
platform::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
} else if (x_dims.size() == 4) {
if (reduction == "none") {
PADDLE_ENFORCE_EQ(
dout_dims.size(), 3,
platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 3,But got [%s].",
dout_dims.size()));
PADDLE_ENFORCE_EQ(
dout_dims[0] == label_dims[0] && dout_dims[1] == label_dims[1] &&
dout_dims[2] == label_dims[2],
true, platform::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions."));
} else {
PADDLE_ENFORCE_EQ(
dout_dims[0], 1,
platform::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
}
}
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
template <typename T>
class NLLLossGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("nll_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
op->SetInput("Total_weight", this->Output("Total_weight"));
if (this->HasInput("Weight")) {
op->SetInput("Weight", this->Input("Weight"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker,
ops::NLLLossGradMaker<paddle::framework::OpDesc>,
ops::NLLLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp);
REGISTER_OP_CPU_KERNEL(
nll_loss, ops::NLLLossOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::NLLLossOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
nll_loss_grad,
ops::NLLLossGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::NLLLossGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/nll_loss_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static const int NTHREADS = 32;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data,
const int64_t* label_data,
const T* weight_data,
const int64_t batch_size,
const int64_t n_classes,
const int64_t ignore_index) {
CUDA_1D_KERNEL_LOOP(i, batch_size) {
const int64_t cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
}
}
template <typename T>
__global__ void GPUNLLLossForward1D_with_reduce(
T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data, const int64_t batch_size,
const int64_t n_classes, const int64_t size_average,
const int64_t ignore_index) {
__shared__ T sharedInputs[NTHREADS], sharedWeights[NTHREADS];
sharedInputs[threadIdx.x] = 0;
sharedWeights[threadIdx.x] = 0;
int i;
for (i = threadIdx.x; i < batch_size; i += NTHREADS) {
const auto cur_label = label_data[i];
if (cur_label != ignore_index) {
const auto cur_weight = weight_data ? weight_data[cur_label] : (T)1;
sharedInputs[threadIdx.x] -=
x_data[i * n_classes + cur_label] * cur_weight;
sharedWeights[threadIdx.x] += cur_weight;
}
}
__syncthreads();
if (threadIdx.x == 0) {
*out_data = *total_weight_data = 0;
T output_val = 0;
T total_weight_val = 0;
for (i = 0; i < NTHREADS; ++i) {
output_val += sharedInputs[i];
total_weight_val += sharedWeights[i];
}
*total_weight_data = total_weight_val;
*out_data = output_val;
if (size_average && *total_weight_data != 0) {
*out_data = output_val / total_weight_val;
}
}
}
// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp, int N>
__device__ void reduceNValuesInBlock(T* smem, T threadVals[N],
const unsigned int numVals,
ReduceOp reduceOp, T init) {
if (numVals == 0) {
#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = init;
}
return;
}
// We store each of the N values contiguously, so if N = 2, all values for
// the first threadVal for each thread in the block are stored followed by
// all of the values for the second threadVal for each thread in the block
if (threadIdx.x < numVals) {
#pragma unroll
for (int i = 0; i < N; ++i) {
smem[i * numVals + threadIdx.x] = threadVals[i];
}
}
__syncthreads();
// Number of lanes in the final reduction --> this is used to determine
// where to put the outputs of each of the n things we are reducing. If
// nLP = 32, then we have the 32 outputs for the first threadVal,
// followed by the 32 outputs for the second threadVal, etc.
const unsigned int numLanesParticipating = min(numVals, warpSize);
if (numVals > warpSize && ((threadIdx.x / warpSize) == 0)) {
#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init;
}
for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
#pragma unroll
for (int j = 0; j < N; ++j) {
threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
smem[i * numLanesParticipating + threadIdx.x] = threadVals[i];
}
}
__syncthreads();
if (threadIdx.x == 0) {
if (numLanesParticipating == 32) {
#pragma unroll
for (int i = 0; i < N; ++i) {
#pragma unroll
for (int j = 1; j < 32; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]);
}
}
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int j = 1; j < numLanesParticipating; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]);
}
}
}
}
}
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
// return the reduced value
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp>
__device__ T reduceBlock(T* smem, const unsigned int numVals, T threadVal,
ReduceOp reduceOp, T init) {
reduceNValuesInBlock<T, ReduceOp, 1>(smem, &threadVal, numVals, reduceOp,
init);
return threadVal;
}
template <typename T>
__global__ void GPUNLLLossForward2D_no_reduce(
T* out_data, const T* x_data, const int64_t* label_data,
const T* weight_data, const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) {
const int64_t map_size = in_dim2 * in_dim3;
const int64_t sample_size = n_classes * map_size;
const int64_t out_numel = batch_size * map_size;
CUDA_1D_KERNEL_LOOP(i, out_numel) {
const int64_t b = i % batch_size;
const int64_t h = (i / batch_size) % in_dim2;
const int64_t w = (i / (batch_size * in_dim2)) % in_dim3;
const int64_t index = b * map_size + h * in_dim3 + w;
const int64_t cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
out_data[index] =
-x_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] *
cur_weight;
}
}
template <typename T>
__global__ void GPUNLLLossForward2D_with_reduce(
T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data, const int64_t batch_size,
const int64_t n_classes, const int64_t map_nelem,
const int64_t blocks_per_sample, const int64_t ignore_index) {
__shared__ T partial_sums[kNumCUDAThreads];
int64_t i;
T input_sum = 0;
T acc_weight = 0;
*out_data = 0;
*total_weight_data = 0;
int64_t sample = blockIdx.x / blocks_per_sample;
int64_t toffset = sample * map_nelem;
int64_t ioffset = sample * map_nelem * n_classes;
int64_t step = blockDim.x * blocks_per_sample;
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem; i += step) {
const int64_t cur_label = label_data[toffset + i];
if (cur_label != ignore_index) {
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
input_sum -= x_data[ioffset + i + map_nelem * cur_label] * cur_weight;
acc_weight += cur_weight;
}
}
input_sum =
reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus<T>(), (T)0);
__syncthreads();
acc_weight = reduceBlock(partial_sums, blockDim.x, acc_weight,
thrust::plus<T>(), (T)0);
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(total_weight_data, acc_weight);
paddle::platform::CudaAtomicAdd(out_data, input_sum);
}
}
template <typename T>
__global__ void GPUNLLLossForward2D_size_average(T* out_data,
T* total_weight_data) {
if (*total_weight_data != 0) {
*out_data /= *total_weight_data;
}
}
template <typename T>
__global__ void GPUNLLLossBackward1D_no_reduce(
T* dx_data, const int64_t* label_data, const T* weight_data,
const T* dout_data, const int64_t batch_size, const int64_t n_classes,
const int64_t ignore_index) {
CUDA_1D_KERNEL_LOOP(i, batch_size) {
const int64_t cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight;
}
}
template <typename T>
__global__ void GPUNLLLossBackward1D_with_reduce(
T* dx_data, const T* total_weight_data, const int64_t* label_data,
const T* weight_data, const T* dout_data, const int64_t batch_size,
const int64_t n_classes, const int64_t size_average,
const int64_t ignore_index) {
if (*total_weight_data <= 0) {
return;
}
int i;
const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1;
for (i = threadIdx.x; i < batch_size; i += NTHREADS) {
const int64_t cur_label = label_data[i];
if (cur_label != ignore_index) {
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
dx_data[i * n_classes + cur_label] = -cur_weight * dout_data[0] * norm;
}
}
}
template <typename T>
__global__ void GPUNLLLossBackward2D_no_reduce(
T* dx_data, const int64_t* label_data, const T* weight_data,
const T* dout_data, const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) {
const int64_t map_size = in_dim2 * in_dim3;
const int64_t sample_size = n_classes * map_size;
const int64_t out_numel = batch_size * map_size;
CUDA_1D_KERNEL_LOOP(i, out_numel) {
const int64_t b = i % batch_size;
const int64_t h = (i / batch_size) % in_dim2;
const int64_t w = (i / (batch_size * in_dim2)) % in_dim3;
const int64_t index = b * map_size + h * in_dim3 + w;
const int64_t cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
dx_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] =
-dout_data[index] * cur_weight;
}
}
template <typename T>
__global__ void GPUNLLLossBackward2D_with_reduce(
T* dx_data, const T* total_weight_data, const int64_t* label_data,
const T* weight_data, const T* dout_data, const int64_t batch_size,
const int64_t n_classes, const int64_t map_nelem,
const int64_t blocks_per_sample, const int64_t size_average,
const int64_t ignore_index) {
if (*total_weight_data <= 0) {
return;
}
int64_t i;
const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1;
int sample = blockIdx.x / blocks_per_sample;
int step = blockDim.x * blocks_per_sample;
int toffset = sample * map_nelem;
int ioffset = sample * map_nelem * n_classes;
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem; i += step) {
const int64_t cur_label = label_data[toffset + i];
if (cur_label != ignore_index) {
dx_data[ioffset + i + map_nelem * cur_label] =
-(weight_data ? weight_data[cur_label] : (T)1) * norm * dout_data[0];
}
}
}
template <typename DeviceContext, typename T>
class NLLLossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* out = ctx.Output<Tensor>("Out");
auto* total_weight = ctx.Output<Tensor>("Total_weight");
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
auto x_data = x->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
auto total_weight_data = total_weight->mutable_data<T>(ctx.GetPlace());
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
cudaMemset(total_weight_data, 0, sizeof(T));
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
int64_t size_average = (int64_t)(reduction == "mean");
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossForward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_data, x_data, label_data, weight_data, batch_size, n_classes,
ignore_index);
} else {
GPUNLLLossForward1D_with_reduce<
T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
out_data, total_weight_data, x_data, label_data, weight_data,
batch_size, n_classes, size_average, ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossForward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_data, x_data, label_data, weight_data, batch_size, n_classes,
in_dim2, in_dim3, ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossForward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(
out_data, total_weight_data, x_data, label_data, weight_data,
batch_size, n_classes, map_size, blocks_per_sample, ignore_index);
if (size_average) {
GPUNLLLossForward2D_size_average<T><<<1, 1, 0, dev_ctx.stream()>>>(
out_data, total_weight_data);
}
}
}
}
};
template <typename DeviceContext, typename T>
class NLLLossGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* total_weight = ctx.Input<Tensor>("Total_weight");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto total_weight_data = total_weight->data<T>();
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
cudaMemset(dx_data, 0, dx->numel() * sizeof(T));
int64_t size_average = (int64_t)(reduction == "mean");
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossBackward1D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, label_data, weight_data, dout_data, batch_size, n_classes,
ignore_index);
} else {
GPUNLLLossBackward1D_with_reduce<
T><<<1, NTHREADS, 0, dev_ctx.stream()>>>(
dx_data, total_weight_data, label_data, weight_data, dout_data,
batch_size, n_classes, size_average, ignore_index);
}
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
auto& dev_ctx = ctx.cuda_device_context();
if (reduction == "none") {
GPUNLLLossBackward2D_no_reduce<
T><<<blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, label_data, weight_data, dout_data, batch_size, n_classes,
in_dim2, in_dim3, ignore_index);
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossBackward2D_with_reduce<
T><<<total_blocks, threads, 0, dev_ctx.stream()>>>(
dx_data, total_weight_data, label_data, weight_data, dout_data,
batch_size, n_classes, map_size, blocks_per_sample, size_average,
ignore_index);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
nll_loss,
ops::NLLLossCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::NLLLossCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
nll_loss_grad,
ops::NLLLossGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::NLLLossGradCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data,
const int64_t batch_size, const int64_t n_classes,
const std::string reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int64_t i = 0; i < batch_size; ++i) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int64_t i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -= x_data[i * n_classes + cur_label] * cur_weight;
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename T>
static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data,
const int64_t* label_data, const T* weight_data,
const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3,
const std::string reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should nor be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
out_data[index] = -x_data[i * sample_size + cur_label * map_size +
h * in_dim3 + w] *
cur_weight;
}
}
}
return;
}
T output_val = 0;
T total_weight_val = 0;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should nor be out of bounds."));
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
total_weight_val += cur_weight;
output_val -=
x_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] *
cur_weight;
}
}
}
if (reduction == "mean" && total_weight_val != 0) {
output_val /= total_weight_val;
}
*out_data = output_val;
*total_weight_data = total_weight_val;
}
template <typename DeviceContext, typename T>
class NLLLossOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* out = ctx.Output<Tensor>("Out");
auto* total_weight = ctx.Output<Tensor>("Total_weight");
auto reduction = ctx.Attr<std::string>("reduction");
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto x_data = x->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto out_data = out->mutable_data<T>(ctx.GetPlace());
auto total_weight_data = total_weight->mutable_data<T>(ctx.GetPlace());
*total_weight_data = 0;
auto x_dims = x->dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_1D<T>(out_data, total_weight_data, x_data, label_data,
weight_data, batch_size, n_classes, reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_2D<T>(out_data, total_weight_data, x_data, label_data,
weight_data, batch_size, n_classes, in_dim2, in_dim3,
reduction, ignore_index);
}
}
};
template <typename T>
static void nll_loss_grad_1D(T* dx_data, const T* dout_data,
const int64_t* label_data, const T* weight_data,
const T* total_weight_data,
const int64_t batch_size, const int64_t n_classes,
const std::string reduction,
const int64_t ignore_index) {
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight;
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
const auto cur_label = label_data[i];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * n_classes + cur_label] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[i * n_classes + cur_label] /= total_weight_val;
}
}
}
template <typename T>
static void nll_loss_grad_2D(T* dx_data, const T* dout_data,
const int64_t* label_data, const T* weight_data,
const T* total_weight_data,
const int64_t batch_size, const int64_t n_classes,
const int64_t in_dim2, const int64_t in_dim3,
const std::string reduction,
const int64_t ignore_index) {
const auto map_size = in_dim2 * in_dim3;
const auto sample_size = n_classes * map_size;
if (reduction == "none") {
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
dx_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] =
-cur_weight * dout_data[index];
}
}
}
return;
}
const T dout_val = *dout_data;
const T total_weight_val = *total_weight_data;
for (int i = 0; i < batch_size; i++) {
for (int h = 0; h < in_dim2; h++) {
for (int w = 0; w < in_dim3; w++) {
const auto index = i * map_size + h * in_dim3 + w;
const auto cur_label = label_data[index];
if (cur_label == ignore_index) {
continue;
}
const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
const auto dx_index =
i * sample_size + cur_label * map_size + h * in_dim3 + w;
dx_data[dx_index] = -dout_val * cur_weight;
if (reduction == "mean") {
dx_data[dx_index] /= total_weight_val;
}
}
}
}
}
template <typename DeviceContext, typename T>
class NLLLossGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* weight = ctx.Input<Tensor>("Weight");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* total_weight = ctx.Input<Tensor>("Total_weight");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto ignore_index = ctx.Attr<int64_t>("ignore_index");
auto reduction = ctx.Attr<std::string>("reduction");
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto label_data = labels->data<int64_t>();
auto weight_data = weight ? weight->data<T>() : nullptr;
auto total_weight_data = total_weight->data<T>();
memset(dx_data, 0, dx->numel() * sizeof(T));
const auto x_dims = x->dims();
const auto batch_size = x_dims[0];
const auto n_classes = x_dims[1];
if (x_dims.size() == 2) {
nll_loss_grad_1D(dx_data, dout_data, label_data, weight_data,
total_weight_data, batch_size, n_classes, reduction,
ignore_index);
} else if (x_dims.size() == 4) {
const auto in_dim2 = x_dims[2];
const auto in_dim3 = x_dims[3];
nll_loss_grad_2D(dx_data, dout_data, label_data, weight_data,
total_weight_data, batch_size, n_classes, in_dim2,
in_dim3, reduction, ignore_index);
}
}
};
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -65,7 +65,7 @@ from .layer.loss import L1Loss #DEFINE_ALIAS
from .layer import loss #DEFINE_ALIAS
from .layer import conv #DEFINE_ALIAS
from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFINE_ALIAS
# from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
# from .layer.learning_rate import CosineDecay #DEFINE_ALIAS
# from .layer.learning_rate import ExponentialDecay #DEFINE_ALIAS
......
......@@ -19,7 +19,7 @@ __all__ = [
'CrossEntropyLoss',
# 'MSELoss',
'L1Loss',
# 'NLLLoss',
'NLLLoss',
'BCELoss'
]
......@@ -329,3 +329,145 @@ class BCELoss(fluid.dygraph.Layer):
return fluid.layers.reduce_mean(out)
else:
return out
class NLLLoss(fluid.dygraph.Layer):
"""
This op accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes.
The input for the loss is epected to contain log-probabilities of
each classes. It hs to be a Tensor of size either (batch_size, C) or
(batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
The label for the loss should be a class index in the range [0, C-1]
where C is the number of classes. If ignore_index is specified, the
specified target value does not contribute to the input gradient.
If the optional argument `weight` is provided, it should be a 1D Tensor
assigning weight to each of the classed. This is particularly useful
when you have an unbalanced training set.
The loss is calculated as follows.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad
l_n = - w_{y_n} x_{n,y_n}, \quad
w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\},
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then
.. math::
\ell(x, y) = \\begin{cases}
\\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, &
\\text{if reduction} = \\text{'mean';}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}
Parameters:
input (Variable): Input tensor, the data type is float32, float64.
label (Variable): Label tensor, the data type is int64_t.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. If given, it has to be a Tensor of size `C`. Otherwise,
it treated as if having all ones. the data type is
float32, float64, Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient.
Returns:
The tensor variable storing the nll_loss.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input_np = np.random.random(size=(10, 10)).astype(np.float32)
label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[10, 10], dtype='float32')
label = fluid.data(name='label', shape=[10], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[res])
print(static_result)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_np)
label = dg.to_variable(label_np)
output = nll_loss(input, label)
print(output.numpy())
"""
def __init__(self, weight=None, reduction='mean', ignore_index=-100):
super(NLLLoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'nll_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'nll_loss')
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % self.reduction)
x_shape = list(input.shape)
n = x_shape[0]
c = x_shape[1]
x_dims = len(x_shape)
if x_dims < 2:
raise ValueError('Expected 2 or more dimensions (got {})'.format(
x_dims))
if x_dims != 2 and x_dims != 4:
input = fluid.layers.reshape(input, shape=[n, c, 1, -1])
label = fluid.layers.reshape(label, shape=[n, 1, -1])
out_shape = [n] + x_shape[2:]
inputs = {'X': input, 'Label': label}
attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index}
if self.weight is not None:
if isinstance(self.weight, fluid.framework.Variable):
inputs['Weight'] = self.weight
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
total_weight = self._helper.create_variable_for_type_inference(
dtype=input.dtype)
outputs = {'Out': out, 'Total_weight': total_weight}
self._helper.append_op(
type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
if x_dims != 2 and x_dims != 4 and self.reduction == 'none':
out = fluid.layers.reshape(out, shape=out_shape)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册