Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
61343fbf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
61343fbf
编写于
5月 10, 2018
作者:
W
Wu Yi
提交者:
GitHub
5月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10531 from typhoonzero/refine_grpc_serde_code
Refine serde code
上级
6d371e45
796a448c
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
172 addition
and
178 deletion
+172
-178
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+105
-113
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+3
-3
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+20
-20
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+44
-42
未找到文件。
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
61343fbf
...
...
@@ -29,60 +29,26 @@ namespace paddle {
namespace
operators
{
namespace
detail
{
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
// When using GPU, need to free the copied CPU buffer
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
using
VarMsg
=
sendrecv
::
VariableMessage
;
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
void
*
payload
=
nullptr
;
size_t
payload_size
;
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
e
.
WriteBool
(
VarMsg
::
kProfileFieldNumber
,
platform
::
IsProfileEnabled
());
}
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
}
if
(
!
out_name
.
empty
())
{
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
}
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
framework
::
ToDataType
(
tensor
.
type
()));
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
tensor
.
type
())));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
auto
lod
=
tensor
.
lod
();
// std::vector<Vector<size_t>>
const
framework
::
LoD
lod
=
tensor
.
lod
();
if
(
lod
.
size
()
>
0
)
{
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
lod
.
size
());
request
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kLodFieldNumber
,
2
+
// tag + varintlength of submessage
1
+
// kLodDataFieldNumber
each
.
size
());
// auto copied from GPU
VarMsg
::
LodData
*
lod_inner
=
request
->
add_lod
();
for
(
auto
&
d
:
each
)
{
e
.
WriteUint64
(
VarMsg
::
LodData
::
kLodDataFieldNumber
,
d
);
lod_inner
->
add_lod_data
(
d
);
}
}
}
...
...
@@ -90,68 +56,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
}
else
{
payload
=
tensor
.
data
<
void
>
();
*
payload
=
tensor
.
data
<
void
>
();
}
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
break
;
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
{
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
*
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
}
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
framework
::
ToDataType
(
slr
->
value
().
type
()));
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
slr
->
value
().
type
())));
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
0
);
e
.
WriteUint64
(
VarMsg
::
kSlrHeightFieldNumber
,
slr
->
height
());
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
())
,
copy_size
,
gpu_dev_ctx
.
stream
());
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
}
else
{
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
*
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
}
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
break
;
default:
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
request
.
set_profile
(
platform
::
IsProfileEnabled
());
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
break
;
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
...
...
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
if
(
framework
::
ToVarType
(
var
->
Type
())
==
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
// NOTE: rows is of type int64_t
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
...
...
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
[](
void
*
backing
)
{
// TODO(typhoonzero): add unref here, same as above.
},
rows_memory_size
,
[](
void
*
backing
)
{},
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
61343fbf
...
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
...
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
61343fbf
...
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
length
))
{
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int
start_pos
=
input
->
CurrentPosition
();
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
return
tag
;
}
lod
->
push_back
(
v
);
}
...
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint
64
_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
uint
32
_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
}
...
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint
64
_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
uint
32
_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
}
...
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
length
))
{
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
...
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
...
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
...
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
}
break
;
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
}
break
;
...
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
length
))
{
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
}
break
;
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
61343fbf
...
...
@@ -18,7 +18,9 @@ import math
import
distributed_splitter
as
splitter
from
..
import
core
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
...
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
]
grad_list
=
[
grad
for
grad
in
grad_list
if
grad
.
name
!=
framework
.
grad_var_name
(
self
.
table_name
)
if
grad
.
name
!=
grad_var_name
(
self
.
table_name
)
]
self
.
table_param_grad
=
[
param_grad
for
param_grad
in
params_grads
...
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks.
"""
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
orig_s_prog
=
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
...
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
framework
.
grad_var_name
(
self
.
table_name
)
table_grad_name
=
grad_var_name
(
self
.
table_name
)
for
op
in
all_ops
:
if
table_grad_name
in
op
.
output_arg_names
:
op_index
=
list
(
all_ops
).
index
(
op
)
...
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable
=
True
)
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
self
.
origin_program
.
global_block
().
vars
[
grad_var_name
(
self
.
table_name
)],
persistable
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录