提交 a5899ca1 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_update_cluster_doc

if(NOT DEFINED SPHINX_THEME)
set(SPHINX_THEME default)
endif()
if(NOT DEFINED SPHINX_THEME_DIR)
set(SPHINX_THEME_DIR)
endif()
# configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_build")
# Sphinx cache with pickled ReST documents
set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees")
# HTML output director
set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/templates/conf.py.en.in"
"${BINARY_BUILD_DIR_EN}/conf.py"
@ONLY)
sphinx_add_target(paddle_docs
html
${BINARY_BUILD_DIR_EN}
${SPHINX_CACHE_DIR_EN}
${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_EN})
# configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build")
# Sphinx cache with pickled ReST documents
set(SPHINX_CACHE_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_doctrees")
# HTML output directory
set(SPHINX_HTML_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/html")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/templates/conf.py.cn.in"
"${BINARY_BUILD_DIR_CN}/conf.py"
@ONLY)
sphinx_add_target(paddle_docs_cn
html
${BINARY_BUILD_DIR_CN}
${SPHINX_CACHE_DIR_CN}
${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_CN})
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(v2)
../../CONTRIBUTING.md
\ No newline at end of file
...@@ -121,7 +121,7 @@ html_theme = 'sphinx_rtd_theme' ...@@ -121,7 +121,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static'] #html_static_path = []
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc' htmlhelp_basename = project + 'doc'
......
...@@ -121,7 +121,7 @@ html_theme = 'sphinx_rtd_theme' ...@@ -121,7 +121,7 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['@PADDLE_SOURCE_DIR@/doc_theme/static'] #html_static_path = []
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc' htmlhelp_basename = project + 'doc'
......
if(NOT DEFINED SPHINX_THEME)
set(SPHINX_THEME default)
endif()
if(NOT DEFINED SPHINX_THEME_DIR)
set(SPHINX_THEME_DIR)
endif()
# configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_build")
# Sphinx cache with pickled ReST documents
set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees")
# HTML output director
set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.en.in"
"${BINARY_BUILD_DIR_EN}/conf.py"
@ONLY)
sphinx_add_target(paddle_docs
html
${BINARY_BUILD_DIR_EN}
${SPHINX_CACHE_DIR_EN}
${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_EN})
# configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build")
# Sphinx cache with pickled ReST documents
set(SPHINX_CACHE_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_doctrees")
# HTML output directory
set(SPHINX_HTML_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/html")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.cn.in"
"${BINARY_BUILD_DIR_CN}/conf.py"
@ONLY)
sphinx_add_target(paddle_docs_cn
html
${BINARY_BUILD_DIR_CN}
${SPHINX_CACHE_DIR_CN}
${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_CN})
../../../CONTRIBUTING.md
\ No newline at end of file
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
contribute_to_paddle_cn.md contribute_to_paddle_cn.md
write_docs_cn.rst write_docs_cn.rst
new_layer_cn.rst
...@@ -6,3 +6,4 @@ Development ...@@ -6,3 +6,4 @@ Development
contribute_to_paddle_en.md contribute_to_paddle_en.md
write_docs_en.rst write_docs_en.rst
new_layer_en.rst
================ ==================
实现新的网络层 如何实现新的网络层
================ ==================
这份教程展示了如何在PaddlePaddle中实现一个自定义的网络层。在这里我们使用全连接层作为例子来展示实现新网络层所需要的四个步骤。 这份教程展示了如何在PaddlePaddle中实现一个自定义的网络层。在这里我们使用全连接层作为例子来展示实现新网络层所需要的四个步骤。
......
...@@ -6,5 +6,6 @@ HOW TO ...@@ -6,5 +6,6 @@ HOW TO
cmd_parameter/index_en.rst cmd_parameter/index_en.rst
cluster/index_en.rst cluster/index_en.rst
capi/index_en.rst
rnn/index_en.rst rnn/index_en.rst
optimization/gpu_profiling_en.rst optimization/gpu_profiling_en.rst
...@@ -55,7 +55,7 @@ above profilers. ...@@ -55,7 +55,7 @@ above profilers.
:code:`paddle/math/test` 目录中的 :code:`test_GpuProfiler` 就是用于展示上述分析工具的用法。 :code:`paddle/math/test` 目录中的 :code:`test_GpuProfiler` 就是用于展示上述分析工具的用法。
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:linenos: :linenos:
...@@ -83,7 +83,7 @@ program crashes when CPU version of PaddlePaddle invokes them. ...@@ -83,7 +83,7 @@ program crashes when CPU version of PaddlePaddle invokes them.
1. 加入 :code:`REGISTER_TIMER_INFO` 和 :code:`printAllStatus` 函数(如高亮部分)。 1. 加入 :code:`REGISTER_TIMER_INFO` 和 :code:`printAllStatus` 函数(如高亮部分)。
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:emphasize-lines: 8-12,14 :emphasize-lines: 8-12,14
...@@ -130,7 +130,7 @@ nvprof 工具 ...@@ -130,7 +130,7 @@ nvprof 工具
1. 将 :code:`REGISTER_GPU_PROFILER` 函数加到代码中(参考强调部分)。 1. 将 :code:`REGISTER_GPU_PROFILER` 函数加到代码中(参考强调部分)。
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:emphasize-lines: 6-7 :emphasize-lines: 6-7
......
...@@ -54,7 +54,7 @@ In this tutorial, we will focus on nvprof and nvvp. ...@@ -54,7 +54,7 @@ In this tutorial, we will focus on nvprof and nvvp.
:code:`test_GpuProfiler` from :code:`paddle/math/tests` directory will be used to evaluate :code:`test_GpuProfiler` from :code:`paddle/math/tests` directory will be used to evaluate
above profilers. above profilers.
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:linenos: :linenos:
...@@ -80,7 +80,7 @@ As a simple example, consider the following: ...@@ -80,7 +80,7 @@ As a simple example, consider the following:
1. Add :code:`REGISTER_TIMER_INFO` and :code:`printAllStatus` functions (see the emphasize-lines). 1. Add :code:`REGISTER_TIMER_INFO` and :code:`printAllStatus` functions (see the emphasize-lines).
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:emphasize-lines: 8-12,14 :emphasize-lines: 8-12,14
...@@ -127,7 +127,7 @@ To use this command line profiler **nvprof**, you can simply issue the following ...@@ -127,7 +127,7 @@ To use this command line profiler **nvprof**, you can simply issue the following
1. Add :code:`REGISTER_GPU_PROFILER` function (see the emphasize-lines). 1. Add :code:`REGISTER_GPU_PROFILER` function (see the emphasize-lines).
.. literalinclude:: ../../../paddle/math/tests/test_GpuProfiler.cpp .. literalinclude:: ../../../../paddle/math/tests/test_GpuProfiler.cpp
:language: c++ :language: c++
:lines: 137-151 :lines: 137-151
:emphasize-lines: 6-7 :emphasize-lines: 6-7
......
...@@ -47,11 +47,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -47,11 +47,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(det_dims[1], 6UL, PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(DetectRes) [N, 6]."); "The shape is of Input(DetectRes) [N, 6].");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The rank of Input(Label) must be 2, " "The rank of Input(Label) must be 2, "
"the shape is [N, 6]."); "the shape is [N, 6].");
PADDLE_ENFORCE_EQ(label_dims[1], 6UL, PADDLE_ENFORCE_EQ(label_dims[1], 6, "The shape is of Input(Label) [N, 6].");
"The shape is of Input(Label) [N, 6].");
if (ctx->HasInput("PosCount")) { if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"), PADDLE_ENFORCE(ctx->HasInput("TruePos"),
...@@ -96,6 +95,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,6 +95,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"instance, the offsets in first dimension are called LoD, " "instance, the offsets in first dimension are called LoD, "
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
"means there is no ground-truth data."); "means there is no ground-truth data.");
AddInput("HasState",
"(Tensor<int>) A tensor with shape [1], 0 means ignoring input "
"states, which including PosCount, TruePos, FalsePos.")
.AsDispensable();
AddInput("PosCount", AddInput("PosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the " "(Tensor) A tensor with shape [Ncls, 1], store the "
"input positive example count of each class, Ncls is the count of " "input positive example count of each class, Ncls is the count of "
...@@ -145,7 +148,7 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -145,7 +148,7 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"(float) " "(float) "
"The lower bound jaccard overlap threshold of detection output and " "The lower bound jaccard overlap threshold of detection output and "
"ground-truth data.") "ground-truth data.")
.SetDefault(.3f); .SetDefault(.5f);
AddAttr<bool>("evaluate_difficult", AddAttr<bool>("evaluate_difficult",
"(bool, default true) " "(bool, default true) "
"Switch to control whether the difficult data is evaluated.") "Switch to control whether the difficult data is evaluated.")
......
...@@ -87,7 +87,13 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,13 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
std::map<int, std::vector<std::pair<T, int>>> true_pos; std::map<int, std::vector<std::pair<T, int>>> true_pos;
std::map<int, std::vector<std::pair<T, int>>> false_pos; std::map<int, std::vector<std::pair<T, int>>> false_pos;
if (in_pos_count != nullptr) { auto* has_state = ctx.Input<framework::LoDTensor>("HasState");
int state = 0;
if (has_state) {
state = has_state->data<int>()[0];
}
if (in_pos_count != nullptr && state) {
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count, GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
true_pos, false_pos); true_pos, false_pos);
} }
...@@ -202,6 +208,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -202,6 +208,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
int* pos_count_data = output_pos_count.mutable_data<int>( int* pos_count_data = output_pos_count.mutable_data<int>(
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace()); framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
T* true_pos_data = output_true_pos.mutable_data<T>( T* true_pos_data = output_true_pos.mutable_data<T>(
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace()); framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
T* false_pos_data = output_false_pos.mutable_data<T>( T* false_pos_data = output_false_pos.mutable_data<T>(
......
...@@ -12,8 +12,8 @@ make -j `nproc` copy_paddle_pybind ...@@ -12,8 +12,8 @@ make -j `nproc` copy_paddle_pybind
make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs
# check websites for broken links # check websites for broken links
linkchecker doc/en/html/index.html linkchecker doc/v2/en/html/index.html
linkchecker doc/cn/html/index.html linkchecker doc/v2/cn/html/index.html
linkchecker doc/api/en/html/index.html linkchecker doc/api/en/html/index.html
# Parse Github URL # Parse Github URL
...@@ -55,8 +55,8 @@ function deploy_docs() { ...@@ -55,8 +55,8 @@ function deploy_docs() {
set +e set +e
rm -rf ${DIR}/doc ${DIR}/doc_cn ${DIR}/api_doc rm -rf ${DIR}/doc ${DIR}/doc_cn ${DIR}/api_doc
set -e set -e
cp -r ../doc/cn/html ${DIR}/doc_cn cp -r ../doc/v2/cn/html ${DIR}/doc_cn
cp -r ../doc/en/html ${DIR}/doc cp -r ../doc/v2/en/html ${DIR}/doc
cp -r ../doc/api/en/html ${DIR}/api_doc cp -r ../doc/api/en/html ${DIR}/api_doc
git add . git add .
} }
......
...@@ -18,11 +18,13 @@ import layers ...@@ -18,11 +18,13 @@ import layers
from framework import Program, Variable, program_guard from framework import Program, Variable, program_guard
import unique_name import unique_name
from layer_helper import LayerHelper from layer_helper import LayerHelper
from initializer import Constant
__all__ = [ __all__ = [
'Accuracy', 'Accuracy',
'ChunkEvaluator', 'ChunkEvaluator',
'EditDistance', 'EditDistance',
'DetectionMAP',
] ]
...@@ -285,3 +287,120 @@ class EditDistance(Evaluator): ...@@ -285,3 +287,120 @@ class EditDistance(Evaluator):
result = executor.run( result = executor.run(
eval_program, fetch_list=[avg_distance, avg_instance_error]) eval_program, fetch_list=[avg_distance, avg_instance_error])
return np.array(result[0]), np.array(result[1]) return np.array(result[0]), np.array(result[1])
class DetectionMAP(Evaluator):
"""
Calculate the detection mean average precision (mAP).
TODO (Dang Qingqing): update the following doc.
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
Args:
input (Variable): The detection results, which is a LoDTensor with shape
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_difficult (Variable): Whether this ground truth is a difficult
bounding box (bbox), which is a LoDTensor [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
overlap_threshold (float): The threshold for deciding true/false
positive, 0.5 by defalut.
evaluate_difficult (bool): Whether to consider difficult ground truth
for evaluation, True by defalut.
ap_version (string): The average precision calculation ways, it must be
'integral' or '11point'. Please check
https://sanchom.wordpress.com/tag/average-precision/ for details.
- 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve.
Example:
exe = fluid.executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_difficult, gt_box)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
In the above example:
'cur_map_v' is the mAP of current mini-batch.
'accum_map_v' is the accumulative mAP of one pass.
"""
def __init__(self,
input,
gt_label,
gt_box,
gt_difficult,
overlap_threshold=0.5,
evaluate_difficult=True,
ap_version='integral'):
super(DetectionMAP, self).__init__("map_eval")
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
# calculate mean average precision (mAP) of current mini-batch
map = layers.detection_map(
input,
label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
self.create_state(dtype='int32', shape=None, suffix='accum_pos_count')
self.create_state(dtype='float32', shape=None, suffix='accum_true_pos')
self.create_state(dtype='float32', shape=None, suffix='accum_false_pos')
self.has_state = None
var = self.helper.create_variable(
persistable=True, dtype='int32', shape=[1])
self.helper.set_variable_initializer(
var, initializer=Constant(value=int(0)))
self.has_state = var
# calculate accumulative mAP
accum_map = layers.detection_map(
input,
label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
has_state=self.has_state,
input_states=self.states,
out_states=self.states,
ap_version=ap_version)
layers.fill_constant(
shape=self.has_state.shape,
value=1,
dtype=self.has_state.dtype,
out=self.has_state)
self.cur_map = map
self.accum_map = accum_map
def get_map_var(self):
return self.cur_map, self.accum_map
def reset(self, executor, reset_program=None):
if reset_program is None:
reset_program = Program()
with program_guard(main_program=reset_program):
var = _clone_var_(reset_program.current_block(), self.has_state)
layers.fill_constant(
shape=var.shape, value=0, dtype=var.dtype, out=var)
executor.run(reset_program)
...@@ -151,23 +151,34 @@ def detection_output(loc, ...@@ -151,23 +151,34 @@ def detection_output(loc,
@autodoc() @autodoc()
def detection_map(detect_res, def detection_map(detect_res,
label, label,
pos_count=None,
true_pos=None,
false_pos=None,
overlap_threshold=0.3, overlap_threshold=0.3,
evaluate_difficult=True, evaluate_difficult=True,
ap_type='integral'): has_state=None,
input_states=None,
out_states=None,
ap_version='integral'):
helper = LayerHelper("detection_map", **locals()) helper = LayerHelper("detection_map", **locals())
map_out = helper.create_tmp_variable(dtype='float32') def __create_var(type):
accum_pos_count_out = helper.create_tmp_variable(dtype='int32') return helper.create_tmp_variable(dtype=type)
accum_true_pos_out = helper.create_tmp_variable(dtype='float32')
accum_false_pos_out = helper.create_tmp_variable(dtype='float32') map_out = __create_var('float32')
accum_pos_count_out = out_states[0] if out_states else __create_var('int32')
accum_true_pos_out = out_states[1] if out_states else __create_var(
'float32')
accum_false_pos_out = out_states[2] if out_states else __create_var(
'float32')
pos_count = input_states[0] if input_states else None
true_pos = input_states[1] if input_states else None
false_pos = input_states[2] if input_states else None
helper.append_op( helper.append_op(
type="detection_map", type="detection_map",
inputs={ inputs={
'Label': label, 'Label': label,
'DetectRes': detect_res, 'DetectRes': detect_res,
'HasState': has_state,
'PosCount': pos_count, 'PosCount': pos_count,
'TruePos': true_pos, 'TruePos': true_pos,
'FalsePos': false_pos 'FalsePos': false_pos
...@@ -181,9 +192,9 @@ def detection_map(detect_res, ...@@ -181,9 +192,9 @@ def detection_map(detect_res,
attrs={ attrs={
'overlap_threshold': overlap_threshold, 'overlap_threshold': overlap_threshold,
'evaluate_difficult': evaluate_difficult, 'evaluate_difficult': evaluate_difficult,
'ap_type': ap_type 'ap_type': ap_version
}) })
return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out return map_out
def bipartite_match(dist_matrix, def bipartite_match(dist_matrix,
......
...@@ -274,7 +274,7 @@ def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None): ...@@ -274,7 +274,7 @@ def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
use_cuda, use_cuda,
parallel=parallel, parallel=parallel,
save_dirname=save_dirname) save_dirname=save_dirname)
infer(use_cuda, save_dirname) infer(word_dict, use_cuda, save_dirname)
class TestUnderstandSentiment(unittest.TestCase): class TestUnderstandSentiment(unittest.TestCase):
......
...@@ -158,26 +158,9 @@ class TestDetectionMAP(unittest.TestCase): ...@@ -158,26 +158,9 @@ class TestDetectionMAP(unittest.TestCase):
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out = layers.detection_map( map_out = layers.detection_map(detect_res=detect_res, label=label)
detect_res=detect_res, label=label)
self.assertIsNotNone(map_out) self.assertIsNotNone(map_out)
self.assertIsNotNone(accum_pos_count_out)
self.assertIsNotNone(accum_true_pos_out)
self.assertIsNotNone(accum_false_pos_out)
self.assertEqual(map_out.shape, (1, )) self.assertEqual(map_out.shape, (1, ))
map_out, accum_pos_count_out2, accum_true_pos_out2, accum_false_pos_out2 = layers.detection_map(
detect_res=detect_res, label=label)
self.assertIsNotNone(map_out)
self.assertIsNotNone(accum_pos_count_out2)
self.assertIsNotNone(accum_true_pos_out2)
self.assertIsNotNone(accum_false_pos_out2)
self.assertEqual(map_out.shape, (1, ))
self.assertEqual(accum_pos_count_out.shape,
accum_pos_count_out2.shape)
self.assertEqual(accum_true_pos_out.shape,
accum_true_pos_out2.shape)
self.assertEqual(accum_false_pos_out.shape,
accum_false_pos_out2.shape)
print(str(program)) print(str(program))
......
...@@ -34,10 +34,12 @@ class TestDetectionMAPOp(OpTest): ...@@ -34,10 +34,12 @@ class TestDetectionMAPOp(OpTest):
'int32') 'int32')
self.true_pos = np.array(self.true_pos).astype('float32') self.true_pos = np.array(self.true_pos).astype('float32')
self.false_pos = np.array(self.false_pos).astype('float32') self.false_pos = np.array(self.false_pos).astype('float32')
self.has_state = np.array([1]).astype('int32')
self.inputs = { self.inputs = {
'Label': (self.label, self.label_lod), 'Label': (self.label, self.label_lod),
'DetectRes': (self.detect, self.detect_lod), 'DetectRes': (self.detect, self.detect_lod),
'HasState': self.has_state,
'PosCount': self.class_pos_count, 'PosCount': self.class_pos_count,
'TruePos': (self.true_pos, self.true_pos_lod), 'TruePos': (self.true_pos, self.true_pos_lod),
'FalsePos': (self.false_pos, self.false_pos_lod) 'FalsePos': (self.false_pos, self.false_pos_lod)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册