提交 e760df4b 编写于 作者: A andyjpaddle

add sar dict

上级 89c9f363
...@@ -9,11 +9,14 @@ from paddle import nn ...@@ -9,11 +9,14 @@ from paddle import nn
class SARLoss(nn.Layer): class SARLoss(nn.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(SARLoss, self).__init__() super(SARLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96) self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="mean", ignore_index=92)
def forward(self, predicts, batch): def forward(self, predicts, batch):
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets predict = predicts[:, :
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation -1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype(
"int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2] 1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \ assert len(label.shape) == len(list(predict.shape)) - 1, \
......
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
_
`
~
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册