提交 60229c1e 编写于 作者: D Dang Qingqing

Follow comments.

test=develop
上级 2939fc9f
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
""" """
Fluid Metrics Fluid Metrics
The metrics are accomplished via Python natively.
""" """
from __future__ import print_function from __future__ import print_function
...@@ -24,6 +22,12 @@ import copy ...@@ -24,6 +22,12 @@ import copy
import warnings import warnings
import six import six
from .layer_helper import LayerHelper
from .initializer import Constant
from . import unique_name
from .framework import Program, Variable, program_guard
from . import layers
__all__ = [ __all__ = [
'MetricBase', 'MetricBase',
'CompositeMetric', 'CompositeMetric',
...@@ -598,7 +602,7 @@ class DetectionMAP(object): ...@@ -598,7 +602,7 @@ class DetectionMAP(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
exe = fluid.executor(place) exe = fluid.Executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input, map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult) gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var() cur_map, accum_map = map_evaluator.get_map_var()
...@@ -624,9 +628,6 @@ class DetectionMAP(object): ...@@ -624,9 +628,6 @@ class DetectionMAP(object):
overlap_threshold=0.5, overlap_threshold=0.5,
evaluate_difficult=True, evaluate_difficult=True,
ap_version='integral'): ap_version='integral'):
from . import layers
from .layer_helper import LayerHelper
from .initializer import Constant
self.helper = LayerHelper('map_eval') self.helper = LayerHelper('map_eval')
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
...@@ -692,7 +693,6 @@ class DetectionMAP(object): ...@@ -692,7 +693,6 @@ class DetectionMAP(object):
shape(tuple|list): the shape of state shape(tuple|list): the shape of state
Returns: State variable Returns: State variable
""" """
from . import unique_name
state = self.helper.create_variable( state = self.helper.create_variable(
name="_".join([unique_name.generate(self.helper.name), suffix]), name="_".join([unique_name.generate(self.helper.name), suffix]),
persistable=True, persistable=True,
...@@ -717,8 +717,6 @@ class DetectionMAP(object): ...@@ -717,8 +717,6 @@ class DetectionMAP(object):
reset_program(Program|None): a single Program for reset process. reset_program(Program|None): a single Program for reset process.
If None, will create a Program. If None, will create a Program.
""" """
from .framework import Program, Variable, program_guard
from . import layers
def _clone_var_(block, var): def _clone_var_(block, var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
class TestMetricsDetectionMap(unittest.TestCase):
def test_detection_map(self):
program = fluid.Program()
with program_guard(program):
detect_res = fluid.layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = fluid.layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
box = fluid.layers.data(
name='bbox',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
map_eval = fluid.metrics.DetectionMAP(
detect_res, label, box, class_num=21)
cur_map, accm_map = map_eval.get_map_var()
self.assertIsNotNone(cur_map)
self.assertIsNotNone(accm_map)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册