Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e21b3c27
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e21b3c27
编写于
4月 17, 2020
作者:
L
lijianshe02
提交者:
GitHub
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add nll_loss op test=develop (#23758)
* add nll_loss op test=develop
上级
40aa14ec
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
2086 addition
and
2 deletion
+2086
-2
paddle/fluid/operators/nll_loss_op.cc
paddle/fluid/operators/nll_loss_op.cc
+268
-0
paddle/fluid/operators/nll_loss_op.cu
paddle/fluid/operators/nll_loss_op.cu
+488
-0
paddle/fluid/operators/nll_loss_op.h
paddle/fluid/operators/nll_loss_op.h
+303
-0
python/paddle/fluid/tests/unittests/test_nll_loss.py
python/paddle/fluid/tests/unittests/test_nll_loss.py
+883
-0
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+1
-1
python/paddle/nn/layer/loss.py
python/paddle/nn/layer/loss.py
+143
-1
未找到文件。
paddle/fluid/operators/nll_loss_op.cc
0 → 100644
浏览文件 @
e21b3c27
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nll_loss_op.h"
#include <memory>
#include <string>
namespace
paddle
{
namespace
operators
{
class
NLLLossOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Label"
),
"Input"
,
"Label"
,
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Total_weight"
),
"Output"
,
"Total_weight"
,
"NLLLoss"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
reduction
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"reduction"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
()
==
2
||
x_dims
.
size
()
==
4
,
true
,
platform
::
errors
::
InvalidArgument
(
"The tensor rank of Input(X) must be 2 or 4."
));
bool
contain_unknown_dim
=
framework
::
contain_unknown_dim
(
x_dims
)
||
framework
::
contain_unknown_dim
(
label_dims
);
bool
check
=
ctx
->
IsRuntime
()
||
!
contain_unknown_dim
;
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
0
],
label_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"ShapeError: Expected input batch_size to match label batch_size,"
"But received: the Input(x) batch_size is [%s], the Input(label) "
" batch_size is [%s]."
,
x_dims
[
0
],
label_dims
[
0
]));
if
(
ctx
->
HasInput
(
"Weight"
))
{
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Weight) should be a 1D tensor."
));
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
w_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"Input(Weight) Tensor's size should match"
"to the class numer."
));
}
}
if
(
x_dims
.
size
()
==
2
)
{
if
(
reduction
==
"none"
)
{
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
]});
}
else
{
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
}
}
else
if
(
x_dims
.
size
()
==
4
)
{
PADDLE_ENFORCE_EQ
(
label_dims
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
"The tensor rank of Input(Label) must be 3."
));
auto
input0
=
x_dims
[
0
];
auto
input2
=
x_dims
[
2
];
auto
input3
=
x_dims
[
3
];
auto
label0
=
label_dims
[
0
];
auto
label1
=
label_dims
[
1
];
auto
label2
=
label_dims
[
2
];
PADDLE_ENFORCE_EQ
(
input0
==
label0
&&
input2
==
label1
&&
input3
==
label2
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) tensor shape should "
"match to Input(Label) tensor "
"shape."
));
if
(
reduction
==
"none"
)
{
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
x_dims
[
2
],
x_dims
[
3
]});
}
else
{
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
}
}
ctx
->
SetOutputDim
(
"Total_weight"
,
{
1
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
}
};
class
NLLLossOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor, default Tensor<float>) A tensor whose last dimension "
"size is equal to the number of classes. It is expected to "
"contain log-probabilities of each class. "
"The X tensor's shape has to be either [batch_size, C] or"
"[batch_size, C, dim1, ..., dimK] in with K >= 1 in the case "
" K-dimensional loss."
);
AddInput
(
"Label"
,
"(Tensor, default Tensor<int64_t>) A tensor which represents the "
"the ground truth. It contains the class index in the range "
"[0, C-1] where C = number of classes. The Lable tensor's "
"shape has to be (batch_size), or "
"(batch_size, dim1, ..., dimK) "
"with K >= 1 in the case K-dimensional loss."
);
AddInput
(
"Weight"
,
"(Tensor, optional) A tensor should be a 1D tensor assigning "
"weight to each of the classes. It's shape must be [C], where "
"C is the class number."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor, default Tensor<float>) A tensor that represents the "
"NLL loss."
);
AddOutput
(
"Total_weight"
,
"(Tensor, default Tensor<float>) A tensor saves the total"
"weight value in the forward process."
);
AddAttr
<
int64_t
>
(
"ignore_index"
,
"(int64_t, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient."
)
.
SetDefault
(
-
100
);
AddAttr
<
std
::
string
>
(
"reduction"
,
"(string, default mean), Specifies the reduction to apply"
"to the output. The options include
\"
none
\"
,
\"
mean
\"
,"
"
\"
sum
\"
."
)
.
SetDefault
(
"mean"
);
AddComment
(
R"DOC(
NLL(Negative Log Likelihood) Loss Operator.
This operator computes the NLL loss according to the inputs.
The loss can be described as:
$Out[i] = -X[Label[i]]*Weight[Label[i]]$
It can also be used for higher dimension inputs, such as 2D images, by
providing an input of shape (batch_size, C, d1, d2, ..., dK), with
K >= 1, where K is the number of dimensions, and a Label of
appropriate shape. In the case of images, it computes NLL loss
per-pixel.
)DOC"
);
}
};
class
NLLLossGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Label"
),
"Input"
,
"Label"
,
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"NLLLoss"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"NLLLoss"
);
auto
reduction
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"reduction"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
label_dims
=
ctx
->
GetInputDim
(
"Label"
);
auto
dout_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
bool
contain_unknown_dim
=
framework
::
contain_unknown_dim
(
x_dims
)
||
framework
::
contain_unknown_dim
(
dout_dims
);
bool
check
=
ctx
->
IsRuntime
()
||
!
contain_unknown_dim
;
if
(
check
)
{
auto
batch_size
=
x_dims
[
0
];
if
(
x_dims
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
dout_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of Input(Out@Grad) must be 1"
));
if
(
reduction
==
"none"
)
{
PADDLE_ENFORCE_EQ
(
dout_dims
[
0
],
batch_size
,
platform
::
errors
::
InvalidArgument
(
"The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size."
));
}
else
{
PADDLE_ENFORCE_EQ
(
dout_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The reduced size of Input(Out@Grad) must be 1"
));
}
}
else
if
(
x_dims
.
size
()
==
4
)
{
if
(
reduction
==
"none"
)
{
PADDLE_ENFORCE_EQ
(
dout_dims
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of Input(Out@Grad) must be 3,But got [%s]."
,
dout_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
dout_dims
[
0
]
==
label_dims
[
0
]
&&
dout_dims
[
1
]
==
label_dims
[
1
]
&&
dout_dims
[
2
]
==
label_dims
[
2
],
true
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions."
));
}
else
{
PADDLE_ENFORCE_EQ
(
dout_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The reduced size of Input(Out@Grad) must be 1"
));
}
}
}
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
}
};
template
<
typename
T
>
class
NLLLossGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"nll_loss_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Label"
,
this
->
Input
(
"Label"
));
op
->
SetInput
(
"Total_weight"
,
this
->
Output
(
"Total_weight"
));
if
(
this
->
HasInput
(
"Weight"
))
{
op
->
SetInput
(
"Weight"
,
this
->
Input
(
"Weight"
));
}
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
nll_loss
,
ops
::
NLLLossOp
,
ops
::
NLLLossOpMaker
,
ops
::
NLLLossGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
NLLLossGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
nll_loss_grad
,
ops
::
NLLLossGradOp
);
REGISTER_OP_CPU_KERNEL
(
nll_loss
,
ops
::
NLLLossOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
NLLLossOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
nll_loss_grad
,
ops
::
NLLLossGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
NLLLossGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/nll_loss_op.cu
0 → 100644
浏览文件 @
e21b3c27
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/nll_loss_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
static
const
int
NTHREADS
=
32
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template
<
typename
T
>
__global__
void
GPUNLLLossForward1D_no_reduce
(
T
*
out_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
ignore_index
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
batch_size
)
{
const
int64_t
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
i
]
=
0
;
continue
;
}
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
out_data
[
i
]
=
-
x_data
[
i
*
n_classes
+
cur_label
]
*
cur_weight
;
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossForward1D_with_reduce
(
T
*
out_data
,
T
*
total_weight_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
size_average
,
const
int64_t
ignore_index
)
{
__shared__
T
sharedInputs
[
NTHREADS
],
sharedWeights
[
NTHREADS
];
sharedInputs
[
threadIdx
.
x
]
=
0
;
sharedWeights
[
threadIdx
.
x
]
=
0
;
int
i
;
for
(
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
NTHREADS
)
{
const
auto
cur_label
=
label_data
[
i
];
if
(
cur_label
!=
ignore_index
)
{
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
sharedInputs
[
threadIdx
.
x
]
-=
x_data
[
i
*
n_classes
+
cur_label
]
*
cur_weight
;
sharedWeights
[
threadIdx
.
x
]
+=
cur_weight
;
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
*
out_data
=
*
total_weight_data
=
0
;
T
output_val
=
0
;
T
total_weight_val
=
0
;
for
(
i
=
0
;
i
<
NTHREADS
;
++
i
)
{
output_val
+=
sharedInputs
[
i
];
total_weight_val
+=
sharedWeights
[
i
];
}
*
total_weight_data
=
total_weight_val
;
*
out_data
=
output_val
;
if
(
size_average
&&
*
total_weight_data
!=
0
)
{
*
out_data
=
output_val
/
total_weight_val
;
}
}
}
// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template
<
typename
T
,
typename
ReduceOp
,
int
N
>
__device__
void
reduceNValuesInBlock
(
T
*
smem
,
T
threadVals
[
N
],
const
unsigned
int
numVals
,
ReduceOp
reduceOp
,
T
init
)
{
if
(
numVals
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
threadVals
[
i
]
=
init
;
}
return
;
}
// We store each of the N values contiguously, so if N = 2, all values for
// the first threadVal for each thread in the block are stored followed by
// all of the values for the second threadVal for each thread in the block
if
(
threadIdx
.
x
<
numVals
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
smem
[
i
*
numVals
+
threadIdx
.
x
]
=
threadVals
[
i
];
}
}
__syncthreads
();
// Number of lanes in the final reduction --> this is used to determine
// where to put the outputs of each of the n things we are reducing. If
// nLP = 32, then we have the 32 outputs for the first threadVal,
// followed by the 32 outputs for the second threadVal, etc.
const
unsigned
int
numLanesParticipating
=
min
(
numVals
,
warpSize
);
if
(
numVals
>
warpSize
&&
((
threadIdx
.
x
/
warpSize
)
==
0
))
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
threadVals
[
i
]
=
threadIdx
.
x
<
numVals
?
threadVals
[
i
]
:
init
;
}
for
(
int
i
=
warpSize
+
threadIdx
.
x
;
i
<
numVals
;
i
+=
warpSize
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
threadVals
[
j
]
=
reduceOp
(
threadVals
[
j
],
smem
[
j
*
numVals
+
i
]);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
smem
[
i
*
numLanesParticipating
+
threadIdx
.
x
]
=
threadVals
[
i
];
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
numLanesParticipating
==
32
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
#pragma unroll
for
(
int
j
=
1
;
j
<
32
;
++
j
)
{
threadVals
[
i
]
=
reduceOp
(
threadVals
[
i
],
smem
[
i
*
32
+
j
]);
}
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
j
=
1
;
j
<
numLanesParticipating
;
++
j
)
{
threadVals
[
i
]
=
reduceOp
(
threadVals
[
i
],
smem
[
i
*
numVals
+
j
]);
}
}
}
}
}
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
// return the reduced value
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template
<
typename
T
,
typename
ReduceOp
>
__device__
T
reduceBlock
(
T
*
smem
,
const
unsigned
int
numVals
,
T
threadVal
,
ReduceOp
reduceOp
,
T
init
)
{
reduceNValuesInBlock
<
T
,
ReduceOp
,
1
>
(
smem
,
&
threadVal
,
numVals
,
reduceOp
,
init
);
return
threadVal
;
}
template
<
typename
T
>
__global__
void
GPUNLLLossForward2D_no_reduce
(
T
*
out_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
in_dim2
,
const
int64_t
in_dim3
,
const
int64_t
ignore_index
)
{
const
int64_t
map_size
=
in_dim2
*
in_dim3
;
const
int64_t
sample_size
=
n_classes
*
map_size
;
const
int64_t
out_numel
=
batch_size
*
map_size
;
CUDA_1D_KERNEL_LOOP
(
i
,
out_numel
)
{
const
int64_t
b
=
i
%
batch_size
;
const
int64_t
h
=
(
i
/
batch_size
)
%
in_dim2
;
const
int64_t
w
=
(
i
/
(
batch_size
*
in_dim2
))
%
in_dim3
;
const
int64_t
index
=
b
*
map_size
+
h
*
in_dim3
+
w
;
const
int64_t
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
index
]
=
0
;
continue
;
}
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
out_data
[
index
]
=
-
x_data
[
b
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
]
*
cur_weight
;
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossForward2D_with_reduce
(
T
*
out_data
,
T
*
total_weight_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
map_nelem
,
const
int64_t
blocks_per_sample
,
const
int64_t
ignore_index
)
{
__shared__
T
partial_sums
[
kNumCUDAThreads
];
int64_t
i
;
T
input_sum
=
0
;
T
acc_weight
=
0
;
*
out_data
=
0
;
*
total_weight_data
=
0
;
int64_t
sample
=
blockIdx
.
x
/
blocks_per_sample
;
int64_t
toffset
=
sample
*
map_nelem
;
int64_t
ioffset
=
sample
*
map_nelem
*
n_classes
;
int64_t
step
=
blockDim
.
x
*
blocks_per_sample
;
for
(
i
=
(
blockIdx
.
x
%
blocks_per_sample
)
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
map_nelem
;
i
+=
step
)
{
const
int64_t
cur_label
=
label_data
[
toffset
+
i
];
if
(
cur_label
!=
ignore_index
)
{
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
input_sum
-=
x_data
[
ioffset
+
i
+
map_nelem
*
cur_label
]
*
cur_weight
;
acc_weight
+=
cur_weight
;
}
}
input_sum
=
reduceBlock
(
partial_sums
,
blockDim
.
x
,
input_sum
,
thrust
::
plus
<
T
>
(),
(
T
)
0
);
__syncthreads
();
acc_weight
=
reduceBlock
(
partial_sums
,
blockDim
.
x
,
acc_weight
,
thrust
::
plus
<
T
>
(),
(
T
)
0
);
if
(
threadIdx
.
x
==
0
)
{
paddle
::
platform
::
CudaAtomicAdd
(
total_weight_data
,
acc_weight
);
paddle
::
platform
::
CudaAtomicAdd
(
out_data
,
input_sum
);
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossForward2D_size_average
(
T
*
out_data
,
T
*
total_weight_data
)
{
if
(
*
total_weight_data
!=
0
)
{
*
out_data
/=
*
total_weight_data
;
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossBackward1D_no_reduce
(
T
*
dx_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
dout_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
ignore_index
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
batch_size
)
{
const
int64_t
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
dx_data
[
i
*
n_classes
+
cur_label
]
=
-
dout_data
[
i
]
*
cur_weight
;
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossBackward1D_with_reduce
(
T
*
dx_data
,
const
T
*
total_weight_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
dout_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
size_average
,
const
int64_t
ignore_index
)
{
if
(
*
total_weight_data
<=
0
)
{
return
;
}
int
i
;
const
T
norm
=
size_average
?
(
T
)(
1
/
*
total_weight_data
)
:
(
T
)
1
;
for
(
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
NTHREADS
)
{
const
int64_t
cur_label
=
label_data
[
i
];
if
(
cur_label
!=
ignore_index
)
{
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
dx_data
[
i
*
n_classes
+
cur_label
]
=
-
cur_weight
*
dout_data
[
0
]
*
norm
;
}
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossBackward2D_no_reduce
(
T
*
dx_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
dout_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
in_dim2
,
const
int64_t
in_dim3
,
const
int64_t
ignore_index
)
{
const
int64_t
map_size
=
in_dim2
*
in_dim3
;
const
int64_t
sample_size
=
n_classes
*
map_size
;
const
int64_t
out_numel
=
batch_size
*
map_size
;
CUDA_1D_KERNEL_LOOP
(
i
,
out_numel
)
{
const
int64_t
b
=
i
%
batch_size
;
const
int64_t
h
=
(
i
/
batch_size
)
%
in_dim2
;
const
int64_t
w
=
(
i
/
(
batch_size
*
in_dim2
))
%
in_dim3
;
const
int64_t
index
=
b
*
map_size
+
h
*
in_dim3
+
w
;
const
int64_t
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
T
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
;
dx_data
[
b
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
]
=
-
dout_data
[
index
]
*
cur_weight
;
}
}
template
<
typename
T
>
__global__
void
GPUNLLLossBackward2D_with_reduce
(
T
*
dx_data
,
const
T
*
total_weight_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
dout_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
map_nelem
,
const
int64_t
blocks_per_sample
,
const
int64_t
size_average
,
const
int64_t
ignore_index
)
{
if
(
*
total_weight_data
<=
0
)
{
return
;
}
int64_t
i
;
const
T
norm
=
size_average
?
(
T
)(
1
/
*
total_weight_data
)
:
(
T
)
1
;
int
sample
=
blockIdx
.
x
/
blocks_per_sample
;
int
step
=
blockDim
.
x
*
blocks_per_sample
;
int
toffset
=
sample
*
map_nelem
;
int
ioffset
=
sample
*
map_nelem
*
n_classes
;
for
(
i
=
(
blockIdx
.
x
%
blocks_per_sample
)
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
map_nelem
;
i
+=
step
)
{
const
int64_t
cur_label
=
label_data
[
toffset
+
i
];
if
(
cur_label
!=
ignore_index
)
{
dx_data
[
ioffset
+
i
+
map_nelem
*
cur_label
]
=
-
(
weight_data
?
weight_data
[
cur_label
]
:
(
T
)
1
)
*
norm
*
dout_data
[
0
];
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
NLLLossCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
total_weight
=
ctx
.
Output
<
Tensor
>
(
"Total_weight"
);
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
auto
x_data
=
x
->
data
<
T
>
();
auto
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
total_weight_data
=
total_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
cudaMemset
(
total_weight_data
,
0
,
sizeof
(
T
));
auto
x_dims
=
x
->
dims
();
auto
batch_size
=
x_dims
[
0
];
auto
n_classes
=
x_dims
[
1
];
int64_t
size_average
=
(
int64_t
)(
reduction
==
"mean"
);
if
(
x_dims
.
size
()
==
2
)
{
int
blocks
=
NumBlocks
(
batch_size
);
int
threads
=
kNumCUDAThreads
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
(
reduction
==
"none"
)
{
GPUNLLLossForward1D_no_reduce
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
ignore_index
);
}
else
{
GPUNLLLossForward1D_with_reduce
<
T
><<<
1
,
NTHREADS
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
total_weight_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
size_average
,
ignore_index
);
}
}
else
if
(
x_dims
.
size
()
==
4
)
{
const
auto
in_dim2
=
x_dims
[
2
];
const
auto
in_dim3
=
x_dims
[
3
];
const
auto
map_size
=
in_dim2
*
in_dim3
;
const
auto
out_numel
=
batch_size
*
in_dim2
*
in_dim3
;
int
blocks
=
NumBlocks
(
out_numel
);
int
threads
=
kNumCUDAThreads
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
(
reduction
==
"none"
)
{
GPUNLLLossForward2D_no_reduce
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
in_dim2
,
in_dim3
,
ignore_index
);
}
else
{
int
blocks_per_sample
=
NumBlocks
(
map_size
)
/
128
;
blocks_per_sample
=
(
blocks_per_sample
==
0
)
?
1
:
blocks_per_sample
;
int
total_blocks
=
blocks_per_sample
*
batch_size
;
GPUNLLLossForward2D_with_reduce
<
T
><<<
total_blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
total_weight_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
map_size
,
blocks_per_sample
,
ignore_index
);
if
(
size_average
)
{
GPUNLLLossForward2D_size_average
<
T
><<<
1
,
1
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
total_weight_data
);
}
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
NLLLossGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
total_weight
=
ctx
.
Input
<
Tensor
>
(
"Total_weight"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dout_data
=
dout
->
data
<
T
>
();
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
auto
total_weight_data
=
total_weight
->
data
<
T
>
();
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
cudaMemset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
int64_t
size_average
=
(
int64_t
)(
reduction
==
"mean"
);
auto
x_dims
=
x
->
dims
();
auto
batch_size
=
x_dims
[
0
];
auto
n_classes
=
x_dims
[
1
];
if
(
x_dims
.
size
()
==
2
)
{
int
blocks
=
NumBlocks
(
batch_size
);
int
threads
=
kNumCUDAThreads
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
(
reduction
==
"none"
)
{
GPUNLLLossBackward1D_no_reduce
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
dx_data
,
label_data
,
weight_data
,
dout_data
,
batch_size
,
n_classes
,
ignore_index
);
}
else
{
GPUNLLLossBackward1D_with_reduce
<
T
><<<
1
,
NTHREADS
,
0
,
dev_ctx
.
stream
()
>>>
(
dx_data
,
total_weight_data
,
label_data
,
weight_data
,
dout_data
,
batch_size
,
n_classes
,
size_average
,
ignore_index
);
}
}
else
if
(
x_dims
.
size
()
==
4
)
{
const
auto
in_dim2
=
x_dims
[
2
];
const
auto
in_dim3
=
x_dims
[
3
];
const
auto
map_size
=
in_dim2
*
in_dim3
;
const
auto
out_numel
=
batch_size
*
in_dim2
*
in_dim3
;
int
blocks
=
NumBlocks
(
out_numel
);
int
threads
=
kNumCUDAThreads
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
(
reduction
==
"none"
)
{
GPUNLLLossBackward2D_no_reduce
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
dx_data
,
label_data
,
weight_data
,
dout_data
,
batch_size
,
n_classes
,
in_dim2
,
in_dim3
,
ignore_index
);
}
else
{
int
blocks_per_sample
=
NumBlocks
(
map_size
)
/
128
;
blocks_per_sample
=
(
blocks_per_sample
==
0
)
?
1
:
blocks_per_sample
;
int
total_blocks
=
blocks_per_sample
*
batch_size
;
GPUNLLLossBackward2D_with_reduce
<
T
><<<
total_blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
dx_data
,
total_weight_data
,
label_data
,
weight_data
,
dout_data
,
batch_size
,
n_classes
,
map_size
,
blocks_per_sample
,
size_average
,
ignore_index
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
nll_loss
,
ops
::
NLLLossCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
NLLLossCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
nll_loss_grad
,
ops
::
NLLLossGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
NLLLossGradCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/nll_loss_op.h
0 → 100644
浏览文件 @
e21b3c27
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
static
void
nll_loss_1D
(
T
*
out_data
,
T
*
total_weight_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
std
::
string
reduction
,
const
int64_t
ignore_index
)
{
if
(
reduction
==
"none"
)
{
for
(
int64_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
auto
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
i
]
=
0
;
continue
;
}
PADDLE_ENFORCE_EQ
(
cur_label
>=
0
&&
cur_label
<
n_classes
,
true
,
platform
::
errors
::
InvalidArgument
(
"label should not be out of bounds."
));
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
out_data
[
i
]
=
-
x_data
[
i
*
n_classes
+
cur_label
]
*
cur_weight
;
}
return
;
}
T
output_val
=
0
;
T
total_weight_val
=
0
;
for
(
int64_t
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
auto
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
i
]
=
0
;
continue
;
}
PADDLE_ENFORCE_EQ
(
cur_label
>=
0
&&
cur_label
<
n_classes
,
true
,
platform
::
errors
::
InvalidArgument
(
"label should not be out of bounds."
));
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
total_weight_val
+=
cur_weight
;
output_val
-=
x_data
[
i
*
n_classes
+
cur_label
]
*
cur_weight
;
}
if
(
reduction
==
"mean"
&&
total_weight_val
!=
0
)
{
output_val
/=
total_weight_val
;
}
*
out_data
=
output_val
;
*
total_weight_data
=
total_weight_val
;
}
template
<
typename
T
>
static
void
nll_loss_2D
(
T
*
out_data
,
T
*
total_weight_data
,
const
T
*
x_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
in_dim2
,
const
int64_t
in_dim3
,
const
std
::
string
reduction
,
const
int64_t
ignore_index
)
{
const
auto
map_size
=
in_dim2
*
in_dim3
;
const
auto
sample_size
=
n_classes
*
map_size
;
if
(
reduction
==
"none"
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
h
=
0
;
h
<
in_dim2
;
h
++
)
{
for
(
int
w
=
0
;
w
<
in_dim3
;
w
++
)
{
const
auto
index
=
i
*
map_size
+
h
*
in_dim3
+
w
;
const
auto
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
index
]
=
0
;
continue
;
}
PADDLE_ENFORCE_EQ
(
cur_label
>=
0
&&
cur_label
<
n_classes
,
true
,
platform
::
errors
::
InvalidArgument
(
"label should nor be out of bounds."
));
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
out_data
[
index
]
=
-
x_data
[
i
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
]
*
cur_weight
;
}
}
}
return
;
}
T
output_val
=
0
;
T
total_weight_val
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
h
=
0
;
h
<
in_dim2
;
h
++
)
{
for
(
int
w
=
0
;
w
<
in_dim3
;
w
++
)
{
const
auto
index
=
i
*
map_size
+
h
*
in_dim3
+
w
;
const
auto
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
out_data
[
index
]
=
0
;
continue
;
}
PADDLE_ENFORCE_EQ
(
cur_label
>=
0
&&
cur_label
<
n_classes
,
true
,
platform
::
errors
::
InvalidArgument
(
"label should nor be out of bounds."
));
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
total_weight_val
+=
cur_weight
;
output_val
-=
x_data
[
i
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
]
*
cur_weight
;
}
}
}
if
(
reduction
==
"mean"
&&
total_weight_val
!=
0
)
{
output_val
/=
total_weight_val
;
}
*
out_data
=
output_val
;
*
total_weight_data
=
total_weight_val
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
NLLLossOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
total_weight
=
ctx
.
Output
<
Tensor
>
(
"Total_weight"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
x_data
=
x
->
data
<
T
>
();
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
auto
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
total_weight_data
=
total_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
*
total_weight_data
=
0
;
auto
x_dims
=
x
->
dims
();
const
auto
batch_size
=
x_dims
[
0
];
const
auto
n_classes
=
x_dims
[
1
];
if
(
x_dims
.
size
()
==
2
)
{
nll_loss_1D
<
T
>
(
out_data
,
total_weight_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
reduction
,
ignore_index
);
}
else
if
(
x_dims
.
size
()
==
4
)
{
const
auto
in_dim2
=
x_dims
[
2
];
const
auto
in_dim3
=
x_dims
[
3
];
nll_loss_2D
<
T
>
(
out_data
,
total_weight_data
,
x_data
,
label_data
,
weight_data
,
batch_size
,
n_classes
,
in_dim2
,
in_dim3
,
reduction
,
ignore_index
);
}
}
};
template
<
typename
T
>
static
void
nll_loss_grad_1D
(
T
*
dx_data
,
const
T
*
dout_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
total_weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
std
::
string
reduction
,
const
int64_t
ignore_index
)
{
if
(
reduction
==
"none"
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
auto
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
dx_data
[
i
*
n_classes
+
cur_label
]
=
-
dout_data
[
i
]
*
cur_weight
;
}
return
;
}
const
T
dout_val
=
*
dout_data
;
const
T
total_weight_val
=
*
total_weight_data
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
auto
cur_label
=
label_data
[
i
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
dx_data
[
i
*
n_classes
+
cur_label
]
=
-
dout_val
*
cur_weight
;
if
(
reduction
==
"mean"
)
{
dx_data
[
i
*
n_classes
+
cur_label
]
/=
total_weight_val
;
}
}
}
template
<
typename
T
>
static
void
nll_loss_grad_2D
(
T
*
dx_data
,
const
T
*
dout_data
,
const
int64_t
*
label_data
,
const
T
*
weight_data
,
const
T
*
total_weight_data
,
const
int64_t
batch_size
,
const
int64_t
n_classes
,
const
int64_t
in_dim2
,
const
int64_t
in_dim3
,
const
std
::
string
reduction
,
const
int64_t
ignore_index
)
{
const
auto
map_size
=
in_dim2
*
in_dim3
;
const
auto
sample_size
=
n_classes
*
map_size
;
if
(
reduction
==
"none"
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
h
=
0
;
h
<
in_dim2
;
h
++
)
{
for
(
int
w
=
0
;
w
<
in_dim3
;
w
++
)
{
const
auto
index
=
i
*
map_size
+
h
*
in_dim3
+
w
;
const
auto
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
dx_data
[
i
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
]
=
-
cur_weight
*
dout_data
[
index
];
}
}
}
return
;
}
const
T
dout_val
=
*
dout_data
;
const
T
total_weight_val
=
*
total_weight_data
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
h
=
0
;
h
<
in_dim2
;
h
++
)
{
for
(
int
w
=
0
;
w
<
in_dim3
;
w
++
)
{
const
auto
index
=
i
*
map_size
+
h
*
in_dim3
+
w
;
const
auto
cur_label
=
label_data
[
index
];
if
(
cur_label
==
ignore_index
)
{
continue
;
}
const
auto
cur_weight
=
weight_data
?
weight_data
[
cur_label
]
:
static_cast
<
T
>
(
1
);
const
auto
dx_index
=
i
*
sample_size
+
cur_label
*
map_size
+
h
*
in_dim3
+
w
;
dx_data
[
dx_index
]
=
-
dout_val
*
cur_weight
;
if
(
reduction
==
"mean"
)
{
dx_data
[
dx_index
]
/=
total_weight_val
;
}
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
NLLLossGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
total_weight
=
ctx
.
Input
<
Tensor
>
(
"Total_weight"
);
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
auto
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dout_data
=
dout
->
data
<
T
>
();
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
auto
total_weight_data
=
total_weight
->
data
<
T
>
();
memset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
const
auto
x_dims
=
x
->
dims
();
const
auto
batch_size
=
x_dims
[
0
];
const
auto
n_classes
=
x_dims
[
1
];
if
(
x_dims
.
size
()
==
2
)
{
nll_loss_grad_1D
(
dx_data
,
dout_data
,
label_data
,
weight_data
,
total_weight_data
,
batch_size
,
n_classes
,
reduction
,
ignore_index
);
}
else
if
(
x_dims
.
size
()
==
4
)
{
const
auto
in_dim2
=
x_dims
[
2
];
const
auto
in_dim3
=
x_dims
[
3
];
nll_loss_grad_2D
(
dx_data
,
dout_data
,
label_data
,
weight_data
,
total_weight_data
,
batch_size
,
n_classes
,
in_dim2
,
in_dim3
,
reduction
,
ignore_index
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_nll_loss.py
0 → 100644
浏览文件 @
e21b3c27
此差异已折叠。
点击以展开。
python/paddle/nn/__init__.py
浏览文件 @
e21b3c27
...
...
@@ -65,7 +65,7 @@ from .layer.loss import L1Loss #DEFINE_ALIAS
from
.layer
import
loss
#DEFINE_ALIAS
from
.layer
import
conv
#DEFINE_ALIAS
from
.layer.conv
import
Conv2D
,
Conv2DTranspose
,
Conv3D
,
Conv3DTranspose
#DEFINE_ALIAS
# from .layer.loss import NLLLoss
#DEFINE_ALIAS
from
.layer.loss
import
NLLLoss
#DEFINE_ALIAS
from
.layer.loss
import
BCELoss
#DEFINE_ALIAS
# from .layer.learning_rate import CosineDecay #DEFINE_ALIAS
# from .layer.learning_rate import ExponentialDecay #DEFINE_ALIAS
...
...
python/paddle/nn/layer/loss.py
浏览文件 @
e21b3c27
...
...
@@ -19,7 +19,7 @@ __all__ = [
'CrossEntropyLoss'
,
# 'MSELoss',
'L1Loss'
,
#
'NLLLoss',
'NLLLoss'
,
'BCELoss'
]
...
...
@@ -329,3 +329,145 @@ class BCELoss(fluid.dygraph.Layer):
return
fluid
.
layers
.
reduce_mean
(
out
)
else
:
return
out
class
NLLLoss
(
fluid
.
dygraph
.
Layer
):
"""
This op accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes.
The input for the loss is epected to contain log-probabilities of
each classes. It hs to be a Tensor of size either (batch_size, C) or
(batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
The label for the loss should be a class index in the range [0, C-1]
where C is the number of classes. If ignore_index is specified, the
specified target value does not contribute to the input gradient.
If the optional argument `weight` is provided, it should be a 1D Tensor
assigning weight to each of the classed. This is particularly useful
when you have an unbalanced training set.
The loss is calculated as follows.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^
\\
top, \quad
l_n = - w_{y_n} x_{n,y_n}, \quad
w_{c} =
\\
text{weight}[c] \cdot \mathbb{1}\{c
\\
not=
\\
text{ignore
\\
_index}\},
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then
.. math::
\ell(x, y) =
\\
begin{cases}
\\
sum_{n=1}^N
\\
frac{1}{
\\
sum_{n=1}^N w_{y_n}} l_n, &
\\
text{if reduction} =
\\
text{'mean';}
\\\\
\\
sum_{n=1}^N l_n, &
\\
text{if reduction} =
\\
text{'sum'.}
\\
end{cases}
Parameters:
input (Variable): Input tensor, the data type is float32, float64.
label (Variable): Label tensor, the data type is int64_t.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. If given, it has to be a Tensor of size `C`. Otherwise,
it treated as if having all ones. the data type is
float32, float64, Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient.
Returns:
The tensor variable storing the nll_loss.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input_np = np.random.random(size=(10, 10)).astype(np.float32)
label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[10, 10], dtype='float32')
label = fluid.data(name='label', shape=[10], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[res])
print(static_result)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_np)
label = dg.to_variable(label_np)
output = nll_loss(input, label)
print(output.numpy())
"""
def
__init__
(
self
,
weight
=
None
,
reduction
=
'mean'
,
ignore_index
=-
100
):
super
(
NLLLoss
,
self
).
__init__
()
self
.
weight
=
weight
self
.
reduction
=
reduction
self
.
ignore_index
=
ignore_index
def
forward
(
self
,
input
,
label
):
dtype
=
self
.
_helper
.
input_dtype
(
input
)
fluid
.
data_feeder
.
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'nll_loss'
)
fluid
.
data_feeder
.
check_variable_and_dtype
(
label
,
'label'
,
[
'int64'
],
'nll_loss'
)
if
self
.
reduction
not
in
[
'sum'
,
'mean'
,
'none'
]:
raise
ValueError
(
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed."
%
self
.
reduction
)
x_shape
=
list
(
input
.
shape
)
n
=
x_shape
[
0
]
c
=
x_shape
[
1
]
x_dims
=
len
(
x_shape
)
if
x_dims
<
2
:
raise
ValueError
(
'Expected 2 or more dimensions (got {})'
.
format
(
x_dims
))
if
x_dims
!=
2
and
x_dims
!=
4
:
input
=
fluid
.
layers
.
reshape
(
input
,
shape
=
[
n
,
c
,
1
,
-
1
])
label
=
fluid
.
layers
.
reshape
(
label
,
shape
=
[
n
,
1
,
-
1
])
out_shape
=
[
n
]
+
x_shape
[
2
:]
inputs
=
{
'X'
:
input
,
'Label'
:
label
}
attrs
=
{
'reduction'
:
self
.
reduction
,
'ignore_index'
:
self
.
ignore_index
}
if
self
.
weight
is
not
None
:
if
isinstance
(
self
.
weight
,
fluid
.
framework
.
Variable
):
inputs
[
'Weight'
]
=
self
.
weight
out
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
total_weight
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
outputs
=
{
'Out'
:
out
,
'Total_weight'
:
total_weight
}
self
.
_helper
.
append_op
(
type
=
'nll_loss'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
if
x_dims
!=
2
and
x_dims
!=
4
and
self
.
reduction
==
'none'
:
out
=
fluid
.
layers
.
reshape
(
out
,
shape
=
out_shape
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录