未验证 提交 8f659d43 编写于 作者: T Tao Luo 提交者: GitHub

Split some APIs from nn.py to loss.py (#21117)

* Split some APIs from nn.py to loss.py

test=develop

* fix test_detection unit-test

test=develop
上级 4a544762
...@@ -28,6 +28,8 @@ from . import device ...@@ -28,6 +28,8 @@ from . import device
from .device import * from .device import *
from . import math_op_patch from . import math_op_patch
from .math_op_patch import * from .math_op_patch import *
from . import loss
from .loss import *
from . import detection from . import detection
from .detection import * from .detection import *
from . import metric_op from . import metric_op
...@@ -50,6 +52,7 @@ __all__ += metric_op.__all__ ...@@ -50,6 +52,7 @@ __all__ += metric_op.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
__all__ += distributions.__all__ __all__ += distributions.__all__
__all__ += sequence_lod.__all__ __all__ += sequence_lod.__all__
__all__ += loss.__all__
__all__ += rnn.__all__ __all__ += rnn.__all__
from .rnn import * from .rnn import *
...@@ -21,6 +21,7 @@ from .layer_function_generator import generate_layer_fn ...@@ -21,6 +21,7 @@ from .layer_function_generator import generate_layer_fn
from .layer_function_generator import autodoc, templatedoc from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from .loss import softmax_with_cross_entropy
from . import tensor from . import tensor
from . import nn from . import nn
from . import ops from . import ops
...@@ -1540,7 +1541,7 @@ def ssd_loss(location, ...@@ -1540,7 +1541,7 @@ def ssd_loss(location,
target_label = tensor.cast(x=target_label, dtype='int64') target_label = tensor.cast(x=target_label, dtype='int64')
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label.stop_gradient = True target_label.stop_gradient = True
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples # 3. Mining hard examples
actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2]) actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2])
actual_shape.stop_gradient = True actual_shape.stop_gradient = True
...@@ -1594,7 +1595,7 @@ def ssd_loss(location, ...@@ -1594,7 +1595,7 @@ def ssd_loss(location,
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label = tensor.cast(x=target_label, dtype='int64') target_label = tensor.cast(x=target_label, dtype='int64')
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = __reshape_to_2d(target_conf_weight) target_conf_weight = __reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight conf_loss = conf_loss * target_conf_weight
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册