Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a135fec1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a135fec1
编写于
5月 07, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into gen_nccl_id_op
上级
0f86397d
f43b71b2
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
435 addition
and
72 deletion
+435
-72
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+22
-6
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/io_converter.cc
paddle/fluid/inference/tensorrt/io_converter.cc
+57
-0
paddle/fluid/inference/tensorrt/io_converter.h
paddle/fluid/inference/tensorrt/io_converter.h
+66
-0
paddle/fluid/inference/tensorrt/test_io_converter.cc
paddle/fluid/inference/tensorrt/test_io_converter.cc
+53
-0
paddle/fluid/inference/tests/book/CMakeLists.txt
paddle/fluid/inference/tests/book/CMakeLists.txt
+5
-0
paddle/fluid/inference/utils/singleton.h
paddle/fluid/inference/utils/singleton.h
+73
-0
paddle/fluid/operators/detail/send_recv.proto
paddle/fluid/operators/detail/send_recv.proto
+4
-0
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+8
-0
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+21
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+3
-0
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+26
-9
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+8
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+15
-6
python/paddle/fluid/metrics.py
python/paddle/fluid/metrics.py
+14
-12
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+2
-1
python/setup.py.in
python/setup.py.in
+2
-1
tools/timeline.py
tools/timeline.py
+54
-36
未找到文件。
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
a135fec1
...
...
@@ -80,6 +80,8 @@ parser.add_argument(
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--profile"
,
action
=
'store_true'
,
help
=
"If set, profile a few steps."
)
# Flags for defining the tf.train.Server
parser
.
add_argument
(
...
...
@@ -183,8 +185,8 @@ def main():
start_time
=
time
.
time
()
num_samples
=
0
train_pass_acc
.
reset
()
for
batch_id
,
data
in
enumerate
(
train_reader
()):
ts
=
time
.
time
()
def
run_step
(
batch_id
,
data
):
img_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
].
reshape
(
data_shape
),
data
)).
astype
(
"float32"
)
...
...
@@ -196,13 +198,27 @@ def main():
feed
=
{
"pixel"
:
img_data
,
"label"
:
y_data
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size
])
return
loss
,
acc
,
b_size
if
args
.
profile
and
args
.
task_index
==
0
:
# warmup.
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>
5
:
break
run_step
(
batch_id
,
data
)
with
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile_vgg'
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>
5
:
break
run_step
(
batch_id
,
data
)
for
batch_id
,
data
in
enumerate
(
train_reader
()):
ts
=
time
.
time
()
loss
,
acc
,
b_size
=
run_step
(
batch_id
,
data
)
iters
+=
1
num_samples
+=
len
(
data
)
train_pass_acc
.
add
(
value
=
acc
,
weight
=
b_size
)
print
(
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s "
%
(
args
.
task_index
,
pass_id
,
iters
,
loss
,
acc
,
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
)
# The accuracy is the accumulation of batches, but not the current batch.
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
a135fec1
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda
)
nv_test
(
test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor
)
set
(
ENGINE_FILE
${
CMAKE_CURRENT_SOURCE_DIR
}
/engine.cc
)
add_subdirectory
(
convert
)
paddle/fluid/inference/tensorrt/io_converter.cc
0 → 100644
浏览文件 @
a135fec1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
using
platform
::
is_gpu_place
;
using
platform
::
is_cpu_place
;
class
DefaultInputConverter
:
public
EngineInputConverter
{
public:
DefaultInputConverter
()
{}
// NOTE out is GPU memory.
virtual
void
operator
()(
const
LoDTensor
&
in
,
void
*
out
,
size_t
max_size
)
override
{
PADDLE_ENFORCE
(
out
!=
nullptr
);
PADDLE_ENFORCE_LE
(
in
.
memory_size
(),
max_size
);
const
auto
&
place
=
in
.
place
();
if
(
is_cpu_place
(
place
))
{
PADDLE_ENFORCE
(
stream_
!=
nullptr
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
out
,
in
.
data
<
float
>
(),
in
.
memory_size
(),
cudaMemcpyHostToDevice
,
*
stream_
));
}
else
if
(
is_gpu_place
(
place
))
{
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
out
,
in
.
data
<
float
>
(),
in
.
memory_size
(),
cudaMemcpyHostToHost
,
*
stream_
));
}
else
{
PADDLE_THROW
(
"Unknown device for converter"
);
}
cudaStreamSynchronize
(
*
stream_
);
}
};
REGISTER_TENSORRT_INPUT_CONVERTER
(
mul
,
DefaultInputConverter
);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/io_converter.h
0 → 100644
浏览文件 @
a135fec1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
using
framework
::
LoDTensor
;
/*
* Convert Input from Fluid to an Engine.
* TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
* most cases just need to copy the data.
*/
class
EngineInputConverter
{
public:
EngineInputConverter
()
{}
virtual
void
operator
()(
const
LoDTensor
&
in
,
void
*
out
,
size_t
max_size
)
{}
void
SetStream
(
cudaStream_t
*
stream
)
{
stream_
=
stream
;
}
static
void
Run
(
const
std
::
string
&
in_op_type
,
const
LoDTensor
&
in
,
void
*
out
,
size_t
max_size
,
cudaStream_t
*
stream
)
{
PADDLE_ENFORCE
(
stream
!=
nullptr
);
auto
*
converter
=
Registry
<
EngineInputConverter
>::
Lookup
(
in_op_type
);
PADDLE_ENFORCE_NOT_NULL
(
converter
);
converter
->
SetStream
(
stream
);
(
*
converter
)(
in
,
out
,
max_size
);
}
virtual
~
EngineInputConverter
()
{}
protected:
cudaStream_t
*
stream_
{
nullptr
};
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
struct trt_input_##in_op_type__##_converter { \
trt_input_##in_op_type__##_converter() { \
::paddle::inference::Registry<EngineInputConverter>::Register< \
Converter__>(#in_op_type__); \
} \
}; \
trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
paddle/fluid/inference/tensorrt/test_io_converter.cc
0 → 100644
浏览文件 @
a135fec1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
EngineInputConverterTester
:
public
::
testing
::
Test
{
public:
void
SetUp
()
override
{
tensor
.
Resize
({
10
,
10
});
}
framework
::
LoDTensor
tensor
;
};
TEST_F
(
EngineInputConverterTester
,
DefaultCPU
)
{
void
*
buffer
;
tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
ASSERT_EQ
(
cudaMalloc
(
&
buffer
,
tensor
.
memory_size
()),
0
);
cudaStream_t
stream
;
EngineInputConverter
::
Run
(
"mul"
,
tensor
,
buffer
,
tensor
.
memory_size
(),
&
stream
);
}
TEST_F
(
EngineInputConverterTester
,
DefaultGPU
)
{
void
*
buffer
;
tensor
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
());
ASSERT_EQ
(
cudaMalloc
(
&
buffer
,
tensor
.
memory_size
()),
0
);
cudaStream_t
stream
;
EngineInputConverter
::
Run
(
"mul"
,
tensor
,
buffer
,
tensor
.
memory_size
(),
&
stream
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/book/CMakeLists.txt
浏览文件 @
a135fec1
...
...
@@ -24,6 +24,11 @@ function(inference_test TARGET_NAME)
endforeach
()
endfunction
(
inference_test
)
####################
# Inference tests here depend on fluid/tests/book. If users want to run
# individual test with ctest, they need to run tests in fluid/tests/book
# first to generate saved model.
####################
# This unittest is buggy!
#inference_test(fit_a_line)
inference_test
(
image_classification ARGS vgg resnet
)
...
...
paddle/fluid/inference/utils/singleton.h
0 → 100644
浏览文件 @
a135fec1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
// NOTE not thread-safe.
template
<
typename
T
>
struct
Singleton
{
static
T
&
Global
()
{
static
T
*
x
=
new
T
;
return
*
x
;
}
Singleton
()
=
delete
;
Singleton
&
operator
=
(
const
Singleton
&
)
=
delete
;
};
/*
* An registor for any type.
* NOTE not thread-safe.
*/
template
<
typename
ItemParent
>
struct
Registry
{
static
Registry
&
Global
()
{
static
auto
*
x
=
new
Registry
<
ItemParent
>
;
return
*
x
;
}
template
<
typename
ItemChild
>
static
void
Register
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
items_
.
count
(
name
),
0
);
items_
[
name
]
=
new
ItemChild
;
}
static
ItemParent
*
Lookup
(
const
std
::
string
&
name
)
{
auto
it
=
items_
.
find
(
name
);
if
(
it
==
items_
.
end
())
return
nullptr
;
return
it
->
second
;
}
~
Registry
()
{
for
(
auto
&
item
:
items_
)
{
delete
item
.
second
;
}
}
private:
Registry
()
=
default
;
static
std
::
unordered_map
<
std
::
string
,
ItemParent
*>
items_
;
};
template
<
typename
ItemParent
>
std
::
unordered_map
<
std
::
string
,
ItemParent
*>
Registry
<
ItemParent
>::
items_
;
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/detail/send_recv.proto
浏览文件 @
a135fec1
...
...
@@ -70,6 +70,10 @@ message VariableMessage {
bytes
rows
=
9
;
// Look up table block execution output variable name.
string
out_varname
=
10
;
// If true, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from true to false.
bool
profile
=
11
;
}
message
VoidMessage
{}
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
a135fec1
...
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -48,6 +49,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
*
payload
=
nullptr
;
size_t
payload_size
=
0
;
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
e
.
WriteBool
(
VarMsg
::
kProfileFieldNumber
,
platform
::
IsProfileEnabled
());
}
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
a135fec1
...
...
@@ -20,6 +20,7 @@
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
...
@@ -446,7 +447,26 @@ int VariableResponse::Parse(Source* source) {
meta_
.
set_out_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
bool
profiling
;
if
(
!
input
.
ReadRaw
(
reinterpret_cast
<
void
*>
(
&
profiling
),
1
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
int64_t
listener_id
=
platform
::
ListenerId
();
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
!
profiling
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"/tmp/profile_ps_%lld"
,
listener_id
));
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
a135fec1
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
platform
::
SetProfileListener
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
a135fec1
...
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <iomanip>
#include <limits>
#include <map>
#include <mutex> // NOLINT
#include <random>
#include <string>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
...
...
@@ -33,6 +36,9 @@ namespace platform {
struct
EventList
;
static
int64_t
profiler_lister_id
=
0
;
static
bool
should_send_profile_state
=
false
;
// The profiler state, the initial value is ProfilerState::kDisabled
static
ProfilerState
g_state
=
ProfilerState
::
kDisabled
;
// The thread local event list only can be accessed by the specific thread
...
...
@@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE
(
state
!=
ProfilerState
::
kDisabled
,
"Can't enbale profling, since the input state is "
,
"ProfilerState::kDisabled"
);
PADDLE_ENFORCE
(
g_state
==
ProfilerState
::
kDisabled
,
"The profiling state should be disabled when calling "
,
"EnableProfiler."
);
if
(
state
==
g_state
)
{
return
;
}
g_state
=
state
;
if
(
g_state
==
ProfilerState
::
kAll
)
{
should_send_profile_state
=
true
;
GetDeviceTracer
()
->
Enable
();
}
#ifdef PADDLE_WITH_CUDA
if
(
g_state
==
ProfilerState
::
kCUDA
)
{
// Generate some dummy events first to reduce the startup overhead.
...
...
@@ -435,8 +440,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
void
DisableProfiler
(
EventSortingKey
sorted_key
,
const
std
::
string
&
profile_path
)
{
PADDLE_ENFORCE
(
g_state
!=
ProfilerState
::
kDisabled
,
"Can't disable profiling, since it's not starting."
);
if
(
g_state
==
ProfilerState
::
kDisabled
)
return
;
// Mark the profiling stop.
Mark
(
"_stop_profiler_"
,
nullptr
);
...
...
@@ -444,12 +448,25 @@ void DisableProfiler(EventSortingKey sorted_key,
ParseEvents
(
all_events
,
sorted_key
);
ResetProfiler
();
DeviceTracer
*
tracer
=
GetDeviceTracer
();
if
(
g_state
==
ProfilerState
::
kAll
&&
tracer
&&
tracer
->
IsEnabled
())
{
if
(
tracer
->
IsEnabled
())
{
tracer
->
Disable
();
tracer
->
GenProfile
(
profile_path
);
}
g_state
=
ProfilerState
::
kDisabled
;
should_send_profile_state
=
true
;
}
bool
IsProfileEnabled
()
{
return
g_state
!=
ProfilerState
::
kDisabled
;
}
bool
ShouldSendProfileState
()
{
return
should_send_profile_state
;
}
void
SetProfileListener
()
{
std
::
mt19937
rng
;
rng
.
seed
(
std
::
random_device
()());
std
::
uniform_int_distribution
<
std
::
mt19937
::
result_type
>
dist6
(
1
,
std
::
numeric_limits
<
int64_t
>::
max
());
profiler_lister_id
=
dist6
(
rng
);
}
int64_t
ListenerId
()
{
return
profiler_lister_id
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/profiler.h
浏览文件 @
a135fec1
...
...
@@ -114,5 +114,13 @@ void ResetProfiler();
void
DisableProfiler
(
EventSortingKey
sorted_key
,
const
std
::
string
&
profile_path
);
// Test if the profiler is currently enabled.
bool
IsProfileEnabled
();
// Whether the trainer should send profiling state to PS.
bool
ShouldSendProfileState
();
// Mark current process as PS by assigning a lister id.
void
SetProfileListener
();
int64_t
ListenerId
();
}
// namespace platform
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
a135fec1
...
...
@@ -60,6 +60,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
'io'
,
'initializer'
,
'layers'
,
'transpiler'
'nets'
,
'optimizer'
,
'learning_rate_decay'
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
a135fec1
...
...
@@ -1042,13 +1042,14 @@ class Program(object):
Returns(Program):
The cloned Program object.
"""
p
=
Program
()
if
for_test
:
p
.
desc
=
core
.
inference_optimize
(
self
.
desc
)
p
=
self
.
inference_optimize
(
)
else
:
p
=
Program
()
p
.
desc
=
core
.
ProgramDesc
(
self
.
desc
)
p
.
blocks
=
[
Block
(
p
,
i
)
for
i
in
xrange
(
self
.
desc
.
num_blocks
())]
p
.
sync_with_cpp
()
p
.
copy_param_info_from
(
self
)
return
p
...
...
@@ -1087,8 +1088,16 @@ class Program(object):
return
res
def
inference_optimize
(
self
):
# this is an alternative implement before
# core.inference_optimize being fixed.
res
=
Program
()
res
.
desc
=
core
.
inference_optimize
(
self
.
desc
)
res
.
desc
=
core
.
ProgramDesc
(
self
.
desc
)
for
i
in
xrange
(
res
.
desc
.
num_blocks
()):
block
=
res
.
desc
.
block
(
i
)
for
j
in
xrange
(
block
.
op_size
()):
op
=
block
.
op
(
j
)
if
op
.
has_attr
(
'is_test'
):
op
.
set_attr
(
'is_test'
,
True
)
res
.
blocks
=
[
Block
(
res
,
i
)
for
i
in
xrange
(
res
.
desc
.
num_blocks
())]
res
.
sync_with_cpp
()
return
res
...
...
python/paddle/fluid/metrics.py
浏览文件 @
a135fec1
...
...
@@ -251,7 +251,7 @@ class EditDistance(MetricBase):
self
.
instance_error
+=
seq_num
-
seq_right_count
self
.
total_distance
+=
total_distance
def
eval
():
def
eval
(
self
):
if
self
.
seq_num
==
0
:
raise
ValueError
(
"There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance."
...
...
@@ -280,6 +280,7 @@ class DetectionMAP(MetricBase):
super
(
DetectionMAP
,
self
).
__init__
(
name
)
# the current map value
self
.
value
=
.
0
self
.
weight
=
.
0
def
update
(
self
,
value
,
weight
):
if
not
_is_number_or_matrix_
(
value
):
...
...
@@ -340,8 +341,8 @@ class Auc(MetricBase):
raise
ValueError
(
"The 'predictions' must be a numpy ndarray."
)
kepsilon
=
1e-7
# to account for floating point imprecisions
thresholds
=
[(
i
+
1
)
*
1.0
/
(
num_thresholds
-
1
)
for
i
in
range
(
num_thresholds
-
2
)]
thresholds
=
[(
i
+
1
)
*
1.0
/
(
self
.
_
num_thresholds
-
1
)
for
i
in
range
(
self
.
_
num_thresholds
-
2
)]
thresholds
=
[
0.0
-
kepsilon
]
+
thresholds
+
[
1.0
+
kepsilon
]
# caculate TP, FN, TN, FP count
...
...
@@ -358,19 +359,20 @@ class Auc(MetricBase):
fp
+=
1
else
:
tn
+=
1
tp_list
[
idx_thresh
]
+=
tp
fn_list
[
idx_thresh
]
+=
fn
tn_list
[
idx_thresh
]
+=
tn
fp_list
[
idx_thresh
]
+=
fp
self
.
tp_list
[
idx_thresh
]
+=
tp
self
.
fn_list
[
idx_thresh
]
+=
fn
self
.
tn_list
[
idx_thresh
]
+=
tn
self
.
fp_list
[
idx_thresh
]
+=
fp
def
eval
(
self
):
epsilon
=
self
.
_epsilon
num_thresholds
=
self
.
_num_thresholds
tpr
=
(
tp_list
.
astype
(
"float32"
)
+
epsilon
)
/
(
tp_list
+
fn_list
+
epsilon
)
fpr
=
fp_list
.
astype
(
"float32"
)
/
(
fp_list
+
tn_list
+
epsilon
)
rec
=
(
tp_list
.
astype
(
"float32"
)
+
epsilon
)
/
(
tp_list
+
fp_list
+
epsilon
)
tpr
=
(
self
.
tp_list
.
astype
(
"float32"
)
+
epsilon
)
/
(
self
.
tp_list
+
self
.
fn_list
+
epsilon
)
fpr
=
self
.
fp_list
.
astype
(
"float32"
)
/
(
self
.
fp_list
+
self
.
tn_list
+
epsilon
)
rec
=
(
self
.
tp_list
.
astype
(
"float32"
)
+
epsilon
)
/
(
self
.
tp_list
+
self
.
fp_list
+
epsilon
)
x
=
fpr
[:
num_thresholds
-
1
]
-
fpr
[
1
:]
y
=
(
tpr
[:
num_thresholds
-
1
]
+
tpr
[
1
:])
/
2.0
...
...
python/paddle/fluid/trainer.py
浏览文件 @
a135fec1
...
...
@@ -19,10 +19,11 @@ import executor
import
data_feeder
import
contextlib
import
io
import
transpiler
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import
optimizer
as
opt_module
import
distribute_transpiler
from
transpiler
import
distribute_transpiler
__all__
=
[
'Trainer'
,
...
...
python/setup.py.in
浏览文件 @
a135fec1
...
...
@@ -68,7 +68,8 @@ packages=['paddle',
'paddle.fluid',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.layers']
'paddle.fluid.layers',
'paddle.fluid.transpiler']
if '${WITH_FLUID_ONLY}'== 'OFF':
packages+=['paddle.proto',
...
...
tools/timeline.py
浏览文件 @
a135fec1
...
...
@@ -22,7 +22,11 @@ import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--profile_path'
,
type
=
str
,
default
=
''
,
help
=
'Input profile file name.'
)
'--profile_path'
,
type
=
str
,
default
=
''
,
help
=
'Input profile file name. If there are multiple file, the format '
'should be trainer1=file1,trainer2=file2,ps=file3'
)
parser
.
add_argument
(
'--timeline_path'
,
type
=
str
,
default
=
''
,
help
=
'Output timeline file name.'
)
args
=
parser
.
parse_args
()
...
...
@@ -108,8 +112,8 @@ class _ChromeTraceFormatter(object):
class
Timeline
(
object
):
def
__init__
(
self
,
profile_
pb
):
self
.
_profile_
pb
=
profile_pb
def
__init__
(
self
,
profile_
dict
):
self
.
_profile_
dict
=
profile_dict
self
.
_pid
=
0
self
.
_devices
=
dict
()
self
.
_chrome_trace
=
_ChromeTraceFormatter
()
...
...
@@ -120,27 +124,29 @@ class Timeline(object):
return
cur_pid
def
_allocate_pids
(
self
):
for
event
in
self
.
_profile_pb
.
events
:
for
k
,
profile_pb
in
self
.
_profile_dict
.
iteritems
():
for
event
in
profile_pb
.
events
:
if
event
.
type
==
profiler_pb2
.
Event
.
CPU
:
if
(
event
.
device_id
,
"CPU"
)
not
in
self
.
_devices
:
if
(
k
,
event
.
device_id
,
"CPU"
)
not
in
self
.
_devices
:
pid
=
self
.
_allocate_pid
()
self
.
_devices
[(
event
.
device_id
,
"CPU"
)]
=
pid
self
.
_chrome_trace
.
emit_pid
(
"
cpu:block:%d"
%
(
event
.
device_id
),
pid
)
self
.
_devices
[(
k
,
event
.
device_id
,
"CPU"
)]
=
pid
self
.
_chrome_trace
.
emit_pid
(
"%s:
cpu:block:%d"
%
(
k
,
event
.
device_id
),
pid
)
elif
event
.
type
==
profiler_pb2
.
Event
.
GPUKernel
:
if
(
event
.
device_id
,
"GPUKernel"
)
not
in
self
.
_devices
:
if
(
k
,
event
.
device_id
,
"GPUKernel"
)
not
in
self
.
_devices
:
pid
=
self
.
_allocate_pid
()
self
.
_devices
[(
event
.
device_id
,
"GPUKernel"
)]
=
pid
self
.
_chrome_trace
.
emit_pid
(
"gpu:%d"
%
(
event
.
device_id
),
pid
)
self
.
_devices
[(
k
,
event
.
device_id
,
"GPUKernel"
)]
=
pid
self
.
_chrome_trace
.
emit_pid
(
"%s:gpu:%d"
%
(
k
,
event
.
device_id
),
pid
)
def
_allocate_events
(
self
):
for
event
in
self
.
_profile_pb
.
events
:
for
k
,
profile_pb
in
self
.
_profile_dict
.
iteritems
():
for
event
in
profile_pb
.
events
:
if
event
.
type
==
profiler_pb2
.
Event
.
CPU
:
type
=
"CPU"
elif
event
.
type
==
profiler_pb2
.
Event
.
GPUKernel
:
type
=
"GPUKernel"
pid
=
self
.
_devices
[(
event
.
device_id
,
type
)]
pid
=
self
.
_devices
[(
k
,
event
.
device_id
,
type
)]
args
=
{
'name'
:
event
.
name
}
if
event
.
memcopy
.
bytes
>
0
:
args
=
{
'mem_bytes'
:
event
.
memcopy
.
bytes
}
...
...
@@ -163,11 +169,23 @@ timeline_path = '/tmp/timeline'
if
args
.
timeline_path
:
timeline_path
=
args
.
timeline_path
with
open
(
profile_path
,
'r'
)
as
f
:
profile_paths
=
profile_path
.
split
(
','
)
profile_dict
=
dict
()
if
len
(
profile_path
)
==
1
:
with
open
(
profile_path
,
'r'
)
as
f
:
profile_s
=
f
.
read
()
profile_pb
=
profiler_pb2
.
Profile
()
profile_pb
.
ParseFromString
(
profile_s
)
profile_dict
[
'trainer'
]
=
profile_pb
else
:
for
profile_path
in
profile_paths
:
k
,
v
=
profile_path
.
split
(
'='
)
with
open
(
v
,
'r'
)
as
f
:
profile_s
=
f
.
read
()
profile_pb
=
profiler_pb2
.
Profile
()
profile_pb
.
ParseFromString
(
profile_s
)
profile_dict
[
k
]
=
profile_pb
tl
=
Timeline
(
profile_
pb
)
tl
=
Timeline
(
profile_
dict
)
with
open
(
timeline_path
,
'w'
)
as
f
:
f
.
write
(
tl
.
generate_chrome_trace
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录