Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
aaf818f5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aaf818f5
编写于
2月 23, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/add_fwd_block_id
上级
65058cfb
bf9ed4a9
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
759 addition
and
308 deletion
+759
-308
doc/design/parallel_do.md
doc/design/parallel_do.md
+2
-2
paddle/fluid/framework/block_desc.cc
paddle/fluid/framework/block_desc.cc
+19
-0
paddle/fluid/framework/block_desc.h
paddle/fluid/framework/block_desc.h
+2
-0
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+73
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-1
paddle/fluid/framework/var_desc.cc
paddle/fluid/framework/var_desc.cc
+52
-2
paddle/fluid/framework/var_desc.h
paddle/fluid/framework/var_desc.h
+4
-0
paddle/fluid/framework/var_type.h
paddle/fluid/framework/var_type.h
+6
-0
paddle/fluid/operators/elementwise_add_op.h
paddle/fluid/operators/elementwise_add_op.h
+5
-57
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+256
-2
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+12
-48
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+6
-11
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+11
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-0
python/paddle/trainer_config_helpers/layer_math.py
python/paddle/trainer_config_helpers/layer_math.py
+1
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr
...r_config_helpers/tests/configs/protostr/math_ops.protostr
+1
-2
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+2
-0
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+221
-177
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+72
-0
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+9
-4
未找到文件。
doc/design/parallel_do.md
浏览文件 @
aaf818f5
...
@@ -24,7 +24,7 @@ A vanilla implementation of parallel_do can be shown as the following (`|` means
...
@@ -24,7 +24,7 @@ A vanilla implementation of parallel_do can be shown as the following (`|` means
```
```
In the forward pass
In the forward pass
| Split input onto different devices
| Split input onto different devices
| Copy parameter
to
onto different devices
| Copy parameter onto different devices
|||| Compute forward pass in parallel
|||| Compute forward pass in parallel
| Merge output from different devices
| Merge output from different devices
...
@@ -87,7 +87,7 @@ block2 {
...
@@ -87,7 +87,7 @@ block2 {
}
}
```
```
## P
ro
formance Imporvement
## P
er
formance Imporvement
There are serial places we can make this parallel_do faster.
There are serial places we can make this parallel_do faster.
...
...
paddle/fluid/framework/block_desc.cc
浏览文件 @
aaf818f5
...
@@ -44,6 +44,25 @@ bool BlockDesc::HasVar(const std::string &name) const {
...
@@ -44,6 +44,25 @@ bool BlockDesc::HasVar(const std::string &name) const {
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
}
}
VarDesc
*
BlockDesc
::
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
if
(
!
this
->
HasVar
(
old_name
))
{
return
nullptr
;
}
need_update_
=
true
;
auto
*
var
=
this
->
Var
(
old_name
);
VarDesc
*
new_var
=
new
VarDesc
(
*
(
var
->
Proto
()));
new_var
->
SetName
(
new_name
);
vars_
[
new_name
].
reset
(
new_var
);
// rename inputs and outputs
for
(
const
auto
&
op
:
ops_
)
{
auto
*
it
=
op
.
get
();
it
->
Rename
(
old_name
,
new_name
);
}
vars_
.
erase
(
old_name
);
return
new_var
;
}
VarDesc
*
BlockDesc
::
FindVarRecursive
(
const
std
::
string
&
name
)
const
{
VarDesc
*
BlockDesc
::
FindVarRecursive
(
const
std
::
string
&
name
)
const
{
if
(
name
==
kEmptyVarName
)
return
nullptr
;
if
(
name
==
kEmptyVarName
)
return
nullptr
;
...
...
paddle/fluid/framework/block_desc.h
浏览文件 @
aaf818f5
...
@@ -57,6 +57,8 @@ class BlockDesc {
...
@@ -57,6 +57,8 @@ class BlockDesc {
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
VarDesc
*
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
&
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
VarDesc
&
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
...
...
paddle/fluid/framework/channel.h
浏览文件 @
aaf818f5
...
@@ -15,6 +15,8 @@ limitations under the License. */
...
@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <stddef.h> // for size_t
#include <stddef.h> // for size_t
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
...
@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
ch
->
Close
();
ch
->
Close
();
}
}
/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class
ChannelHolder
{
public:
template
<
typename
T
>
void
Reset
(
size_t
buffer_size
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
buffer_size
));
}
template
<
typename
T
>
bool
Send
(
T
*
data
)
{
if
(
!
IsInitialized
())
return
false
;
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)));
// Static cast should be safe because we have ensured that types are same
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
return
channel
!=
nullptr
?
channel
->
Send
(
data
)
:
false
;
}
template
<
typename
T
>
bool
Receive
(
T
*
data
)
{
if
(
!
IsInitialized
())
return
false
;
PADDLE_ENFORCE_EQ
(
holder_
->
Type
(),
std
::
type_index
(
typeid
(
T
)));
Channel
<
T
>*
channel
=
static_cast
<
Channel
<
T
>*>
(
holder_
->
Ptr
());
return
channel
!=
nullptr
?
channel
->
Receive
(
data
)
:
false
;
}
void
close
()
{
if
(
IsInitialized
())
holder_
->
Close
();
}
inline
bool
IsInitialized
()
const
{
return
holder_
!=
nullptr
;
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct
Placeholder
{
virtual
~
Placeholder
()
{}
virtual
const
std
::
type_index
Type
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
void
Close
()
const
=
0
;
std
::
type_info
type_
;
};
template
<
typename
T
>
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
size_t
buffer_size
)
:
type_
(
std
::
type_index
(
typeid
(
T
)))
{
channel_
.
reset
(
MakeChannel
<
T
>
(
buffer_size
));
}
virtual
const
std
::
type_index
Type
()
const
{
return
type_
;
}
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
channel_
.
get
());
}
virtual
void
Close
()
{
if
(
channel_
)
channel_
->
Close
();
}
std
::
unique_ptr
<
Channel
<
T
>*>
channel_
;
const
std
::
type_index
type_
;
};
// Pointer to a PlaceholderImpl object
std
::
unique_ptr
<
Placeholder
>
holder_
;
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
aaf818f5
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <set>
#include <set>
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
...
@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
...
@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
CHANNEL
)
{
var
->
GetMutable
<
ChannelHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
NCCL_COM
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
NCCL_COM
)
{
// GetMutable will be called in ncclInit
// GetMutable will be called in ncclInit
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Variable type %d is not in "
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]"
,
"LOD_RANK_TABLE, PLACE_LIST, READER,
CHANNEL,
NCCL_COM]"
,
var_type
);
var_type
);
}
}
}
}
...
...
paddle/fluid/framework/var_desc.cc
浏览文件 @
aaf818f5
...
@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
...
@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
mutable_channel_desc
()
->
set_data_type
(
data_type
);
break
;
default:
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
}
}
void
VarDesc
::
SetDataTypes
(
void
VarDesc
::
SetDataTypes
(
...
@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
...
@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
}
}
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
return
tensor_desc
().
data_type
();
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
channel_desc
().
data_type
();
break
;
default:
return
tensor_desc
().
data_type
();
}
}
}
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
...
@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
...
@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return
res
;
return
res
;
}
}
void
VarDesc
::
SetCapacity
(
int64_t
capacity
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
desc_
.
mutable_type
()
->
mutable_channel
()
->
set_capacity
(
capacity
);
break
;
default:
PADDLE_THROW
(
"Setting 'capacity' is not supported by the type of var %s."
,
this
->
Name
());
}
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
().
type
())
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
case
proto
::
VarType
::
LOD_TENSOR
:
...
@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
...
@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
}
}
const
proto
::
VarType
::
ChannelDesc
&
VarDesc
::
channel_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
type
().
channel
();
default:
PADDLE_THROW
(
"Getting 'channel_desc' is not supported by the type of var %s."
,
this
->
Name
());
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var's type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
...
@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
}
}
proto
::
VarType
::
ChannelDesc
*
VarDesc
::
mutable_channel_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
CHANNEL
:
return
desc_
.
mutable_type
()
->
mutable_channel
();
default:
PADDLE_THROW
(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s."
,
this
->
Name
());
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
PADDLE_ENFORCE
(
desc_
.
type
().
has_type
(),
"The var type hasn't been set."
);
...
...
paddle/fluid/framework/var_desc.h
浏览文件 @
aaf818f5
...
@@ -85,6 +85,8 @@ class VarDesc {
...
@@ -85,6 +85,8 @@ class VarDesc {
void
SetDataTypes
(
void
SetDataTypes
(
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
);
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
);
void
SetCapacity
(
int64_t
capacity
);
proto
::
VarType
::
Type
GetDataType
()
const
;
proto
::
VarType
::
Type
GetDataType
()
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
...
@@ -106,8 +108,10 @@ class VarDesc {
...
@@ -106,8 +108,10 @@ class VarDesc {
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
private:
private:
const
proto
::
VarType
::
ChannelDesc
&
channel_desc
()
const
;
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
const
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
proto
::
VarType
::
ChannelDesc
*
mutable_channel_desc
();
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
...
...
paddle/fluid/framework/var_type.h
浏览文件 @
aaf818f5
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
...
@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return
proto
::
VarType_Type_SELECTED_ROWS
;
return
proto
::
VarType_Type_SELECTED_ROWS
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
ReaderHolder
).
hash_code
())
{
}
else
if
(
type
.
hash_code
()
==
typeid
(
ReaderHolder
).
hash_code
())
{
return
proto
::
VarType_Type_READER
;
return
proto
::
VarType_Type_READER
;
}
else
if
(
type
.
hash_code
()
==
typeid
(
ChannelHolder
).
hash_code
())
{
return
proto
::
VarType_Type_CHANNEL
;
}
else
{
}
else
{
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
PADDLE_THROW
(
"ToVarType:Unsupported type %s"
,
type
.
name
());
}
}
...
@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
...
@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case
proto
::
VarType_Type_READER
:
case
proto
::
VarType_Type_READER
:
visitor
(
var
.
Get
<
ReaderHolder
>
());
visitor
(
var
.
Get
<
ReaderHolder
>
());
return
;
return
;
case
proto
::
VarType_Type_CHANNEL
:
visitor
(
var
.
Get
<
ChannelHolder
>
());
return
;
default:
default:
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
PADDLE_THROW
(
"Not supported visit type, %d"
,
ToVarType
(
var
.
Type
()));
}
}
...
...
paddle/fluid/operators/elementwise_add_op.h
浏览文件 @
aaf818f5
...
@@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
...
@@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
ElementwiseAddGradFunctor
{
struct
IdentityGrad
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
typename
dY
,
typename
dZ
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
;
}
}
};
template
<
typename
T
>
struct
ElementwiseAddBroadCastGradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
typename
dY
,
typename
dZ
,
typename
Pre
,
typename
N
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
,
Pre
pre
,
N
n
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
pre
,
n
))
.
sum
(
Eigen
::
array
<
int
,
1
>
{{
0
}});
}
}
};
template
<
typename
T
>
struct
ElementwiseAddBroadCast2GradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
typename
dY
,
typename
dZ
,
typename
Pre
,
typename
N
,
typename
Post
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
,
Pre
pre
,
N
n
,
Post
post
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
.
reshape
(
Eigen
::
DSizes
<
int
,
3
>
(
pre
,
n
,
post
))
.
sum
(
Eigen
::
array
<
int
,
2
>
{{
0
,
2
}});
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
...
@@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseGradCompute
<
DeviceContext
,
T
,
ElementwiseAddGradFunctor
<
T
>
,
ElemwiseGradCompute
<
DeviceContext
,
T
,
IdentityGrad
<
T
>
,
IdentityGrad
<
T
>>
(
ElementwiseAddBroadCastGradFunctor
<
T
>
,
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
IdentityGrad
<
T
>
(),
ElementwiseAddBroadCast2GradFunctor
<
T
>>
(
IdentityGrad
<
T
>
());
ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
);
}
}
};
};
...
...
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
aaf818f5
...
@@ -20,9 +20,11 @@ limitations under the License. */
...
@@ -20,9 +20,11 @@ limitations under the License. */
#ifdef __NVCC__
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_adaptor.h>
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
#endif
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -33,10 +35,10 @@ namespace operators {
...
@@ -33,10 +35,10 @@ namespace operators {
* For example:
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1,
12,1).broadcast(2,12,
5)
* x.shape(2, 12, 5) * y.shape(1,
12, 1).broadcast(2, 12,
5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* pre=2*3, n=4*5, post=1
* x.shape(
2, 3, 20) * y.shape(1,1,20).broadcast(2,3,20
)
* x.shape(
6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1
)
*/
*/
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
...
@@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
...
@@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR
(
Div
,
EIGEN_DIV
);
EIGEN_FUNCTOR
(
Div
,
EIGEN_DIV
);
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
struct
ElemwiseGradNoBroadcast
{
const
T
*
x_
;
const
T
*
y_
;
const
T
*
out_
;
const
T
*
dout_
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
if
(
dx_
!=
nullptr
)
{
dx_
[
i
]
=
dx_op_
(
x_
[
i
],
y_
[
i
],
out_
[
i
],
dout_
[
i
]);
}
if
(
dy_
!=
nullptr
)
{
dy_
[
i
]
=
dx_op_
(
x_
[
i
],
y_
[
i
],
out_
[
i
],
dout_
[
i
]);
}
}
DX_OP
dx_op_
;
DY_OP
dy_op_
;
T
*
dx_
;
T
*
dy_
;
};
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
shm
[
tid
]
=
0
;
do
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
if
(
dy
)
{
__syncthreads
();
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
if
(
threadIdx
.
x
==
0
)
{
for
(
int
k
=
1
;
k
<
h
;
++
k
)
{
shm
[
0
]
+=
shm
[
k
];
}
dy
[
j
]
=
shm
[
0
];
}
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
int
gird_size
=
w
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast1CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
&&
k
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
shm
[
tid
]
=
0
;
int
ttid
=
tid
;
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
)
{
__syncthreads
();
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
if
(
tid
==
0
)
{
for
(
int
i
=
1
;
i
<
h
;
++
i
)
{
shm
[
0
]
+=
shm
[
i
];
}
dy
[
j
]
=
shm
[
0
];
}
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
pre
*
post
);
int
gird_size
=
n
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
if
(
x
.
dims
()
==
y
.
dims
())
{
size_t
N
=
static_cast
<
size_t
>
(
framework
::
product
(
x
.
dims
()));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
N
);
for_range
(
ElemwiseGradNoBroadcast
<
T
,
DX_OP
,
DY_OP
>
{
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())});
}
else
{
// Y is a scalar
auto
x_dim
=
x
.
dims
();
auto
y_dim
=
y
.
dims
();
if
(
y_dim
.
size
()
==
1
&&
y_dim
[
0
]
==
1
)
{
// y is a scalar
auto
extended_dims
=
framework
::
vectorize
(
x_dim
);
extended_dims
.
push_back
(
1
);
x_dim
=
framework
::
make_ddim
(
extended_dims
);
}
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
);
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
pre
,
n
,
post
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast1CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
typename
broadcastfunctor
,
typename
broadcast2functor
>
typename
broadcastfunctor
,
typename
broadcast2functor
>
void
ElementwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
void
ElementwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
aaf818f5
...
@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_
->
join
();
server_thread_
->
join
();
}
}
std
::
string
GetGradVarNameForTrainer
(
const
std
::
string
&
varname
)
const
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
@@ -91,8 +84,7 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -91,8 +84,7 @@ class ListenAndServOp : public framework::OperatorBase {
// FIXME(Yancey1989): initialize rpc server with lazy mode.
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
ins
=
Inputs
(
"X"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
...
@@ -109,40 +101,24 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -109,40 +101,24 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
rpc_service_
->
SetCond
(
0
);
size_t
recv_var_cnt
=
0
;
size_t
recv_var_cnt
=
0
;
size_t
update_param_cnt
=
0
;
int
batch_barrier
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad
_var_name
=
v
.
first
;
auto
recv
_var_name
=
v
.
first
;
if
(
grad
_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
if
(
recv
_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
exit_flag
=
true
;
break
;
break
;
}
else
if
(
grad
_var_name
==
BATCH_BARRIER_MESSAGE
)
{
}
else
if
(
recv
_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
batch_barrier
++
;
continue
;
continue
;
}
else
{
}
else
{
// receive a variable
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
recv_var_cnt
++
;
auto
it
=
auto
*
var
=
recv_scope
.
FindVar
(
recv_var_name
);
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
update_param_cnt
++
;
VLOG
(
3
)
<<
"received grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
}
else
{
VLOG
(
3
)
<<
"received variable: "
<<
grad_var_name
<<
" no need to update param"
;
}
if
(
fan_in
>
1
&&
!
param_var_name
.
empty
())
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
var
==
nullptr
)
{
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
grad
_var_name
;
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv
_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
PADDLE_THROW
(
"Can not find server side var"
);
}
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
...
@@ -151,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -151,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
}
}
}
}
VLOG
(
3
)
<<
"recv "
<<
recv_var_cnt
<<
" parmeters for one barrier."
;
if
(
exit_flag
)
{
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
rpc_service_
->
ShutDown
();
}
}
VLOG
(
3
)
<<
"run optimize graph..."
;
try
{
try
{
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}
// Reset the received sparse variables, the sum operator would not
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// mini-batch.
// TO
OD
(Yancey1989): move the reset action into an operator, we couldn't
// TO
DO
(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
// have any hide logic in the operator.
for
(
auto
&
var
:
sparse_vars
)
{
for
(
auto
&
var
:
sparse_vars
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
update_param_cnt
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
grads_counter_
.
clear
(
);
rpc_service_
->
WaitClientGet
(
ins
.
size
()
);
sparse_vars
.
clear
();
sparse_vars
.
clear
();
}
// while(true)
}
// while(true)
}
}
...
@@ -181,13 +154,13 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -181,13 +154,13 @@ class ListenAndServOp : public framework::OperatorBase {
protected:
protected:
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
};
};
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
ListenAndServOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
ListenAndServOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor) Variables that server recv."
).
AsDuplicable
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
ListenAndServ operator
ListenAndServ operator
...
@@ -201,16 +174,7 @@ from send_op and send back variables to recv_op.
...
@@ -201,16 +174,7 @@ from send_op and send back variables to recv_op.
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
"BlockID to run on server side."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
}
}
};
};
...
...
paddle/fluid/platform/float16.h
浏览文件 @
aaf818f5
...
@@ -74,17 +74,12 @@ struct PADDLE_ALIGN(2) float16 {
...
@@ -74,17 +74,12 @@ struct PADDLE_ALIGN(2) float16 {
// The following defaulted special class member functions
// The following defaulted special class member functions
// are added to make float16 pass the std::is_trivial test
// are added to make float16 pass the std::is_trivial test
HOSTDEVICE
inline
float16
()
=
default
;
float16
()
=
default
;
float16
(
const
float16
&
o
)
=
default
;
HOSTDEVICE
inline
float16
(
const
float16
&
)
=
default
;
float16
&
operator
=
(
const
float16
&
o
)
=
default
;
float16
(
float16
&&
o
)
=
default
;
HOSTDEVICE
inline
float16
&
operator
=
(
const
float16
&
)
=
default
;
float16
&
operator
=
(
float16
&&
o
)
=
default
;
~
float16
()
=
default
;
HOSTDEVICE
inline
float16
(
float16
&&
)
=
default
;
HOSTDEVICE
inline
float16
&
operator
=
(
float16
&&
)
=
default
;
HOSTDEVICE
inline
~
float16
()
=
default
;
// Constructors
// Constructors
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_CUDA_FP16
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
aaf818f5
...
@@ -172,6 +172,14 @@ void BindBlockDesc(py::module &m) {
...
@@ -172,6 +172,14 @@ void BindBlockDesc(py::module &m) {
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
std
::
string
name
=
byte_name
;
std
::
string
name
=
byte_name
;
return
self
.
HasVar
(
name
);
return
self
.
HasVar
(
name
);
},
py
::
return_value_policy
::
reference
)
.
def
(
"rename_var"
,
[](
BlockDesc
&
self
,
const
py
::
bytes
&
byte_name
,
const
py
::
bytes
&
byte_name_new
)
{
std
::
string
name
=
byte_name
;
std
::
string
new_name
=
byte_name_new
;
self
.
RenameVar
(
name
,
new_name
);
})
})
.
def
(
"has_var_recursive"
,
.
def
(
"has_var_recursive"
,
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
...
@@ -200,7 +208,7 @@ void BindVarDsec(py::module &m) {
...
@@ -200,7 +208,7 @@ void BindVarDsec(py::module &m) {
py
::
class_
<
VarDesc
>
var_desc
(
m
,
"VarDesc"
,
""
);
py
::
class_
<
VarDesc
>
var_desc
(
m
,
"VarDesc"
,
""
);
var_desc
var_desc
.
def
(
"name"
,
.
def
(
"name"
,
[](
const
VarDesc
&
self
)
{
[](
VarDesc
&
self
)
{
py
::
bytes
name
=
self
.
Name
();
py
::
bytes
name
=
self
.
Name
();
return
name
;
return
name
;
},
},
...
@@ -210,6 +218,7 @@ void BindVarDsec(py::module &m) {
...
@@ -210,6 +218,7 @@ void BindVarDsec(py::module &m) {
.
def
(
"set_shapes"
,
&
VarDesc
::
SetShapes
)
.
def
(
"set_shapes"
,
&
VarDesc
::
SetShapes
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"set_dtypes"
,
&
VarDesc
::
SetDataTypes
)
.
def
(
"set_dtypes"
,
&
VarDesc
::
SetDataTypes
)
.
def
(
"set_capacity"
,
&
VarDesc
::
SetCapacity
)
.
def
(
"shape"
,
&
VarDesc
::
GetShape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shape"
,
&
VarDesc
::
GetShape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
VarDesc
::
GetShapes
,
py
::
return_value_policy
::
reference
)
.
def
(
"shapes"
,
&
VarDesc
::
GetShapes
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
,
py
::
return_value_policy
::
reference
)
...
@@ -240,6 +249,7 @@ void BindVarDsec(py::module &m) {
...
@@ -240,6 +249,7 @@ void BindVarDsec(py::module &m) {
.
value
(
"STEP_SCOPES"
,
proto
::
VarType
::
STEP_SCOPES
)
.
value
(
"STEP_SCOPES"
,
proto
::
VarType
::
STEP_SCOPES
)
.
value
(
"LOD_RANK_TABLE"
,
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_RANK_TABLE"
,
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_TENSOR_ARRAY"
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"LOD_TENSOR_ARRAY"
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"CHANNEL"
,
proto
::
VarType
::
CHANNEL
)
.
value
(
"PLACE_LIST"
,
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"PLACE_LIST"
,
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"READER"
,
proto
::
VarType
::
READER
)
.
value
(
"READER"
,
proto
::
VarType
::
READER
)
.
value
(
"NCCL_COM"
,
proto
::
VarType
::
NCCL_COM
);
.
value
(
"NCCL_COM"
,
proto
::
VarType
::
NCCL_COM
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
aaf818f5
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <mutex> // for call_once
#include <mutex> // for call_once
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
python/paddle/trainer_config_helpers/layer_math.py
浏览文件 @
aaf818f5
...
@@ -75,7 +75,7 @@ LayerOutput.__add__ = add
...
@@ -75,7 +75,7 @@ LayerOutput.__add__ = add
def
sub
(
layeroutput
,
other
):
def
sub
(
layeroutput
,
other
):
if
is_compatible_with
(
other
,
float
):
if
is_compatible_with
(
other
,
float
):
return
slope_intercept_layer
(
input
=
layeroutput
,
intercept
=
other
)
return
slope_intercept_layer
(
input
=
layeroutput
,
intercept
=
-
other
)
if
not
isinstance
(
other
,
LayerOutput
):
if
not
isinstance
(
other
,
LayerOutput
):
logger
.
fatal
(
"LayerOutput can only be subtracted with"
logger
.
fatal
(
"LayerOutput can only be subtracted with"
" another Layeroutput or a number"
)
" another Layeroutput or a number"
)
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr
浏览文件 @
aaf818f5
...
@@ -230,7 +230,7 @@ layers {
...
@@ -230,7 +230,7 @@ layers {
input_layer_name: "__mixed_1__"
input_layer_name: "__mixed_1__"
}
}
slope: 1.0
slope: 1.0
intercept: 2
intercept:
-
2
}
}
layers {
layers {
name: "__slope_intercept_layer_4__"
name: "__slope_intercept_layer_4__"
...
@@ -411,4 +411,3 @@ sub_models {
...
@@ -411,4 +411,3 @@ sub_models {
output_layer_names: "__mixed_3__"
output_layer_names: "__mixed_3__"
is_recurrent_layer_group: false
is_recurrent_layer_group: false
}
}
python/paddle/v2/dataset/common.py
浏览文件 @
aaf818f5
...
@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
...
@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
retry
=
0
retry
=
0
retry_limit
=
3
retry_limit
=
3
while
not
(
os
.
path
.
exists
(
filename
)
and
md5file
(
filename
)
==
md5sum
):
while
not
(
os
.
path
.
exists
(
filename
)
and
md5file
(
filename
)
==
md5sum
):
if
os
.
path
.
exists
(
filename
):
print
"file md5"
,
md5file
(
filename
),
md5sum
if
retry
<
retry_limit
:
if
retry
<
retry_limit
:
retry
+=
1
retry
+=
1
else
:
else
:
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
aaf818f5
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
framework
import
framework
from
framework
import
Program
,
default_main_program
,
Parameter
,
Variable
from
framework
import
Program
,
default_main_program
,
default_startup_program
,
Parameter
,
Variable
import
optimizer
import
optimizer
from
layer_helper
import
LayerHelper
from
layer_helper
import
LayerHelper
from
distributed_spliter
import
*
from
distributed_spliter
import
*
...
@@ -133,6 +133,7 @@ class DistributeTranspiler:
...
@@ -133,6 +133,7 @@ class DistributeTranspiler:
def
transpile
(
self
,
def
transpile
(
self
,
optimize_ops
,
optimize_ops
,
params_grads
,
params_grads
,
trainer_id
,
program
=
None
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
trainers
=
1
,
...
@@ -146,13 +147,37 @@ class DistributeTranspiler:
...
@@ -146,13 +147,37 @@ class DistributeTranspiler:
Use different methods to split trainable variables to different
Use different methods to split trainable variables to different
parameter servers.
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param optimize_ops: op list of optimization, should be the
:param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize
return value of Optimizer.minimize
:type optimize_ops: list
:type optimize_ops: list
:param program: program to optimize, default is default_main_program
:param params_grads: list of tuple(weight, gradient)
:type params_grads: list
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:type pservers: string
:return: return a list of programs
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
"""
"""
assert
(
callable
(
split_method
))
assert
(
callable
(
split_method
))
if
program
is
None
:
if
program
is
None
:
...
@@ -160,25 +185,19 @@ class DistributeTranspiler:
...
@@ -160,25 +185,19 @@ class DistributeTranspiler:
self
.
program
=
program
self
.
program
=
program
self
.
trainers
=
trainers
self
.
trainers
=
trainers
self
.
optimize_ops
=
optimize_ops
self
.
optimize_ops
=
optimize_ops
# steps to transpile:
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
# like Kubernetes, we should port this to use etcd later when developing
# 2. modify trainer program add split_op to each Grad.
# fluid distributed training with fault-tolerance.
# 3. append send_op to trainer.
self
.
trainer_id
=
trainer_id
# 4. append concat_op to trainer to update local weights.
# 5. create new program for parameter server.
# 6. create parameter server program by split_method generated endpoint->VarBlock
pserver_endpoints
=
pservers
.
split
(
","
)
pserver_endpoints
=
pservers
.
split
(
","
)
# step1
# step1
param_list
=
[
pg
[
0
]
for
pg
in
params_grads
]
param_list
=
[
pg
[
0
]
for
pg
in
params_grads
]
grad_list
=
[
pg
[
1
]
for
pg
in
params_grads
]
grad_list
=
[
pg
[
1
]
for
pg
in
params_grads
]
# TODO: add split selected rows support
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
grad_blocks
=
split_dense_variable
(
grad_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
param_blocks
=
split_dense_variable
(
param_list
,
len
(
pserver_endpoints
))
# step2
# step2
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
grad_var_mapping
=
self
.
_append_split_op
(
program
,
grad_blocks
)
# step3
# step3
send_inputs
=
[]
send_inputs
=
[]
send_outputs
=
[]
send_outputs
=
[]
...
@@ -211,7 +230,7 @@ class DistributeTranspiler:
...
@@ -211,7 +230,7 @@ class DistributeTranspiler:
shape
=
[
0
])
shape
=
[
0
])
# create send_op
# create send_op
send_op
=
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"send"
,
type
=
"send"
,
inputs
=
{
"X"
:
send_inputs
},
inputs
=
{
"X"
:
send_inputs
},
outputs
=
{
"Out"
:
send_outputs
,
outputs
=
{
"Out"
:
send_outputs
,
...
@@ -223,14 +242,158 @@ class DistributeTranspiler:
...
@@ -223,14 +242,158 @@ class DistributeTranspiler:
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
orig_param
=
program
.
global_block
().
vars
[
varname
]
orig_param
=
program
.
global_block
().
vars
[
varname
]
concat
=
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"concat"
,
type
=
"concat"
,
inputs
=
{
"X"
:
splited_var
},
inputs
=
{
"X"
:
splited_var
},
outputs
=
{
"Out"
:
[
orig_param
]},
outputs
=
{
"Out"
:
[
orig_param
]},
attrs
=
{
"axis"
:
0
})
attrs
=
{
"axis"
:
0
})
def
_create_vars_from_blocklist
(
self
,
program
,
block_list
):
def
get_trainer_program
(
self
):
# Create respective variables using the block_list
# remove optimize ops and add a send op to main_program
self
.
program
.
global_block
().
delete_ops
(
self
.
optimize_ops
)
return
self
.
program
def
get_pserver_program
(
self
,
endpoint
):
"""
Get pserver side program using the endpoint.
NOTE: assume blocks of the same variable is not distributed
on the same pserver, only change param/grad varnames for
trainers to fetch.
"""
# step1
pserver_program
=
Program
()
# step2
recv_inputs
=
[]
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
if
suff_idx
>=
0
:
orig_var_name
=
v
.
name
[:
suff_idx
]
pserver_program
.
global_block
().
create_var
(
name
=
orig_var_name
,
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
print
(
"create origin var: "
,
orig_var_name
)
for
trainer_id
in
xrange
(
self
.
trainers
):
var
=
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
orig_var_name
,
trainer_id
),
persistable
=
False
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
recv_inputs
.
append
(
var
)
print
(
"create per trainer var: "
,
var
.
name
)
# step3
optimize_block
=
pserver_program
.
create_block
(
0
)
# step 4
# Create a union-find data struct from optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind
=
self
.
_create_ufind
(
self
.
optimize_ops
)
# step 4.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver
=
[]
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
if
self
.
_is_opt_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
# step 4.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
if
ufind
.
is_connected
(
op
,
opt_op
):
if
self
.
_is_opt_op
(
op
):
self
.
_append_pserver_ops
(
optimize_block
,
op
,
endpoint
)
else
:
self
.
_append_pserver_non_opt_ops
(
optimize_block
,
op
)
break
# step5 append the listen_and_serv op
pserver_program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{
'X'
:
recv_inputs
},
outputs
=
{},
attrs
=
{
"OptimizeBlock"
:
optimize_block
,
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainers
})
pserver_program
.
sync_with_cpp
()
return
pserver_program
def
get_startup_program
(
self
,
endpoint
,
pserver_program
):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
"""
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
for
idx
,
splited_param
in
enumerate
(
params
):
pname
=
splited_param
.
name
if
same_or_split_var
(
pname
,
varname
)
and
varname
!=
pname
:
return
pname
,
splited_param
.
shape
return
""
,
[]
# 1. create vars in pserver program to startup program
pserver_vars
=
pserver_program
.
global_block
().
vars
created_var_map
=
dict
()
for
_
,
var
in
pserver_vars
.
iteritems
():
tmpvar
=
s_prog
.
global_block
().
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
)
created_var_map
[
var
.
name
]
=
tmpvar
# 2. rename op outputs
for
op
in
orig_s_prog
.
global_block
().
ops
:
new_inputs
=
dict
()
new_outputs
=
dict
()
# do not append startup op if var is not on this pserver
op_on_pserver
=
False
for
key
in
op
.
output_names
:
newname
,
_
=
_get_splited_name_and_shape
(
op
.
output
(
key
)[
0
])
if
newname
:
op_on_pserver
=
True
new_outputs
[
key
]
=
created_var_map
[
newname
]
elif
op
.
output
(
key
)[
0
]
in
pserver_vars
:
op_on_pserver
=
True
new_outputs
[
key
]
=
pserver_vars
[
op
.
output
(
key
)[
0
]]
# most startup program ops have no inputs
new_inputs
=
self
.
_get_input_map_from_op
(
pserver_vars
,
op
)
if
op_on_pserver
:
if
op
.
type
in
[
"gaussian_random"
,
"fill_constant"
,
"uniform_random"
]:
op
.
attrs
[
"shape"
]
=
new_outputs
[
"Out"
].
shape
s_prog
.
global_block
().
append_op
(
type
=
op
.
type
,
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
attrs
)
return
s_prog
# ====================== private transpiler functions =====================
def
_create_vars_from_blocklist
(
self
,
program
,
block_list
,
add_trainer_suffix
=
False
):
"""
NOTE: only grads need to be named for different trainers, use
add_trainer_suffix to rename the grad vars.
"""
block_map
=
dict
()
block_map
=
dict
()
var_mapping
=
dict
()
var_mapping
=
dict
()
for
block_str
in
block_list
:
for
block_str
in
block_list
:
...
@@ -239,11 +402,20 @@ class DistributeTranspiler:
...
@@ -239,11 +402,20 @@ class DistributeTranspiler:
block_map
[
varname
]
=
[]
block_map
[
varname
]
=
[]
block_map
[
varname
].
append
((
long
(
offset
),
long
(
size
)))
block_map
[
varname
].
append
((
long
(
offset
),
long
(
size
)))
for
varname
,
splited
in
block_map
.
iteritems
():
for
varname
,
splited
in
block_map
.
iteritems
():
orig_var
=
program
.
global_block
().
vars
[
varname
]
orig_var
=
program
.
global_block
().
var
(
varname
)
var_mapping
[
varname
]
=
[]
if
len
(
splited
)
==
1
:
if
len
(
splited
)
==
1
:
var_mapping
[
varname
]
=
[
orig_var
]
if
add_trainer_suffix
:
new_var_name
=
"%s.trainer_%d"
%
\
(
orig_var
.
name
,
self
.
trainer_id
)
program
.
global_block
().
rename_var
(
varname
,
new_var_name
)
var_mapping
[
varname
]
=
\
[
program
.
global_block
().
var
(
new_var_name
)]
else
:
var_mapping
[
varname
]
=
\
[
program
.
global_block
().
var
(
orig_var
.
name
)]
continue
continue
var_mapping
[
varname
]
=
[]
orig_shape
=
orig_var
.
shape
orig_shape
=
orig_var
.
shape
orig_dim1_flatten
=
1
orig_dim1_flatten
=
1
if
len
(
orig_shape
)
>=
2
:
if
len
(
orig_shape
)
>=
2
:
...
@@ -255,13 +427,21 @@ class DistributeTranspiler:
...
@@ -255,13 +427,21 @@ class DistributeTranspiler:
splited_shape
=
[
rows
]
splited_shape
=
[
rows
]
if
len
(
orig_shape
)
>=
2
:
if
len
(
orig_shape
)
>=
2
:
splited_shape
.
extend
(
orig_shape
[
1
:])
splited_shape
.
extend
(
orig_shape
[
1
:])
new_var_name
=
""
if
add_trainer_suffix
:
new_var_name
=
"%s.block%d.trainer_%d"
%
\
(
varname
,
i
,
self
.
trainer_id
)
else
:
new_var_name
=
"%s.block%d"
%
\
(
varname
,
i
)
var
=
program
.
global_block
().
create_var
(
var
=
program
.
global_block
().
create_var
(
name
=
"%s.block%d"
%
(
varname
,
i
)
,
name
=
new_var_name
,
persistable
=
False
,
persistable
=
False
,
dtype
=
orig_var
.
dtype
,
dtype
=
orig_var
.
dtype
,
type
=
orig_var
.
type
,
type
=
orig_var
.
type
,
shape
=
splited_shape
)
# flattend splited var
shape
=
splited_shape
)
# flattend splited var
var_mapping
[
varname
].
append
(
var
)
var_mapping
[
varname
].
append
(
var
)
program
.
global_block
().
sync_with_cpp
()
return
var_mapping
return
var_mapping
def
_clone_var
(
self
,
block
,
var
):
def
_clone_var
(
self
,
block
,
var
):
...
@@ -272,13 +452,12 @@ class DistributeTranspiler:
...
@@ -272,13 +452,12 @@ class DistributeTranspiler:
dtype
=
var
.
dtype
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
type
=
var
.
type
,
lod_level
=
var
.
lod_level
,
lod_level
=
var
.
lod_level
,
# HACK: let all param in pserver be persistable so the child
# program in recv can get them
persistable
=
True
)
persistable
=
True
)
def
_append_split_op
(
self
,
program
,
gradblocks
):
def
_append_split_op
(
self
,
program
,
gradblocks
):
# Split variables that need to be split and append respective ops
# Split variables that need to be split and append respective ops
var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
gradblocks
)
var_mapping
=
self
.
_create_vars_from_blocklist
(
program
,
gradblocks
,
add_trainer_suffix
=
True
)
for
varname
,
splited_vars
in
var_mapping
.
iteritems
():
for
varname
,
splited_vars
in
var_mapping
.
iteritems
():
# variable that don't need to split have empty splited_vars
# variable that don't need to split have empty splited_vars
if
len
(
splited_vars
)
<=
1
:
if
len
(
splited_vars
)
<=
1
:
...
@@ -308,24 +487,6 @@ class DistributeTranspiler:
...
@@ -308,24 +487,6 @@ class DistributeTranspiler:
"[LOD_TENSOR, SELECTED_ROWS]"
)
"[LOD_TENSOR, SELECTED_ROWS]"
)
return
var_mapping
return
var_mapping
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
self
.
program
.
global_block
().
delete_ops
(
self
.
optimize_ops
)
return
self
.
program
def
_create_var_for_trainers
(
self
,
block
,
var
,
trainers
):
# For each trainer, create the necessary variables
var_list
=
[]
for
i
in
xrange
(
trainers
):
var_each
=
block
.
create_var
(
name
=
"%s.trainer_%d"
%
(
var
.
name
,
i
),
persistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
type
=
var
.
type
,
shape
=
var
.
shape
)
var_list
.
append
(
var_each
)
return
var_list
def
_get_optimizer_input_shape
(
self
,
op_type
,
varkey
,
orig_shape
,
def
_get_optimizer_input_shape
(
self
,
op_type
,
varkey
,
orig_shape
,
param_shape
):
param_shape
):
"""
"""
...
@@ -353,6 +514,13 @@ class DistributeTranspiler:
...
@@ -353,6 +514,13 @@ class DistributeTranspiler:
pass
pass
return
orig_shape
return
orig_shape
def
_orig_varname
(
self
,
varname
):
suff_idx
=
varname
.
find
(
".trainer_"
)
orig_var_name
=
""
if
suff_idx
>=
0
:
orig_var_name
=
varname
[:
suff_idx
]
return
orig_var_name
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
program
=
optimize_block
.
program
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
pserver_block
=
program
.
global_block
()
...
@@ -363,18 +531,23 @@ class DistributeTranspiler:
...
@@ -363,18 +531,23 @@ class DistributeTranspiler:
if
key
==
"Grad"
:
if
key
==
"Grad"
:
grad_block
=
None
grad_block
=
None
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
for
g
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
if
same_or_split_var
(
g
.
name
,
opt_op
.
input
(
key
)[
0
]):
if
same_or_split_var
(
self
.
_orig_varname
(
g
.
name
),
opt_op
.
input
(
key
)[
0
]):
grad_block
=
g
grad_block
=
g
break
break
if
not
grad_block
:
if
not
grad_block
:
# do not append this op if current endpoint
# do not append this op if current endpoint
# is not dealing with this grad block
# is not dealing with this grad block
return
return
merged_var
=
pserver_block
.
vars
[
grad_block
.
name
]
merged_var
=
\
# append merging ops if trainers > 1
pserver_block
.
vars
[
self
.
_orig_varname
(
grad_block
.
name
)]
if
self
.
trainers
>
1
:
if
self
.
trainers
>
1
:
vars2merge
=
self
.
_create_var_for_trainers
(
vars2merge
=
[]
pserver_block
,
grad_block
,
self
.
trainers
)
for
i
in
xrange
(
self
.
trainers
):
per_trainer_name
=
"%s.trainer_%d"
%
\
(
self
.
_orig_varname
(
grad_block
.
name
),
i
)
vars2merge
.
append
(
pserver_block
.
vars
[
per_trainer_name
])
optimize_block
.
append_op
(
optimize_block
.
append_op
(
type
=
"sum"
,
type
=
"sum"
,
inputs
=
{
"X"
:
vars2merge
},
inputs
=
{
"X"
:
vars2merge
},
...
@@ -517,77 +690,6 @@ class DistributeTranspiler:
...
@@ -517,77 +690,6 @@ class DistributeTranspiler:
return
False
return
False
return
False
return
False
def
get_pserver_program
(
self
,
endpoint
):
"""
Get pserver side program using the endpoint
NOTE: assume blocks of the same variable is not distributed
on the same pserver, only change param/grad varnames for
trainers to fetch. For each pserver endpoint, server side
program must be a sub-set of the original optimization program.
"""
# step5
pserver_program
=
Program
()
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program
.
global_block
().
create_var
(
name
=
v
.
name
,
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
for
trainer_id
in
xrange
(
self
.
trainers
):
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
v
.
name
,
trainer_id
),
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
# step6
optimize_block
=
pserver_program
.
create_block
(
0
)
# step 6.1
# Create a union-find data struct by optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind
=
self
.
_create_ufind
(
self
.
optimize_ops
)
# step 6.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver
=
[]
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
if
self
.
_is_opt_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
# step 6.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
for
_
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
if
ufind
.
is_connected
(
op
,
opt_op
):
if
self
.
_is_opt_op
(
op
):
self
.
_append_pserver_ops
(
optimize_block
,
op
,
endpoint
)
else
:
self
.
_append_pserver_non_opt_ops
(
optimize_block
,
op
)
break
# Append the listen_and_serv op
pserver_program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"OptimizeBlock"
:
optimize_block
,
"endpoint"
:
endpoint
,
"ParamList"
:
[
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
],
"GradList"
:
[
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
],
"Fanin"
:
self
.
trainers
})
pserver_program
.
sync_with_cpp
()
return
pserver_program
def
_get_input_map_from_op
(
self
,
varmap
,
op
):
def
_get_input_map_from_op
(
self
,
varmap
,
op
):
iomap
=
dict
()
iomap
=
dict
()
for
key
in
op
.
input_names
:
for
key
in
op
.
input_names
:
...
@@ -611,61 +713,3 @@ class DistributeTranspiler:
...
@@ -611,61 +713,3 @@ class DistributeTranspiler:
else
:
else
:
iomap
[
key
]
=
vars
iomap
[
key
]
=
vars
return
iomap
return
iomap
def
get_startup_program
(
self
,
endpoint
,
pserver_program
):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
"""
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
for
idx
,
splited_param
in
enumerate
(
params
):
pname
=
splited_param
.
name
if
same_or_split_var
(
pname
,
varname
)
and
varname
!=
pname
:
return
pname
,
splited_param
.
shape
return
""
,
[]
# 1. create vars in pserver program to startup program
pserver_vars
=
pserver_program
.
global_block
().
vars
created_var_map
=
dict
()
for
_
,
var
in
pserver_vars
.
iteritems
():
tmpvar
=
s_prog
.
global_block
().
create_var
(
name
=
var
.
name
,
persistable
=
var
.
persistable
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
)
created_var_map
[
var
.
name
]
=
tmpvar
# 2. rename op outputs
for
op
in
orig_s_prog
.
global_block
().
ops
:
new_inputs
=
dict
()
new_outputs
=
dict
()
# do not append startup op if var is not on this pserver
op_on_pserver
=
False
for
key
in
op
.
output_names
:
newname
,
_
=
_get_splited_name_and_shape
(
op
.
output
(
key
)[
0
])
if
newname
:
op_on_pserver
=
True
new_outputs
[
key
]
=
created_var_map
[
newname
]
elif
op
.
output
(
key
)[
0
]
in
pserver_vars
:
op_on_pserver
=
True
new_outputs
[
key
]
=
pserver_vars
[
op
.
output
(
key
)[
0
]]
# most startup program ops have no inputs
new_inputs
=
self
.
_get_input_map_from_op
(
pserver_vars
,
op
)
if
op_on_pserver
:
if
op
.
type
in
[
"gaussian_random"
,
"fill_constant"
,
"uniform_random"
]:
op
.
attrs
[
"shape"
]
=
new_outputs
[
"Out"
].
shape
s_prog
.
global_block
().
append_op
(
type
=
op
.
type
,
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
attrs
)
return
s_prog
python/paddle/v2/fluid/framework.py
浏览文件 @
aaf818f5
...
@@ -286,6 +286,10 @@ class Variable(object):
...
@@ -286,6 +286,10 @@ class Variable(object):
def
name
(
self
):
def
name
(
self
):
return
self
.
desc
.
name
()
return
self
.
desc
.
name
()
@
name
.
setter
def
name
(
self
,
new_name
):
self
.
desc
.
set_name
(
new_name
)
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
# convert to tuple, make it as same as numpy API.
# convert to tuple, make it as same as numpy API.
...
@@ -531,6 +535,12 @@ class Operator(object):
...
@@ -531,6 +535,12 @@ class Operator(object):
"""
"""
return
self
.
desc
.
input
(
name
)
return
self
.
desc
.
input
(
name
)
def
rename_input
(
self
,
old_name
,
new_name
):
self
.
desc
.
rename_input
(
old_name
,
new_name
)
def
rename_output
(
self
,
old_name
,
new_name
):
self
.
desc
.
rename_output
(
old_name
,
new_name
)
@
property
@
property
def
input_names
(
self
):
def
input_names
(
self
):
"""
"""
...
@@ -540,6 +550,14 @@ class Operator(object):
...
@@ -540,6 +550,14 @@ class Operator(object):
"""
"""
return
self
.
desc
.
input_names
()
return
self
.
desc
.
input_names
()
@
property
def
input_arg_names
(
self
):
return
self
.
desc
.
input_arg_names
()
@
property
def
output_arg_names
(
self
):
return
self
.
desc
.
output_arg_names
()
def
output
(
self
,
name
):
def
output
(
self
,
name
):
"""
"""
Get output arguments by the output parameter name
Get output arguments by the output parameter name
...
@@ -741,6 +759,60 @@ class Block(object):
...
@@ -741,6 +759,60 @@ class Block(object):
def
has_var
(
self
,
name
):
def
has_var
(
self
,
name
):
return
name
in
self
.
vars
return
name
in
self
.
vars
def
rename_var
(
self
,
name
,
new_name
):
"""
Rename variable in vars and ops' inputs and outputs
"""
if
not
self
.
has_var
(
name
):
raise
ValueError
(
"var %s is not in current"
%
name
)
v
=
self
.
var
(
name
)
stop_gradient
=
None
trainable
=
None
optimize_attr
=
None
regularizer
=
None
gradient_clip_attr
=
None
error_clip
=
None
if
type
(
v
)
==
Parameter
:
stop_gradient
=
v
.
stop_gradient
trainable
=
v
.
trainable
optimize_attr
=
v
.
optimize_attr
regularizer
=
v
.
regularizer
gradient_clip_attr
=
v
.
gradient_clip_attr
error_clip
=
v
.
error_clip
elif
type
(
v
)
==
Variable
:
error_clip
=
v
.
error_clip
stop_gradient
=
v
.
stop_gradient
else
:
raise
ValueError
(
"unsupported var type: %s"
,
type
(
v
))
self
.
desc
.
rename_var
(
name
,
new_name
)
d
=
self
.
desc
.
find_var
(
new_name
)
var
=
None
if
type
(
v
)
==
Parameter
:
var
=
Parameter
(
self
,
d
.
shape
(),
d
.
dtype
(),
name
=
new_name
,
stop_gradient
=
stop_gradient
,
trainable
=
trainable
,
optimize_attr
=
optimize_attr
,
regularizer
=
regularizer
,
gradient_clip_attr
=
gradient_clip_attr
,
error_clip
=
error_clip
)
elif
type
(
v
)
==
Variable
:
var
=
Variable
(
self
,
name
=
new_name
,
error_clip
=
error_clip
,
stop_gradient
=
stop_gradient
)
# rename the python side, sync_with_cpp will only add
# new vars/ops to python side.
self
.
vars
[
new_name
]
=
var
del
self
.
vars
[
name
]
self
.
sync_with_cpp
()
def
create_parameter
(
self
,
*
args
,
**
kwargs
):
def
create_parameter
(
self
,
*
args
,
**
kwargs
):
global_block
=
self
.
program
.
global_block
()
global_block
=
self
.
program
.
global_block
()
param
=
Parameter
(
global_block
,
*
args
,
**
kwargs
)
param
=
Parameter
(
global_block
,
*
args
,
**
kwargs
)
...
...
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
浏览文件 @
aaf818f5
...
@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count
...
@@ -58,14 +58,19 @@ trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# current pserver endpoint
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# current pserver endpoint
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
"TRAINER"
)
# get the training role: trainer/pserver
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
t
=
fluid
.
DistributeTranspiler
()
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
optimize_ops
,
params_grads
,
0
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_startup
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录