Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
23433def
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
23433def
编写于
6月 07, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap_memcpy_with_dist
上级
15913d92
ff9b1a0f
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
422 addition
and
221 deletion
+422
-221
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-0
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+2
-2
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+4
-0
paddle/fluid/framework/details/graph_builder_factory.cc
paddle/fluid/framework/details/graph_builder_factory.cc
+47
-0
paddle/fluid/framework/details/graph_builder_factory.h
paddle/fluid/framework/details/graph_builder_factory.h
+67
-0
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+1
-9
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+1
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.h
paddle/fluid/framework/details/reduce_op_handle.h
+2
-2
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+0
-58
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+1
-2
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+83
-0
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+67
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+15
-14
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-1
paddle/fluid/framework/tensor.cc
paddle/fluid/framework/tensor.cc
+98
-1
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+15
-26
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+0
-97
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+6
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+5
-5
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
23433def
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method
)
framework_proto glog lod_rank_table feed_fetch_method
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
multi_devices_graph_builder
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
graph_builder_factory
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
23433def
...
@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
...
@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
...
@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
cc_library
(
graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
23433def
...
@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
...
@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
void
RunImpl
()
override
;
void
RunImpl
()
override
;
private:
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
23433def
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#pragma once
#pragma once
#include <string>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -29,6 +31,8 @@ struct BuildStrategy {
...
@@ -29,6 +31,8 @@ struct BuildStrategy {
ReduceStrategy
reduce_
{
ReduceStrategy
::
kAllReduce
};
ReduceStrategy
reduce_
{
ReduceStrategy
::
kAllReduce
};
GradientScaleStrategy
gradient_scale_
{
GradientScaleStrategy
::
kCoeffNumDevice
};
GradientScaleStrategy
gradient_scale_
{
GradientScaleStrategy
::
kCoeffNumDevice
};
std
::
string
debug_graphviz_path_
{
""
};
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/graph_builder_factory.cc
0 → 100644
浏览文件 @
23433def
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
#ifdef PADDLE_WITH_CUDA
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
nccl_ctxs_
,
strategy_
)
#else
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
strategy_
)
#endif
);
// NOLINT
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
new
GraphvizSSAGraphPrinter
());
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
)));
}
return
res
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/graph_builder_factory.h
0 → 100644
浏览文件 @
23433def
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
SSAGraphBuilderFactory
{
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
nccl_ctxs_
=
nccl_ctxs
;
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
string
loss_var_name_
;
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
23433def
...
@@ -30,10 +30,6 @@
...
@@ -30,10 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
#endif
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -149,6 +145,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -149,6 +145,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
VLOG
(
3
)
<<
"Building ...."
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
all_vars
[
var
->
Name
()]
=
var
;
all_vars
[
var
->
Name
()]
=
var
;
...
@@ -315,11 +312,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -315,11 +312,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
*/
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
std
::
ofstream
fout
(
FLAGS_ssa_graph_path
);
PrintGraphviz
(
*
graph
,
fout
);
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
}
}
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
23433def
...
@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
int
GetRemoteVarDevice
(
const
std
::
string
&
var_name
)
const
{
int
GetRemoteVarDevice
Id
(
const
std
::
string
&
var_name
)
const
override
{
auto
got
=
remote_vars_devices_
.
find
(
var_name
);
auto
got
=
remote_vars_devices_
.
find
(
var_name
);
if
(
got
!=
remote_vars_devices_
.
end
())
{
if
(
got
!=
remote_vars_devices_
.
end
())
{
return
got
->
second
;
return
got
->
second
;
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
浏览文件 @
23433def
...
@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
...
@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void
RunImpl
()
override
;
void
RunImpl
()
override
;
private:
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
const
platform
::
NCCLContextMap
&
nccl_ctxs_
;
const
platform
::
NCCLContextMap
&
nccl_ctxs_
;
};
};
...
...
paddle/fluid/framework/details/reduce_op_handle.h
浏览文件 @
23433def
...
@@ -32,8 +32,8 @@ namespace framework {
...
@@ -32,8 +32,8 @@ namespace framework {
namespace
details
{
namespace
details
{
struct
ReduceOpHandle
:
public
OpHandleBase
{
struct
ReduceOpHandle
:
public
OpHandleBase
{
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
23433def
...
@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
...
@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
template
<
typename
Callback
>
void
IterAllVar
(
const
SSAGraph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
vars_
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
}
}
}
for
(
auto
&
var
:
graph
.
dep_vars_
)
{
callback
(
*
var
);
}
}
void
SSAGraphBuilder
::
PrintGraphviz
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
{
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
IterAllVar
(
graph
,
[
&
](
const
VarHandleBase
&
var
)
{
auto
*
var_ptr
=
&
var
;
auto
*
var_handle_ptr
=
dynamic_cast
<
const
VarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
DummyVarHandle
*>
(
var_ptr
);
size_t
cur_var_id
=
var_id
++
;
vars
[
var_ptr
]
=
cur_var_id
;
if
(
var_handle_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
var_handle_ptr
->
name_
<<
"
\\
n"
<<
var_handle_ptr
->
place_
<<
"
\\
n"
<<
var_handle_ptr
->
version_
<<
"
\"
]"
<<
std
::
endl
;
}
else
if
(
dummy_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
dummy
\"
]"
<<
std
::
endl
;
}
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
ops_
)
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
op
->
Inputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
op
->
Outputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
SSAGraph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
SSAGraph
*
graph
)
{
for
(
auto
&
op
:
graph
->
ops_
)
{
for
(
auto
&
op
:
graph
->
ops_
)
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
23433def
...
@@ -30,6 +30,7 @@ class SSAGraphBuilder {
...
@@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder
()
{}
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
=
0
;
virtual
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
=
0
;
virtual
int
GetRemoteVarDeviceId
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
...
@@ -55,8 +56,6 @@ class SSAGraphBuilder {
...
@@ -55,8 +56,6 @@ class SSAGraphBuilder {
const
platform
::
Place
&
place
,
size_t
place_offset
);
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
SSAGraph
*
graph
);
static
void
AddOutputToLeafOps
(
SSAGraph
*
graph
);
static
void
PrintGraphviz
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
);
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
0 → 100644
浏览文件 @
23433def
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
SSAGraph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
vars_
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
}
}
}
for
(
auto
&
var
:
graph
.
dep_vars_
)
{
callback
(
*
var
);
}
}
void
GraphvizSSAGraphPrinter
::
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
{
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
IterAllVar
(
graph
,
[
&
](
const
VarHandleBase
&
var
)
{
auto
*
var_ptr
=
&
var
;
auto
*
var_handle_ptr
=
dynamic_cast
<
const
VarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
DummyVarHandle
*>
(
var_ptr
);
size_t
cur_var_id
=
var_id
++
;
vars
[
var_ptr
]
=
cur_var_id
;
if
(
var_handle_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
var_handle_ptr
->
name_
<<
"
\\
n"
<<
var_handle_ptr
->
place_
<<
"
\\
n"
<<
var_handle_ptr
->
version_
<<
"
\"
]"
<<
std
::
endl
;
}
else
if
(
dummy_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
dummy
\"
]"
<<
std
::
endl
;
}
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
ops_
)
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
op
->
Inputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
op
->
Outputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_printer.h
0 → 100644
浏览文件 @
23433def
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iosfwd>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
SSAGraph
;
class
SSAGraphPrinter
{
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
public:
void
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
{
auto
graph
=
builder_
->
Build
(
program
);
printer_
->
Print
(
*
graph
,
stream_ref_
);
return
graph
;
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
23433def
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -101,23 +102,23 @@ ParallelExecutor::ParallelExecutor(
...
@@ -101,23 +102,23 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
#ifdef PADDLE_WITH_CUDA
details
::
SSAGraphBuilderFactory
builder_factory
(
builder_
.
reset
(
new
details
::
MultiDevSSAGraphBuilder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
(),
build_strategy
));
#else
builder_
.
reset
(
new
details
::
MultiDevSSAGraphBuilder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
)
)
;
build_strategy
);
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
#endif
#endif
auto
graph
=
builder_
->
Build
(
main_program
);
builder_
.
reset
(
builder_factory
.
Create
().
get
());
if
(
builder_
.
get
()
==
nullptr
)
{
VLOG
(
3
)
<<
"builder is null."
;
}
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
builder_
->
Build
(
main_program
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
...
@@ -155,8 +156,8 @@ void ParallelExecutor::BCastParamsToGPUs(
...
@@ -155,8 +156,8 @@ void ParallelExecutor::BCastParamsToGPUs(
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
place
);
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
place
);
if
(
builder_
.
get
()
!=
nullptr
&&
if
(
builder_
.
get
()
!=
nullptr
&&
builder_
->
GetRemoteVarDevice
(
var
)
!=
-
1
)
{
builder_
->
GetRemoteVarDevice
Id
(
var
)
!=
-
1
)
{
int
place_id
=
builder_
->
GetRemoteVarDevice
(
var
);
int
place_id
=
builder_
->
GetRemoteVarDevice
Id
(
var
);
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
place_id
,
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
place_id
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
else
{
}
else
{
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
23433def
...
@@ -70,7 +70,7 @@ class ParallelExecutor {
...
@@ -70,7 +70,7 @@ class ParallelExecutor {
private:
private:
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
MultiDev
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
details
::
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/tensor.cc
浏览文件 @
23433def
...
@@ -15,5 +15,102 @@ limitations under the License. */
...
@@ -15,5 +15,102 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
namespace
framework
{
extern
size_t
SizeOfType
(
std
::
type_index
type
);
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE_LE
(
numel
()
*
SizeOfType
(
type
()),
memory_size
(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.
\n
"
"or maybe the required data-type mismatches the data already stored."
);
}
size_t
Tensor
::
memory_size
()
const
{
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
}
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
PADDLE_ENFORCE_GE
(
numel
(),
0
,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first."
);
int64_t
size
=
numel
()
*
SizeOfType
(
type
);
/* some versions of boost::variant don't have operator!= */
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
||
holder_
->
size
()
<
size
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
,
type
));
}
else
if
(
platform
::
is_gpu_place
(
place
)
||
platform
::
is_cuda_pinned_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode."
);
}
#else
if
(
platform
::
is_gpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CUDAPlace
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
size
,
type
));
}
else
if
(
platform
::
is_cuda_pinned_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CUDAPinnedPlace
>
(
boost
::
get
<
platform
::
CUDAPinnedPlace
>
(
place
),
size
,
type
));
}
}
#endif
offset_
=
0
;
}
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
this
->
holder_
!=
nullptr
,
"Cannot invoke mutable data if current hold nothing."
);
return
mutable_data
(
place
,
holder_
->
type
());
}
Tensor
&
Tensor
::
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
check_memory_size
();
*
this
=
src
;
return
*
this
;
}
Tensor
Tensor
::
Slice
(
int
begin_idx
,
int
end_idx
)
const
{
check_memory_size
();
PADDLE_ENFORCE_GE
(
begin_idx
,
0
,
"The start row index must be greater than 0."
);
PADDLE_ENFORCE_LE
(
end_idx
,
dims_
[
0
],
"The end row index is out of bound."
);
PADDLE_ENFORCE_LT
(
begin_idx
,
end_idx
,
"The start row index must be lesser than the end row index."
);
if
(
dims_
[
0
]
==
1
)
{
return
*
this
;
}
else
{
size_t
base
=
numel
()
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
set_layout
(
layout_
);
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
SizeOfType
(
type
());
return
dst
;
}
}
Tensor
&
Tensor
::
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
return
*
this
;
}
const
DDim
&
Tensor
::
dims
()
const
{
return
dims_
;
}
int64_t
Tensor
::
numel
()
const
{
return
product
(
dims_
);
}
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/tensor.h
浏览文件 @
23433def
...
@@ -54,26 +54,24 @@ class Tensor {
...
@@ -54,26 +54,24 @@ class Tensor {
/*! Return a pointer to mutable memory block. */
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
template
<
typename
T
>
inline
T
*
data
();
T
*
data
();
/*! Return a pointer to constant memory block. */
/*! Return a pointer to constant memory block. */
template
<
typename
T
>
template
<
typename
T
>
inline
const
T
*
data
()
const
;
const
T
*
data
()
const
;
inline
bool
IsInitialized
()
const
;
bool
IsInitialized
()
const
;
inline
void
switch_place
(
platform
::
Place
new_place
);
/**
/**
* @brief Return a pointer to mutable memory block.
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
* @note If not exist, then allocation.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
inline
T
*
mutable_data
(
platform
::
Place
place
);
T
*
mutable_data
(
platform
::
Place
place
);
inline
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
);
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
);
inline
void
*
mutable_data
(
platform
::
Place
place
);
void
*
mutable_data
(
platform
::
Place
place
);
/**
/**
* @brief Return a pointer to mutable memory block.
* @brief Return a pointer to mutable memory block.
...
@@ -84,19 +82,19 @@ class Tensor {
...
@@ -84,19 +82,19 @@ class Tensor {
* @note If not exist, then allocation.
* @note If not exist, then allocation.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
);
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
);
/*! Return the dimensions of the memory block. */
/*! Return the dimensions of the memory block. */
inline
const
DDim
&
dims
()
const
;
const
DDim
&
dims
()
const
;
/*! Return the numel of the memory block. */
/*! Return the numel of the memory block. */
in
line
in
t64_t
numel
()
const
;
int64_t
numel
()
const
;
/*! Resize the dimensions of the memory block. */
/*! Resize the dimensions of the memory block. */
inline
Tensor
&
Resize
(
const
DDim
&
dims
);
Tensor
&
Resize
(
const
DDim
&
dims
);
/*! The internal of two tensors share the same memory block. */
/*! The internal of two tensors share the same memory block. */
inline
Tensor
&
ShareDataWith
(
const
Tensor
&
src
);
Tensor
&
ShareDataWith
(
const
Tensor
&
src
);
/**
/**
* @brief Return a sub-tensor of the given tensor.
* @brief Return a sub-tensor of the given tensor.
...
@@ -106,7 +104,7 @@ class Tensor {
...
@@ -106,7 +104,7 @@ class Tensor {
* @param[in] end_idx The index of the end row(exclusive) to slice.
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
* The index number begins from 0.
*/
*/
inline
Tensor
Slice
(
int
begin_idx
,
int
end_idx
)
const
;
Tensor
Slice
(
int
begin_idx
,
int
end_idx
)
const
;
platform
::
Place
place
()
const
{
platform
::
Place
place
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
...
@@ -123,11 +121,11 @@ class Tensor {
...
@@ -123,11 +121,11 @@ class Tensor {
// memory size returns the holding memory size in byte.
// memory size returns the holding memory size in byte.
size_t
memory_size
()
const
;
size_t
memory_size
()
const
;
inline
void
check_memory_size
()
const
;
void
check_memory_size
()
const
;
inline
DataLayout
layout
()
const
{
return
layout_
;
}
DataLayout
layout
()
const
{
return
layout_
;
}
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
private:
private:
/**
/**
...
@@ -210,15 +208,6 @@ class Tensor {
...
@@ -210,15 +208,6 @@ class Tensor {
size_t
offset_
;
size_t
offset_
;
};
};
inline
void
Tensor
::
switch_place
(
platform
::
Place
new_place
)
{
if
(
holder_
->
place
()
==
new_place
)
{
return
;
}
// TODO(tonyyang-svail): do memcpy here.
PADDLE_THROW
(
"Not Implemented"
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
23433def
...
@@ -20,21 +20,6 @@ limitations under the License. */
...
@@ -20,21 +20,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
extern
size_t
SizeOfType
(
std
::
type_index
type
);
inline
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
PADDLE_ENFORCE_LE
(
numel
()
*
SizeOfType
(
type
()),
memory_size
(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.
\n
"
"or maybe the required data-type mismatches the data already stored."
);
}
inline
size_t
Tensor
::
memory_size
()
const
{
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
}
template
<
typename
T
>
template
<
typename
T
>
inline
const
T
*
Tensor
::
data
()
const
{
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
();
check_memory_size
();
...
@@ -73,88 +58,6 @@ inline T* Tensor::mutable_data(platform::Place place) {
...
@@ -73,88 +58,6 @@ inline T* Tensor::mutable_data(platform::Place place) {
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)));
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)));
}
}
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
PADDLE_ENFORCE_GE
(
numel
(),
0
,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first."
);
int64_t
size
=
numel
()
*
SizeOfType
(
type
);
/* some versions of boost::variant don't have operator!= */
if
(
holder_
==
nullptr
||
!
(
holder_
->
place
()
==
place
)
||
holder_
->
size
()
<
size
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
,
type
));
}
else
if
(
platform
::
is_gpu_place
(
place
)
||
platform
::
is_cuda_pinned_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode."
);
}
#else
if
(
platform
::
is_gpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CUDAPlace
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
size
,
type
));
}
else
if
(
platform
::
is_cuda_pinned_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CUDAPinnedPlace
>
(
boost
::
get
<
platform
::
CUDAPinnedPlace
>
(
place
),
size
,
type
));
}
}
#endif
offset_
=
0
;
}
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
this
->
holder_
!=
nullptr
,
"Cannot invoke mutable data if current hold nothing."
);
return
mutable_data
(
place
,
holder_
->
type
());
}
inline
Tensor
&
Tensor
::
ShareDataWith
(
const
Tensor
&
src
)
{
src
.
check_memory_size
();
*
this
=
src
;
return
*
this
;
}
inline
Tensor
Tensor
::
Slice
(
int
begin_idx
,
int
end_idx
)
const
{
check_memory_size
();
PADDLE_ENFORCE_GE
(
begin_idx
,
0
,
"The start row index must be greater than 0."
);
PADDLE_ENFORCE_LE
(
end_idx
,
dims_
[
0
],
"The end row index is out of bound."
);
PADDLE_ENFORCE_LT
(
begin_idx
,
end_idx
,
"The start row index must be lesser than the end row index."
);
if
(
dims_
[
0
]
==
1
)
{
return
*
this
;
}
else
{
size_t
base
=
numel
()
/
dims_
[
0
];
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
set_layout
(
layout_
);
DDim
dst_dims
=
dims_
;
dst_dims
[
0
]
=
end_idx
-
begin_idx
;
dst
.
Resize
(
dst_dims
);
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
SizeOfType
(
type
());
return
dst
;
}
}
inline
Tensor
&
Tensor
::
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
return
*
this
;
}
inline
const
DDim
&
Tensor
::
dims
()
const
{
return
dims_
;
}
inline
int64_t
Tensor
::
numel
()
const
{
return
product
(
dims_
);
}
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
Tensor
res
;
Tensor
res
;
res
.
ShareDataWith
(
src
);
res
.
ShareDataWith
(
src
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
23433def
...
@@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
[](
BuildStrategy
&
self
,
BuildStrategy
::
GradientScaleStrategy
strategy
)
{
BuildStrategy
::
GradientScaleStrategy
strategy
)
{
self
.
gradient_scale_
=
strategy
;
self
.
gradient_scale_
=
strategy
;
})
.
def_property
(
"debug_graphviz_path"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
debug_graphviz_path_
;
},
[](
BuildStrategy
&
self
,
const
std
::
string
&
path
)
{
self
.
debug_graphviz_path_
=
path
;
});
});
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
23433def
...
@@ -1182,19 +1182,19 @@ def conv2d(input,
...
@@ -1182,19 +1182,19 @@ def conv2d(input,
- Input:
- Input:
Input shape:
$(N, C_{in}, H_{in}, W_{in})$
Input shape:
:math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape:
$(C_{out}, C_{in}, H_f, W_f)$
Filter shape:
:math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
- Output:
Output shape:
$(N, C_{out}, H_{out}, W_{out})$
Output shape:
:math:`(N, C_{out}, H_{out}, W_{out})`
Where
Where
.. math::
.. math::
H_{out}&=
\\
frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
\\\\
H_{out}&=
\\
frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
\\\\
W_{out}&=
\\
frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
W_{out}&=
\\
frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
Args:
input(Variable): The input image with [N, C, H, W] format.
input(Variable): The input image with [N, C, H, W] format.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录