From e21b3c273ec7606355bfa6285ceb4198913a1249 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Fri, 17 Apr 2020 13:32:27 +0800 Subject: [PATCH] add nll_loss op test=develop (#23758) * add nll_loss op test=develop --- paddle/fluid/operators/nll_loss_op.cc | 268 ++++++ paddle/fluid/operators/nll_loss_op.cu | 488 ++++++++++ paddle/fluid/operators/nll_loss_op.h | 303 ++++++ .../fluid/tests/unittests/test_nll_loss.py | 883 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 +- python/paddle/nn/layer/loss.py | 144 ++- 6 files changed, 2086 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/nll_loss_op.cc create mode 100644 paddle/fluid/operators/nll_loss_op.cu create mode 100644 paddle/fluid/operators/nll_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_nll_loss.py diff --git a/paddle/fluid/operators/nll_loss_op.cc b/paddle/fluid/operators/nll_loss_op.cc new file mode 100644 index 00000000000..e99ccd31714 --- /dev/null +++ b/paddle/fluid/operators/nll_loss_op.cc @@ -0,0 +1,268 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/nll_loss_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class NLLLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NLLLoss"); + OP_INOUT_CHECK(ctx->HasOutput("Total_weight"), "Output", "Total_weight", + "NLLLoss"); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto reduction = ctx->Attrs().Get("reduction"); + + PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4, true, + platform::errors::InvalidArgument( + "The tensor rank of Input(X) must be 2 or 4.")); + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(label_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_EQ( + x_dims[0], label_dims[0], + platform::errors::InvalidArgument( + "ShapeError: Expected input batch_size to match label batch_size," + "But received: the Input(x) batch_size is [%s], the Input(label) " + " batch_size is [%s].", + x_dims[0], label_dims[0])); + if (ctx->HasInput("Weight")) { + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 1, + platform::errors::InvalidArgument( + "Input(Weight) should be a 1D tensor.")); + PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0], + platform::errors::InvalidArgument( + "Input(Weight) Tensor's size should match" + "to the class numer.")); + } + } + if (x_dims.size() == 2) { + if (reduction == "none") { + ctx->SetOutputDim("Out", {x_dims[0]}); + } else { + ctx->SetOutputDim("Out", {1}); + } + } else if (x_dims.size() == 4) { + PADDLE_ENFORCE_EQ(label_dims.size(), 3, + platform::errors::InvalidArgument( + "The tensor rank of Input(Label) must be 3.")); + auto input0 = x_dims[0]; + auto input2 = x_dims[2]; + auto input3 = x_dims[3]; + auto label0 = label_dims[0]; + auto label1 = label_dims[1]; + auto label2 = label_dims[2]; + PADDLE_ENFORCE_EQ( + input0 == label0 && input2 == label1 && input3 == label2, true, + platform::errors::InvalidArgument("Input(X) tensor shape should " + "match to Input(Label) tensor " + "shape.")); + if (reduction == "none") { + ctx->SetOutputDim("Out", {x_dims[0], x_dims[2], x_dims[3]}); + } else { + ctx->SetOutputDim("Out", {1}); + } + } + ctx->SetOutputDim("Total_weight", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class NLLLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor) A tensor whose last dimension " + "size is equal to the number of classes. It is expected to " + "contain log-probabilities of each class. " + "The X tensor's shape has to be either [batch_size, C] or" + "[batch_size, C, dim1, ..., dimK] in with K >= 1 in the case " + " K-dimensional loss."); + AddInput("Label", + "(Tensor, default Tensor) A tensor which represents the " + "the ground truth. It contains the class index in the range " + "[0, C-1] where C = number of classes. The Lable tensor's " + "shape has to be (batch_size), or " + "(batch_size, dim1, ..., dimK) " + "with K >= 1 in the case K-dimensional loss."); + AddInput("Weight", + "(Tensor, optional) A tensor should be a 1D tensor assigning " + "weight to each of the classes. It's shape must be [C], where " + "C is the class number.") + .AsDispensable(); + AddOutput("Out", + "(Tensor, default Tensor) A tensor that represents the " + "NLL loss."); + AddOutput("Total_weight", + "(Tensor, default Tensor) A tensor saves the total" + "weight value in the forward process."); + AddAttr("ignore_index", + "(int64_t, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient.") + .SetDefault(-100); + AddAttr( + "reduction", + "(string, default mean), Specifies the reduction to apply" + "to the output. The options include \"none\", \"mean\"," + "\"sum\".") + .SetDefault("mean"); + AddComment(R"DOC( +NLL(Negative Log Likelihood) Loss Operator. + +This operator computes the NLL loss according to the inputs. +The loss can be described as: + +$Out[i] = -X[Label[i]]*Weight[Label[i]]$ + +It can also be used for higher dimension inputs, such as 2D images, by +providing an input of shape (batch_size, C, d1, d2, ..., dK), with +K >= 1, where K is the number of dimensions, and a Label of +appropriate shape. In the case of images, it computes NLL loss +per-pixel. + +)DOC"); + } +}; + +class NLLLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "NLLLoss"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "NLLLoss"); + + auto reduction = ctx->Attrs().Get("reduction"); + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(dout_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; + + if (check) { + auto batch_size = x_dims[0]; + if (x_dims.size() == 2) { + PADDLE_ENFORCE_EQ(dout_dims.size(), 1, + platform::errors::InvalidArgument( + "The dimensions of Input(Out@Grad) must be 1")); + if (reduction == "none") { + PADDLE_ENFORCE_EQ( + dout_dims[0], batch_size, + platform::errors::InvalidArgument( + "The unreduced size ofInput(Out@Grad) must be the " + "same as batch_size.")); + } else { + PADDLE_ENFORCE_EQ( + dout_dims[0], 1, + platform::errors::InvalidArgument( + "The reduced size of Input(Out@Grad) must be 1")); + } + } else if (x_dims.size() == 4) { + if (reduction == "none") { + PADDLE_ENFORCE_EQ( + dout_dims.size(), 3, + platform::errors::InvalidArgument( + "The dimensions of Input(Out@Grad) must be 3,But got [%s].", + dout_dims.size())); + PADDLE_ENFORCE_EQ( + dout_dims[0] == label_dims[0] && dout_dims[1] == label_dims[1] && + dout_dims[2] == label_dims[2], + true, platform::errors::InvalidArgument( + "The dimensions of Input(Out@Grad) must be match " + "to Input(Label) dimensions.")); + } else { + PADDLE_ENFORCE_EQ( + dout_dims[0], 1, + platform::errors::InvalidArgument( + "The reduced size of Input(Out@Grad) must be 1")); + } + } + } + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +template +class NLLLossGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("nll_loss_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Label", this->Input("Label")); + op->SetInput("Total_weight", this->Output("Total_weight")); + + if (this->HasInput("Weight")) { + op->SetInput("Weight", this->Input("Weight")); + } + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker, + ops::NLLLossGradMaker, + ops::NLLLossGradMaker); +REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp); +REGISTER_OP_CPU_KERNEL( + nll_loss, ops::NLLLossOpKernel, + ops::NLLLossOpKernel); +REGISTER_OP_CPU_KERNEL( + nll_loss_grad, + ops::NLLLossGradOpKernel, + ops::NLLLossGradOpKernel); diff --git a/paddle/fluid/operators/nll_loss_op.cu b/paddle/fluid/operators/nll_loss_op.cu new file mode 100644 index 00000000000..ff7ac17a238 --- /dev/null +++ b/paddle/fluid/operators/nll_loss_op.cu @@ -0,0 +1,488 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include "cub/cub.cuh" +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/operators/nll_loss_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; +static const int NTHREADS = 32; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data, + const int64_t* label_data, + const T* weight_data, + const int64_t batch_size, + const int64_t n_classes, + const int64_t ignore_index) { + CUDA_1D_KERNEL_LOOP(i, batch_size) { + const int64_t cur_label = label_data[i]; + if (cur_label == ignore_index) { + out_data[i] = 0; + continue; + } + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight; + } +} + +template +__global__ void GPUNLLLossForward1D_with_reduce( + T* out_data, T* total_weight_data, const T* x_data, + const int64_t* label_data, const T* weight_data, const int64_t batch_size, + const int64_t n_classes, const int64_t size_average, + const int64_t ignore_index) { + __shared__ T sharedInputs[NTHREADS], sharedWeights[NTHREADS]; + sharedInputs[threadIdx.x] = 0; + sharedWeights[threadIdx.x] = 0; + int i; + for (i = threadIdx.x; i < batch_size; i += NTHREADS) { + const auto cur_label = label_data[i]; + if (cur_label != ignore_index) { + const auto cur_weight = weight_data ? weight_data[cur_label] : (T)1; + sharedInputs[threadIdx.x] -= + x_data[i * n_classes + cur_label] * cur_weight; + sharedWeights[threadIdx.x] += cur_weight; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + *out_data = *total_weight_data = 0; + T output_val = 0; + T total_weight_val = 0; + for (i = 0; i < NTHREADS; ++i) { + output_val += sharedInputs[i]; + total_weight_val += sharedWeights[i]; + } + *total_weight_data = total_weight_val; + *out_data = output_val; + + if (size_average && *total_weight_data != 0) { + *out_data = output_val / total_weight_val; + } + } +} + +// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: +// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0 +// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20) +// +// If smem is not used again, there is no need to __syncthreads before this +// call. However, if smem will be used, e.g., this function is called in a loop, +// then __syncthreads is needed either before or afterwards to prevent non-0 +// threads overriding smem in the next loop before num-0 thread reads from it. +template +__device__ void reduceNValuesInBlock(T* smem, T threadVals[N], + const unsigned int numVals, + ReduceOp reduceOp, T init) { + if (numVals == 0) { +#pragma unroll + for (int i = 0; i < N; ++i) { + threadVals[i] = init; + } + return; + } + + // We store each of the N values contiguously, so if N = 2, all values for + // the first threadVal for each thread in the block are stored followed by + // all of the values for the second threadVal for each thread in the block + if (threadIdx.x < numVals) { +#pragma unroll + for (int i = 0; i < N; ++i) { + smem[i * numVals + threadIdx.x] = threadVals[i]; + } + } + __syncthreads(); + + // Number of lanes in the final reduction --> this is used to determine + // where to put the outputs of each of the n things we are reducing. If + // nLP = 32, then we have the 32 outputs for the first threadVal, + // followed by the 32 outputs for the second threadVal, etc. + const unsigned int numLanesParticipating = min(numVals, warpSize); + + if (numVals > warpSize && ((threadIdx.x / warpSize) == 0)) { +#pragma unroll + for (int i = 0; i < N; ++i) { + threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init; + } + + for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) { +#pragma unroll + for (int j = 0; j < N; ++j) { + threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]); + } + } + +#pragma unroll + for (int i = 0; i < N; ++i) { + smem[i * numLanesParticipating + threadIdx.x] = threadVals[i]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + if (numLanesParticipating == 32) { +#pragma unroll + for (int i = 0; i < N; ++i) { +#pragma unroll + for (int j = 1; j < 32; ++j) { + threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]); + } + } + } else { +#pragma unroll + for (int i = 0; i < N; ++i) { + for (int j = 1; j < numLanesParticipating; ++j) { + threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]); + } + } + } + } +} + +// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will +// return the reduced value +// +// If smem is not used again, there is no need to __syncthreads before this +// call. However, if smem will be used, e.g., this function is called in a loop, +// then __syncthreads is needed either before or afterwards to prevent non-0 +// threads overriding smem in the next loop before num-0 thread reads from it. +template +__device__ T reduceBlock(T* smem, const unsigned int numVals, T threadVal, + ReduceOp reduceOp, T init) { + reduceNValuesInBlock(smem, &threadVal, numVals, reduceOp, + init); + return threadVal; +} + +template +__global__ void GPUNLLLossForward2D_no_reduce( + T* out_data, const T* x_data, const int64_t* label_data, + const T* weight_data, const int64_t batch_size, const int64_t n_classes, + const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) { + const int64_t map_size = in_dim2 * in_dim3; + const int64_t sample_size = n_classes * map_size; + const int64_t out_numel = batch_size * map_size; + CUDA_1D_KERNEL_LOOP(i, out_numel) { + const int64_t b = i % batch_size; + const int64_t h = (i / batch_size) % in_dim2; + const int64_t w = (i / (batch_size * in_dim2)) % in_dim3; + + const int64_t index = b * map_size + h * in_dim3 + w; + const int64_t cur_label = label_data[index]; + if (cur_label == ignore_index) { + out_data[index] = 0; + continue; + } + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + out_data[index] = + -x_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] * + cur_weight; + } +} + +template +__global__ void GPUNLLLossForward2D_with_reduce( + T* out_data, T* total_weight_data, const T* x_data, + const int64_t* label_data, const T* weight_data, const int64_t batch_size, + const int64_t n_classes, const int64_t map_nelem, + const int64_t blocks_per_sample, const int64_t ignore_index) { + __shared__ T partial_sums[kNumCUDAThreads]; + int64_t i; + T input_sum = 0; + T acc_weight = 0; + *out_data = 0; + *total_weight_data = 0; + + int64_t sample = blockIdx.x / blocks_per_sample; + int64_t toffset = sample * map_nelem; + int64_t ioffset = sample * map_nelem * n_classes; + int64_t step = blockDim.x * blocks_per_sample; + for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; + i < map_nelem; i += step) { + const int64_t cur_label = label_data[toffset + i]; + if (cur_label != ignore_index) { + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + input_sum -= x_data[ioffset + i + map_nelem * cur_label] * cur_weight; + acc_weight += cur_weight; + } + } + + input_sum = + reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus(), (T)0); + __syncthreads(); + acc_weight = reduceBlock(partial_sums, blockDim.x, acc_weight, + thrust::plus(), (T)0); + + if (threadIdx.x == 0) { + paddle::platform::CudaAtomicAdd(total_weight_data, acc_weight); + paddle::platform::CudaAtomicAdd(out_data, input_sum); + } +} + +template +__global__ void GPUNLLLossForward2D_size_average(T* out_data, + T* total_weight_data) { + if (*total_weight_data != 0) { + *out_data /= *total_weight_data; + } +} + +template +__global__ void GPUNLLLossBackward1D_no_reduce( + T* dx_data, const int64_t* label_data, const T* weight_data, + const T* dout_data, const int64_t batch_size, const int64_t n_classes, + const int64_t ignore_index) { + CUDA_1D_KERNEL_LOOP(i, batch_size) { + const int64_t cur_label = label_data[i]; + if (cur_label == ignore_index) { + continue; + } + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight; + } +} + +template +__global__ void GPUNLLLossBackward1D_with_reduce( + T* dx_data, const T* total_weight_data, const int64_t* label_data, + const T* weight_data, const T* dout_data, const int64_t batch_size, + const int64_t n_classes, const int64_t size_average, + const int64_t ignore_index) { + if (*total_weight_data <= 0) { + return; + } + int i; + const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1; + for (i = threadIdx.x; i < batch_size; i += NTHREADS) { + const int64_t cur_label = label_data[i]; + if (cur_label != ignore_index) { + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + dx_data[i * n_classes + cur_label] = -cur_weight * dout_data[0] * norm; + } + } +} + +template +__global__ void GPUNLLLossBackward2D_no_reduce( + T* dx_data, const int64_t* label_data, const T* weight_data, + const T* dout_data, const int64_t batch_size, const int64_t n_classes, + const int64_t in_dim2, const int64_t in_dim3, const int64_t ignore_index) { + const int64_t map_size = in_dim2 * in_dim3; + const int64_t sample_size = n_classes * map_size; + const int64_t out_numel = batch_size * map_size; + CUDA_1D_KERNEL_LOOP(i, out_numel) { + const int64_t b = i % batch_size; + const int64_t h = (i / batch_size) % in_dim2; + const int64_t w = (i / (batch_size * in_dim2)) % in_dim3; + const int64_t index = b * map_size + h * in_dim3 + w; + const int64_t cur_label = label_data[index]; + if (cur_label == ignore_index) { + continue; + } + const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; + dx_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] = + -dout_data[index] * cur_weight; + } +} + +template +__global__ void GPUNLLLossBackward2D_with_reduce( + T* dx_data, const T* total_weight_data, const int64_t* label_data, + const T* weight_data, const T* dout_data, const int64_t batch_size, + const int64_t n_classes, const int64_t map_nelem, + const int64_t blocks_per_sample, const int64_t size_average, + const int64_t ignore_index) { + if (*total_weight_data <= 0) { + return; + } + int64_t i; + const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1; + int sample = blockIdx.x / blocks_per_sample; + int step = blockDim.x * blocks_per_sample; + int toffset = sample * map_nelem; + int ioffset = sample * map_nelem * n_classes; + for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; + i < map_nelem; i += step) { + const int64_t cur_label = label_data[toffset + i]; + if (cur_label != ignore_index) { + dx_data[ioffset + i + map_nelem * cur_label] = + -(weight_data ? weight_data[cur_label] : (T)1) * norm * dout_data[0]; + } + } +} + +template +class NLLLossCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* weight = ctx.Input("Weight"); + auto* out = ctx.Output("Out"); + auto* total_weight = ctx.Output("Total_weight"); + auto ignore_index = ctx.Attr("ignore_index"); + auto reduction = ctx.Attr("reduction"); + + auto x_data = x->data(); + auto out_data = out->mutable_data(ctx.GetPlace()); + auto total_weight_data = total_weight->mutable_data(ctx.GetPlace()); + auto label_data = labels->data(); + auto weight_data = weight ? weight->data() : nullptr; + cudaMemset(total_weight_data, 0, sizeof(T)); + auto x_dims = x->dims(); + auto batch_size = x_dims[0]; + auto n_classes = x_dims[1]; + int64_t size_average = (int64_t)(reduction == "mean"); + + if (x_dims.size() == 2) { + int blocks = NumBlocks(batch_size); + int threads = kNumCUDAThreads; + auto& dev_ctx = ctx.cuda_device_context(); + if (reduction == "none") { + GPUNLLLossForward1D_no_reduce< + T><<>>( + out_data, x_data, label_data, weight_data, batch_size, n_classes, + ignore_index); + } else { + GPUNLLLossForward1D_with_reduce< + T><<<1, NTHREADS, 0, dev_ctx.stream()>>>( + out_data, total_weight_data, x_data, label_data, weight_data, + batch_size, n_classes, size_average, ignore_index); + } + } else if (x_dims.size() == 4) { + const auto in_dim2 = x_dims[2]; + const auto in_dim3 = x_dims[3]; + const auto map_size = in_dim2 * in_dim3; + const auto out_numel = batch_size * in_dim2 * in_dim3; + int blocks = NumBlocks(out_numel); + int threads = kNumCUDAThreads; + auto& dev_ctx = ctx.cuda_device_context(); + if (reduction == "none") { + GPUNLLLossForward2D_no_reduce< + T><<>>( + out_data, x_data, label_data, weight_data, batch_size, n_classes, + in_dim2, in_dim3, ignore_index); + } else { + int blocks_per_sample = NumBlocks(map_size) / 128; + blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; + int total_blocks = blocks_per_sample * batch_size; + GPUNLLLossForward2D_with_reduce< + T><<>>( + out_data, total_weight_data, x_data, label_data, weight_data, + batch_size, n_classes, map_size, blocks_per_sample, ignore_index); + if (size_average) { + GPUNLLLossForward2D_size_average<<<1, 1, 0, dev_ctx.stream()>>>( + out_data, total_weight_data); + } + } + } + } +}; + +template +class NLLLossGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* weight = ctx.Input("Weight"); + auto* total_weight = ctx.Input("Total_weight"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto dx_data = dx->mutable_data(ctx.GetPlace()); + auto dout_data = dout->data(); + auto label_data = labels->data(); + auto weight_data = weight ? weight->data() : nullptr; + auto total_weight_data = total_weight->data(); + auto ignore_index = ctx.Attr("ignore_index"); + auto reduction = ctx.Attr("reduction"); + cudaMemset(dx_data, 0, dx->numel() * sizeof(T)); + + int64_t size_average = (int64_t)(reduction == "mean"); + auto x_dims = x->dims(); + auto batch_size = x_dims[0]; + auto n_classes = x_dims[1]; + + if (x_dims.size() == 2) { + int blocks = NumBlocks(batch_size); + int threads = kNumCUDAThreads; + auto& dev_ctx = ctx.cuda_device_context(); + if (reduction == "none") { + GPUNLLLossBackward1D_no_reduce< + T><<>>( + dx_data, label_data, weight_data, dout_data, batch_size, n_classes, + ignore_index); + } else { + GPUNLLLossBackward1D_with_reduce< + T><<<1, NTHREADS, 0, dev_ctx.stream()>>>( + dx_data, total_weight_data, label_data, weight_data, dout_data, + batch_size, n_classes, size_average, ignore_index); + } + } else if (x_dims.size() == 4) { + const auto in_dim2 = x_dims[2]; + const auto in_dim3 = x_dims[3]; + const auto map_size = in_dim2 * in_dim3; + const auto out_numel = batch_size * in_dim2 * in_dim3; + + int blocks = NumBlocks(out_numel); + int threads = kNumCUDAThreads; + auto& dev_ctx = ctx.cuda_device_context(); + if (reduction == "none") { + GPUNLLLossBackward2D_no_reduce< + T><<>>( + dx_data, label_data, weight_data, dout_data, batch_size, n_classes, + in_dim2, in_dim3, ignore_index); + } else { + int blocks_per_sample = NumBlocks(map_size) / 128; + blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample; + int total_blocks = blocks_per_sample * batch_size; + GPUNLLLossBackward2D_with_reduce< + T><<>>( + dx_data, total_weight_data, label_data, weight_data, dout_data, + batch_size, n_classes, map_size, blocks_per_sample, size_average, + ignore_index); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + nll_loss, + ops::NLLLossCUDAKernel, + ops::NLLLossCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + nll_loss_grad, + ops::NLLLossGradCUDAKernel, + ops::NLLLossGradCUDAKernel); diff --git a/paddle/fluid/operators/nll_loss_op.h b/paddle/fluid/operators/nll_loss_op.h new file mode 100644 index 00000000000..92f3d169f3f --- /dev/null +++ b/paddle/fluid/operators/nll_loss_op.h @@ -0,0 +1,303 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data, + const int64_t* label_data, const T* weight_data, + const int64_t batch_size, const int64_t n_classes, + const std::string reduction, + const int64_t ignore_index) { + if (reduction == "none") { + for (int64_t i = 0; i < batch_size; ++i) { + const auto cur_label = label_data[i]; + if (cur_label == ignore_index) { + out_data[i] = 0; + continue; + } + PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, + platform::errors::InvalidArgument( + "label should not be out of bounds.")); + + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight; + } + return; + } + + T output_val = 0; + T total_weight_val = 0; + + for (int64_t i = 0; i < batch_size; i++) { + const auto cur_label = label_data[i]; + if (cur_label == ignore_index) { + out_data[i] = 0; + continue; + } + PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, + platform::errors::InvalidArgument( + "label should not be out of bounds.")); + + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + total_weight_val += cur_weight; + output_val -= x_data[i * n_classes + cur_label] * cur_weight; + } + if (reduction == "mean" && total_weight_val != 0) { + output_val /= total_weight_val; + } + *out_data = output_val; + *total_weight_data = total_weight_val; +} + +template +static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data, + const int64_t* label_data, const T* weight_data, + const int64_t batch_size, const int64_t n_classes, + const int64_t in_dim2, const int64_t in_dim3, + const std::string reduction, + const int64_t ignore_index) { + const auto map_size = in_dim2 * in_dim3; + const auto sample_size = n_classes * map_size; + if (reduction == "none") { + for (int i = 0; i < batch_size; i++) { + for (int h = 0; h < in_dim2; h++) { + for (int w = 0; w < in_dim3; w++) { + const auto index = i * map_size + h * in_dim3 + w; + const auto cur_label = label_data[index]; + if (cur_label == ignore_index) { + out_data[index] = 0; + continue; + } + PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, + platform::errors::InvalidArgument( + "label should nor be out of bounds.")); + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + out_data[index] = -x_data[i * sample_size + cur_label * map_size + + h * in_dim3 + w] * + cur_weight; + } + } + } + return; + } + + T output_val = 0; + T total_weight_val = 0; + + for (int i = 0; i < batch_size; i++) { + for (int h = 0; h < in_dim2; h++) { + for (int w = 0; w < in_dim3; w++) { + const auto index = i * map_size + h * in_dim3 + w; + const auto cur_label = label_data[index]; + if (cur_label == ignore_index) { + out_data[index] = 0; + continue; + } + PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, + platform::errors::InvalidArgument( + "label should nor be out of bounds.")); + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + total_weight_val += cur_weight; + output_val -= + x_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] * + cur_weight; + } + } + } + + if (reduction == "mean" && total_weight_val != 0) { + output_val /= total_weight_val; + } + *out_data = output_val; + *total_weight_data = total_weight_val; +} + +template +class NLLLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* weight = ctx.Input("Weight"); + auto* out = ctx.Output("Out"); + auto* total_weight = ctx.Output("Total_weight"); + auto reduction = ctx.Attr("reduction"); + auto ignore_index = ctx.Attr("ignore_index"); + + auto x_data = x->data(); + auto label_data = labels->data(); + auto weight_data = weight ? weight->data() : nullptr; + auto out_data = out->mutable_data(ctx.GetPlace()); + auto total_weight_data = total_weight->mutable_data(ctx.GetPlace()); + *total_weight_data = 0; + + auto x_dims = x->dims(); + const auto batch_size = x_dims[0]; + const auto n_classes = x_dims[1]; + + if (x_dims.size() == 2) { + nll_loss_1D(out_data, total_weight_data, x_data, label_data, + weight_data, batch_size, n_classes, reduction, + ignore_index); + } else if (x_dims.size() == 4) { + const auto in_dim2 = x_dims[2]; + const auto in_dim3 = x_dims[3]; + nll_loss_2D(out_data, total_weight_data, x_data, label_data, + weight_data, batch_size, n_classes, in_dim2, in_dim3, + reduction, ignore_index); + } + } +}; + +template +static void nll_loss_grad_1D(T* dx_data, const T* dout_data, + const int64_t* label_data, const T* weight_data, + const T* total_weight_data, + const int64_t batch_size, const int64_t n_classes, + const std::string reduction, + const int64_t ignore_index) { + if (reduction == "none") { + for (int i = 0; i < batch_size; i++) { + const auto cur_label = label_data[i]; + if (cur_label == ignore_index) { + continue; + } + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight; + } + return; + } + + const T dout_val = *dout_data; + const T total_weight_val = *total_weight_data; + for (int i = 0; i < batch_size; i++) { + const auto cur_label = label_data[i]; + if (cur_label == ignore_index) { + continue; + } + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + dx_data[i * n_classes + cur_label] = -dout_val * cur_weight; + if (reduction == "mean") { + dx_data[i * n_classes + cur_label] /= total_weight_val; + } + } +} + +template +static void nll_loss_grad_2D(T* dx_data, const T* dout_data, + const int64_t* label_data, const T* weight_data, + const T* total_weight_data, + const int64_t batch_size, const int64_t n_classes, + const int64_t in_dim2, const int64_t in_dim3, + const std::string reduction, + const int64_t ignore_index) { + const auto map_size = in_dim2 * in_dim3; + const auto sample_size = n_classes * map_size; + + if (reduction == "none") { + for (int i = 0; i < batch_size; i++) { + for (int h = 0; h < in_dim2; h++) { + for (int w = 0; w < in_dim3; w++) { + const auto index = i * map_size + h * in_dim3 + w; + const auto cur_label = label_data[index]; + if (cur_label == ignore_index) { + continue; + } + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + dx_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] = + -cur_weight * dout_data[index]; + } + } + } + return; + } + + const T dout_val = *dout_data; + const T total_weight_val = *total_weight_data; + for (int i = 0; i < batch_size; i++) { + for (int h = 0; h < in_dim2; h++) { + for (int w = 0; w < in_dim3; w++) { + const auto index = i * map_size + h * in_dim3 + w; + const auto cur_label = label_data[index]; + if (cur_label == ignore_index) { + continue; + } + const auto cur_weight = + weight_data ? weight_data[cur_label] : static_cast(1); + const auto dx_index = + i * sample_size + cur_label * map_size + h * in_dim3 + w; + dx_data[dx_index] = -dout_val * cur_weight; + if (reduction == "mean") { + dx_data[dx_index] /= total_weight_val; + } + } + } + } +} + +template +class NLLLossGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* weight = ctx.Input("Weight"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* total_weight = ctx.Input("Total_weight"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto ignore_index = ctx.Attr("ignore_index"); + auto reduction = ctx.Attr("reduction"); + + auto dx_data = dx->mutable_data(ctx.GetPlace()); + auto dout_data = dout->data(); + auto label_data = labels->data(); + auto weight_data = weight ? weight->data() : nullptr; + auto total_weight_data = total_weight->data(); + memset(dx_data, 0, dx->numel() * sizeof(T)); + + const auto x_dims = x->dims(); + const auto batch_size = x_dims[0]; + const auto n_classes = x_dims[1]; + + if (x_dims.size() == 2) { + nll_loss_grad_1D(dx_data, dout_data, label_data, weight_data, + total_weight_data, batch_size, n_classes, reduction, + ignore_index); + } else if (x_dims.size() == 4) { + const auto in_dim2 = x_dims[2]; + const auto in_dim3 = x_dims[3]; + nll_loss_grad_2D(dx_data, dout_data, label_data, weight_data, + total_weight_data, batch_size, n_classes, in_dim2, + in_dim3, reduction, ignore_index); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_nll_loss.py b/python/paddle/fluid/tests/unittests/test_nll_loss.py new file mode 100644 index 00000000000..b14e3a15d97 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nll_loss.py @@ -0,0 +1,883 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +from op_test import OpTest + + +def nll_loss_1d(logs, targets, weight=None, reduction='mean', + ignore_index=-100): + input_shape = logs.shape + N = input_shape[0] + C = input_shape[1] + out = np.zeros_like(targets).astype(np.float64) + total_weight = 0 + for i in range(N): + cur_target = targets[i] + if cur_target == ignore_index: + out[i] = 0 + continue + cur_weight = weight[cur_target] if weight is not None else 1 + total_weight += cur_weight + out[i] = -logs[i][cur_target] * cur_weight + if reduction == 'sum': + return np.sum(out), np.array([total_weight]).astype('float64') + elif reduction == 'mean': + return out.sum() / total_weight, np.array( + [total_weight]).astype('float64') + elif reduction == 'none': + return out + + +def nll_loss_2d(logs, targets, weight=None, reduction='mean', + ignore_index=-100): + input_shape = logs.shape + N = input_shape[0] + H = input_shape[2] + W = input_shape[3] + out = np.zeros_like(targets).astype(np.float64) + total_weight = 0 + for i in range(N): + for h in range(H): + for w in range(W): + cur_target = targets[i][h][w] + if cur_target == ignore_index: + out[i][h][w] = 0 + continue + cur_weight = weight[cur_target] if weight is not None else 1 + total_weight += cur_weight + out[i][h][w] = -logs[i][cur_target][h][w] * cur_weight + if reduction == 'sum': + return np.sum(out), np.array([total_weight]).astype('float64') + elif reduction == 'mean': + return out.sum() / total_weight, np.array( + [total_weight]).astype('float64') + elif reduction == 'none': + return out + + +class TestNLLLoss(unittest.TestCase): + def test_NLLLoss_1D_mean(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss() + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_1d(input_np, label_np)[0] + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_1D_sum(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_1d(input_np, label_np, reduction='sum')[0] + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_1D_with_weight_mean(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + weight_np = np.random.random(size=(10, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + # place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + weight = fluid.data(name='weight', shape=[10], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_1D_with_weight_sum(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + weight_np = np.random.random(size=(10, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + # place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + weight = fluid.data(name='weight', shape=[10], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='sum') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + expected = nll_loss_1d( + input_np, label_np, weight=weight_np, reduction='sum')[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_1D_with_weight_mean_cpu(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + weight_np = np.random.random(size=(10, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + weight = fluid.data(name='weight', shape=[10], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + expected = nll_loss_1d(input_np, label_np, weight=weight_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_1D_with_weight_no_reduce_cpu(self): + input_np = np.random.random(size=(10, 10)).astype(np.float64) + label_np = np.random.randint(0, 10, size=(10, )).astype(np.int64) + weight_np = np.random.random(size=(10, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float64') + label = fluid.data(name='label', shape=[10], dtype='int64') + weight = fluid.data(name='weight', shape=[10], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='none') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + expected = nll_loss_1d( + input_np, label_np, weight=weight_np, reduction='none') + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_2D_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss() + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_2d(input_np, label_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_2D_sum(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_2d(input_np, label_np, reduction='sum')[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_2D_with_weight_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + + nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_2D_with_weight_mean_cpu(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + + nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_2d(input_np, label_np, weight=weight_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_2D_with_weight_sum(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='sum') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = nll_loss_2d( + input_np, label_np, weight=weight_np, reduction='sum')[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_in_dims_not_2or4_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss() + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + input_shape = input_np.shape + label_shape = label_np.shape + input_np_reshape = np.reshape(input_np, + (input_shape[0], input_shape[1], 1, -1)) + label_np_reshape = np.reshape(label_np, (label_shape[0], 1, -1)) + expected = nll_loss_2d(input_np_reshape, label_np_reshape)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_in_dims_not_2or4_with_weight_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + input_shape = input_np.shape + label_shape = label_np.shape + input_np_reshape = np.reshape(input_np, + (input_shape[0], input_shape[1], 1, -1)) + label_np_reshape = np.reshape(label_np, (label_shape[0], 1, -1)) + expected = nll_loss_2d( + input_np_reshape, label_np_reshape, weight=weight_np)[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_in_dims_not_2or4_with_weight_sum(self): + input_np = np.random.random(size=(5, 3, 5, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='sum') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + input_shape = input_np.shape + label_shape = label_np.shape + input_np_reshape = np.reshape(input_np, + (input_shape[0], input_shape[1], 1, -1)) + label_np_reshape = np.reshape(label_np, (label_shape[0], 1, -1)) + expected = nll_loss_2d( + input_np_reshape, + label_np_reshape, + weight=weight_np, + reduction='sum')[0] + + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_in_dims_not_2or4_with_weight_no_reduce(self): + input_np = np.random.random(size=(5, 3, 5, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + #place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='none') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + input_shape = input_np.shape + label_shape = label_np.shape + out_shape = (input_shape[0], ) + input_shape[2:] + input_np_reshape = np.reshape(input_np, + (input_shape[0], input_shape[1], 1, -1)) + label_np_reshape = np.reshape(label_np, (label_shape[0], 1, -1)) + expected = nll_loss_2d( + input_np_reshape, + label_np_reshape, + weight=weight_np, + reduction='none') + expected = np.reshape(expected, out_shape) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_NLLLoss_in_dims_not_2or4_with_weight_no_reduce_cpu(self): + input_np = np.random.random(size=(5, 3, 5, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + nll_loss = paddle.nn.loss.NLLLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='none') + dy_res = nll_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + input_shape = input_np.shape + label_shape = label_np.shape + out_shape = (input_shape[0], ) + input_shape[2:] + input_np_reshape = np.reshape(input_np, + (input_shape[0], input_shape[1], 1, -1)) + label_np_reshape = np.reshape(label_np, (label_shape[0], 1, -1)) + expected = nll_loss_2d( + input_np_reshape, + label_np_reshape, + weight=weight_np, + reduction='none') + expected = np.reshape(expected, out_shape) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + +class TestNLLLossOp1DWithReduce(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "nll_loss" + self.with_weight = False + input_np = np.random.uniform(0.1, 0.8, + self.input_shape).astype("float64") + label_np = np.random.randint(0, self.input_shape[1], + self.label_shape).astype("int64") + output_np, total_weight_np = nll_loss_1d(input_np, label_np) + self.inputs = {'X': input_np, 'Label': label_np} + if self.with_weight: + weight_np = np.random.uniform(0.1, 0.8, + self.input_shape[1]).astype("float64") + output_np, total_weight_np = nll_loss_1d( + input_np, label_np, weight=weight_np) + self.inputs['Weight'] = weight_np + + self.outputs = {'Out': output_np, 'Total_weight': total_weight_np} + self.attrs = {'reduction': 'mean', 'ignore_index': -100} + + def test_check_output(self): + self.check_output() + + def test_check_output_with_weight(self): + self.with_weight = True + self.check_output() + + def test_check_grad(self): + self.with_weight = True + place = fluid.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.input_shape = [10, 10] + self.label_shape = [10] + + +class TestNLLLossOp1DNoReduce(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "nll_loss" + self.with_weight = False + input_np = np.random.uniform(0.1, 0.8, + self.input_shape).astype("float64") + label_np = np.random.randint(0, self.input_shape[1], + self.label_shape).astype("int64") + output_np = nll_loss_1d(input_np, label_np, reduction='none') + total_weight_np = np.array([0]).astype('float64') + self.inputs = {'X': input_np, 'Label': label_np} + if self.with_weight: + weight_np = np.random.uniform(0.1, 0.8, + self.input_shape[1]).astype("float64") + output_np, total_weight_np = nll_loss_1d( + input_np, label_np, weight=weight_np, reduction='none') + self.inputs['Weight'] = weight_np + + self.outputs = {'Out': output_np, 'Total_weight': total_weight_np} + self.attrs = {'reduction': 'none', 'ignore_index': -100} + + def test_check_output(self): + self.check_output() + + def test_check_output_with_weight(self): + self.with_weight = True + self.check_output() + + def test_check_grad(self): + self.with_weight = True + place = fluid.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.input_shape = [10, 10] + self.label_shape = [10] + + +class TestNLLLossOp2DWithReduce(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "nll_loss" + self.with_weight = False + input_np = np.random.uniform(0.1, 0.8, + self.input_shape).astype("float64") + label_np = np.random.randint(0, self.input_shape[1], + self.label_shape).astype("int64") + output_np, total_weight_np = nll_loss_2d(input_np, label_np) + self.inputs = {'X': input_np, 'Label': label_np} + if self.with_weight: + weight_np = np.random.uniform(0.1, 0.8, + self.input_shape[1]).astype("float64") + output_np, total_weight_np = nll_loss_2d( + input_np, label_np, weight=weight_np) + self.inputs['Weight'] = weight_np + + self.outputs = {'Out': output_np, 'Total_weight': total_weight_np} + self.attrs = {'reduction': 'mean', 'ignore_index': -100} + + def test_check_output(self): + self.check_output() + + def test_check_output_with_weight(self): + self.with_weight = True + self.check_output() + + def test_check_grad(self): + self.with_weight = True + place = fluid.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.input_shape = [5, 3, 5, 5] + self.label_shape = [5, 5, 5] + + +class TestNLLLossOp2DNoReduce(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "nll_loss" + self.with_weight = False + input_np = np.random.uniform(0.1, 0.8, + self.input_shape).astype("float64") + label_np = np.random.randint(0, self.input_shape[1], + self.label_shape).astype("int64") + output_np = nll_loss_2d(input_np, label_np, reduction='none') + total_weight_np = np.array([0]).astype('float64') + self.inputs = {'X': input_np, 'Label': label_np} + if self.with_weight: + weight_np = np.random.uniform(0.1, 0.8, + self.input_shape[1]).astype("float64") + output_np, total_weight_np = nll_loss_2d( + input_np, label_np, weight=weight_np, reduction='none') + self.inputs['Weight'] = weight_np + + self.outputs = {'Out': output_np, 'Total_weight': total_weight_np} + self.attrs = {'reduction': 'none', 'ignore_index': -100} + + def test_check_output(self): + self.check_output() + + def test_check_output_with_weight(self): + self.with_weight = True + self.check_output() + + def test_check_grad(self): + self.with_weight = True + place = fluid.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.input_shape = [5, 3, 5, 5] + self.label_shape = [5, 5, 5] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index ca755fdb725..2ad41f64a55 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -65,7 +65,7 @@ from .layer.loss import L1Loss #DEFINE_ALIAS from .layer import loss #DEFINE_ALIAS from .layer import conv #DEFINE_ALIAS from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFINE_ALIAS -# from .layer.loss import NLLLoss #DEFINE_ALIAS +from .layer.loss import NLLLoss #DEFINE_ALIAS from .layer.loss import BCELoss #DEFINE_ALIAS # from .layer.learning_rate import CosineDecay #DEFINE_ALIAS # from .layer.learning_rate import ExponentialDecay #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 65bb9215e29..8f20a1cde65 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -19,7 +19,7 @@ __all__ = [ 'CrossEntropyLoss', # 'MSELoss', 'L1Loss', - # 'NLLLoss', + 'NLLLoss', 'BCELoss' ] @@ -329,3 +329,145 @@ class BCELoss(fluid.dygraph.Layer): return fluid.layers.reduce_mean(out) else: return out + + +class NLLLoss(fluid.dygraph.Layer): + """ + This op accepts input and target label and returns negative log likelihood + cross error. It is useful to train a classification problem with C classes. + + The input for the loss is epected to contain log-probabilities of + each classes. It hs to be a Tensor of size either (batch_size, C) or + (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case. + The label for the loss should be a class index in the range [0, C-1] + where C is the number of classes. If ignore_index is specified, the + specified target value does not contribute to the input gradient. + + If the optional argument `weight` is provided, it should be a 1D Tensor + assigning weight to each of the classed. This is particularly useful + when you have an unbalanced training set. + + The loss is calculated as follows. + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\}, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \\begin{cases} + \\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, & + \\text{if reduction} = \\text{'mean';}\\\\ + \\sum_{n=1}^N l_n, & + \\text{if reduction} = \\text{'sum'.} + \\end{cases} + + Parameters: + input (Variable): Input tensor, the data type is float32, float64. + label (Variable): Label tensor, the data type is int64_t. + weight (Variable, optional): Weight tensor, a manual rescaling weight given + to each class. If given, it has to be a Tensor of size `C`. Otherwise, + it treated as if having all ones. the data type is + float32, float64, Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + Default is ``'mean'``. + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. + + Returns: + The tensor variable storing the nll_loss. + + Return type: Variable. + + Examples: + + .. code-block:: python + + # declarative mode + import paddle.fluid as fluid + import numpy as np + import paddle + + input_np = np.random.random(size=(10, 10)).astype(np.float32) + label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float32') + label = fluid.data(name='label', shape=[10], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + print(static_result) + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_np) + label = dg.to_variable(label_np) + output = nll_loss(input, label) + print(output.numpy()) + """ + + def __init__(self, weight=None, reduction='mean', ignore_index=-100): + super(NLLLoss, self).__init__() + self.weight = weight + self.reduction = reduction + self.ignore_index = ignore_index + + def forward(self, input, label): + dtype = self._helper.input_dtype(input) + + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'nll_loss') + fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], + 'nll_loss') + + if self.reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % self.reduction) + + x_shape = list(input.shape) + n = x_shape[0] + c = x_shape[1] + x_dims = len(x_shape) + if x_dims < 2: + raise ValueError('Expected 2 or more dimensions (got {})'.format( + x_dims)) + if x_dims != 2 and x_dims != 4: + input = fluid.layers.reshape(input, shape=[n, c, 1, -1]) + label = fluid.layers.reshape(label, shape=[n, 1, -1]) + out_shape = [n] + x_shape[2:] + + inputs = {'X': input, 'Label': label} + attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index} + + if self.weight is not None: + if isinstance(self.weight, fluid.framework.Variable): + inputs['Weight'] = self.weight + + out = self._helper.create_variable_for_type_inference(dtype=input.dtype) + total_weight = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + outputs = {'Out': out, 'Total_weight': total_weight} + + self._helper.append_op( + type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs) + if x_dims != 2 and x_dims != 4 and self.reduction == 'none': + out = fluid.layers.reshape(out, shape=out_shape) + + return out -- GitLab