Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
890f626b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
890f626b
编写于
5月 12, 2021
作者:
T
tangwei12
提交者:
GitHub
5月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize/fleet save (#32817)
* fix cpp lint * fix save/load with unexpected value * fix save and user interface
上级
e1a4c83c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
130 addition
and
64 deletion
+130
-64
paddle/fluid/distributed/table/common_sparse_table.cc
paddle/fluid/distributed/table/common_sparse_table.cc
+33
-16
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+49
-0
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+34
-36
python/paddle/fluid/tests/unittests/test_fleet_base_2.py
python/paddle/fluid/tests/unittests/test_fleet_base_2.py
+14
-12
未找到文件。
paddle/fluid/distributed/table/common_sparse_table.cc
浏览文件 @
890f626b
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include <sstream>
#include <sstream>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -25,7 +25,8 @@ class ValueBlock;
...
@@ -25,7 +25,8 @@ class ValueBlock;
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
#define PSERVER_SAVE_SUFFIX "_txt"
#define PSERVER_SAVE_SUFFIX ".shard"
using
boost
::
lexical_cast
;
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -100,7 +101,7 @@ struct Meta {
...
@@ -100,7 +101,7 @@ struct Meta {
};
};
void
ProcessALine
(
const
std
::
vector
<
std
::
string
>&
columns
,
const
Meta
&
meta
,
void
ProcessALine
(
const
std
::
vector
<
std
::
string
>&
columns
,
const
Meta
&
meta
,
std
::
vector
<
std
::
vector
<
float
>>*
values
)
{
const
int64_t
id
,
std
::
vector
<
std
::
vector
<
float
>>*
values
)
{
auto
colunmn_size
=
columns
.
size
();
auto
colunmn_size
=
columns
.
size
();
auto
load_values
=
auto
load_values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
columns
[
colunmn_size
-
1
],
","
);
paddle
::
string
::
split_string
<
std
::
string
>
(
columns
[
colunmn_size
-
1
],
","
);
...
@@ -116,8 +117,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
...
@@ -116,8 +117,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
"The data format in txt does not meet the field "
"The data format in txt does not meet the field "
"requirements defined in meta"
));
"requirements defined in meta"
));
std
::
transform
(
start
,
end
,
std
::
back_inserter
(
val
),
std
::
transform
(
start
,
end
,
std
::
back_inserter
(
val
),
[
id
](
std
::
string
va
)
{
[](
std
::
string
va
)
{
return
std
::
stof
(
va
);
});
float
v
=
0.0
;
try
{
v
=
lexical_cast
<
float
>
(
va
);
}
catch
(
boost
::
bad_lexical_cast
&
e
)
{
VLOG
(
0
)
<<
"id: "
<<
id
<<
" get unexpected value: "
<<
va
<<
" and be reset to: 0.0"
;
}
return
v
;
});
values
->
push_back
(
val
);
values
->
push_back
(
val
);
offset
+=
meta
.
dims
[
x
];
offset
+=
meta
.
dims
[
x
];
}
}
...
@@ -126,25 +137,29 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
...
@@ -126,25 +137,29 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
int64_t
SaveToText
(
std
::
ostream
*
os
,
std
::
shared_ptr
<
ValueBlock
>
block
,
int64_t
SaveToText
(
std
::
ostream
*
os
,
std
::
shared_ptr
<
ValueBlock
>
block
,
const
int
mode
)
{
const
int
mode
)
{
int64_t
save_num
=
0
;
int64_t
save_num
=
0
;
for
(
auto
&
table
:
block
->
values_
)
{
for
(
auto
&
table
:
block
->
values_
)
{
for
(
auto
&
value
:
table
)
{
for
(
auto
&
value
:
table
)
{
if
(
mode
==
SaveMode
::
delta
&&
!
value
.
second
->
need_save_
)
{
if
(
mode
==
SaveMode
::
delta
&&
!
value
.
second
->
need_save_
)
{
continue
;
continue
;
}
}
save_num
+=
1
;
auto
*
vs
=
value
.
second
->
data_
.
data
();
++
save_num
;
std
::
stringstream
ss
;
std
::
stringstream
ss
;
auto
*
vs
=
value
.
second
->
data_
.
data
();
auto
id
=
value
.
first
;
auto
id
=
value
.
first
;
ss
<<
id
<<
"
\t
"
<<
value
.
second
->
count_
<<
"
\t
"
ss
<<
id
<<
"
\t
"
<<
value
.
second
->
count_
<<
"
\t
"
<<
value
.
second
->
unseen_days_
<<
"
\t
"
<<
value
.
second
->
is_entry_
<<
value
.
second
->
unseen_days_
<<
"
\t
"
<<
value
.
second
->
is_entry_
<<
"
\t
"
;
<<
"
\t
"
;
for
(
int
i
=
0
;
i
<
block
->
value_length_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
block
->
value_length_
-
1
;
i
++
)
{
ss
<<
vs
[
i
];
ss
<<
std
::
to_string
(
vs
[
i
])
<<
","
;
ss
<<
","
;
}
}
ss
<<
std
::
to_string
(
vs
[
block
->
value_length_
-
1
]);
ss
<<
"
\n
"
;
ss
<<
"
\n
"
;
os
->
write
(
ss
.
str
().
c_str
(),
sizeof
(
char
)
*
ss
.
str
().
size
());
os
->
write
(
ss
.
str
().
c_str
(),
sizeof
(
char
)
*
ss
.
str
().
size
());
...
@@ -170,7 +185,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
...
@@ -170,7 +185,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
while
(
std
::
getline
(
file
,
line
))
{
while
(
std
::
getline
(
file
,
line
))
{
auto
values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
line
,
"
\t
"
);
auto
values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
line
,
"
\t
"
);
auto
id
=
std
::
stoull
(
values
[
0
]);
auto
id
=
lexical_cast
<
int64_t
>
(
values
[
0
]);
if
(
id
%
pserver_num
!=
pserver_id
)
{
if
(
id
%
pserver_num
!=
pserver_id
)
{
VLOG
(
3
)
<<
"will not load "
<<
values
[
0
]
<<
" from "
<<
valuepath
VLOG
(
3
)
<<
"will not load "
<<
values
[
0
]
<<
" from "
<<
valuepath
...
@@ -182,15 +197,17 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
...
@@ -182,15 +197,17 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
auto
block
=
blocks
->
at
(
shard_id
);
auto
block
=
blocks
->
at
(
shard_id
);
std
::
vector
<
std
::
vector
<
float
>>
kvalues
;
std
::
vector
<
std
::
vector
<
float
>>
kvalues
;
ProcessALine
(
values
,
meta
,
&
kvalues
);
ProcessALine
(
values
,
meta
,
id
,
&
kvalues
);
block
->
Init
(
id
,
false
);
block
->
Init
(
id
,
false
);
VALUE
*
value_instant
=
block
->
GetValue
(
id
);
VALUE
*
value_instant
=
block
->
GetValue
(
id
);
if
(
values
.
size
()
==
5
)
{
if
(
values
.
size
()
==
5
)
{
value_instant
->
count_
=
std
::
stoi
(
values
[
1
]);
value_instant
->
count_
=
lexical_cast
<
int
>
(
values
[
1
]);
value_instant
->
unseen_days_
=
std
::
stoi
(
values
[
2
]);
value_instant
->
unseen_days_
=
lexical_cast
<
int
>
(
values
[
2
]);
value_instant
->
is_entry_
=
static_cast
<
bool
>
(
std
::
stoi
(
values
[
3
]));
value_instant
->
is_entry_
=
static_cast
<
bool
>
(
lexical_cast
<
int
>
(
values
[
3
]));
}
}
std
::
vector
<
float
*>
block_values
=
block
->
Get
(
id
,
meta
.
names
,
meta
.
dims
);
std
::
vector
<
float
*>
block_values
=
block
->
Get
(
id
,
meta
.
names
,
meta
.
dims
);
...
@@ -475,7 +492,7 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
...
@@ -475,7 +492,7 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
auto
*
value
=
block
->
InitGet
(
id
);
auto
*
value
=
block
->
InitGet
(
id
);
// std::copy_n(value + param_offset_, param_dim_,
// std::copy_n(value + param_offset_, param_dim_,
// pull_values + param_dim_ * offset);
// pull_values + param_dim_ * offset);
pull_values
[
offset
]
=
(
char
*
)
value
;
pull_values
[
offset
]
=
reinterpret_cast
<
char
*>
(
value
)
;
}
}
return
0
;
return
0
;
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
890f626b
...
@@ -580,6 +580,49 @@ class Fleet(object):
...
@@ -580,6 +580,49 @@ class Fleet(object):
"""
"""
self
.
_runtime_handle
.
_stop_worker
()
self
.
_runtime_handle
.
_stop_worker
()
def
save
(
self
,
dirname
,
feed
=
[],
fetch
=
[],
**
configs
):
inference
=
True
if
not
feed
and
not
fetch
:
inference
=
False
place
=
paddle
.
CPUPlace
()
executor
=
paddle
.
static
.
Executor
(
place
)
if
inference
:
feeded_var_names
=
[]
fetch_var_names
=
[]
for
var
in
feed
:
if
isinstance
(
var
,
str
):
feeded_var_names
.
append
(
var
)
elif
isinstance
(
var
,
paddle
.
static
.
Variable
):
feeded_var_names
.
append
(
var
.
name
)
else
:
raise
ValueError
(
"feed must be [str|Variable]"
)
for
var
in
fetch
:
if
isinstance
(
var
,
str
):
fetch_var_names
.
append
(
var
)
elif
isinstance
(
var
,
paddle
.
static
.
Variable
):
fetch_var_names
.
append
(
var
.
name
)
else
:
raise
ValueError
(
"feed must be [str|Variable]"
)
fetch_vars
=
[
paddle
.
static
.
default_main_program
().
global_block
().
var
(
name
)
for
name
in
fetch_var_names
]
self
.
_runtime_handle
.
_save_inference_model
(
executor
,
dirname
,
feeded_var_names
,
fetch_vars
,
None
,
True
,
0
)
else
:
increment_mode
=
0
if
"mode"
in
configs
:
increment_mode
=
int
(
configs
[
"mode"
])
self
.
_runtime_handle
.
_save_persistables
(
executor
,
dirname
,
main_program
=
None
,
mode
=
increment_mode
)
def
save_inference_model
(
self
,
def
save_inference_model
(
self
,
executor
,
executor
,
dirname
,
dirname
,
...
@@ -607,6 +650,9 @@ class Fleet(object):
...
@@ -607,6 +650,9 @@ class Fleet(object):
fleet.init_server()
fleet.init_server()
"""
"""
# warnings.warn(
# "'save_inference_model' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
# )
self
.
_runtime_handle
.
_save_inference_model
(
self
.
_runtime_handle
.
_save_inference_model
(
executor
,
dirname
,
feeded_var_names
,
target_vars
,
main_program
,
executor
,
dirname
,
feeded_var_names
,
target_vars
,
main_program
,
...
@@ -653,6 +699,9 @@ class Fleet(object):
...
@@ -653,6 +699,9 @@ class Fleet(object):
fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
"""
"""
# warnings.warn(
# "'save_persistables' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
# )
self
.
_runtime_handle
.
_save_persistables
(
executor
,
dirname
,
main_program
,
self
.
_runtime_handle
.
_save_persistables
(
executor
,
dirname
,
main_program
,
mode
)
mode
)
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
890f626b
...
@@ -32,7 +32,7 @@ def conv_indent(indent):
...
@@ -32,7 +32,7 @@ def conv_indent(indent):
return
""
.
join
([
" "
]
*
indent
)
return
""
.
join
([
" "
]
*
indent
)
PSERVER_SAVE_SUFFIX
=
"
_txt
"
PSERVER_SAVE_SUFFIX
=
"
.shard
"
class
Accessor
:
class
Accessor
:
...
@@ -916,7 +916,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -916,7 +916,7 @@ class TheOnePSRuntime(RuntimeBase):
self
.
compiled_strategy
.
origin_main_program
,
True
)
self
.
compiled_strategy
.
origin_main_program
,
True
)
values
=
[]
values
=
[]
for
id
,
names
in
context
.
items
():
for
id
,
names
in
context
.
items
():
if
names
not
in
distributed_varnames
:
if
names
[
0
]
not
in
distributed_varnames
:
# only save sparse param to local
# only save sparse param to local
self
.
_worker
.
recv_and_save_model
(
id
,
dirname
)
self
.
_worker
.
recv_and_save_model
(
id
,
dirname
)
# save sparse & distributed param on server
# save sparse & distributed param on server
...
@@ -953,11 +953,11 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -953,11 +953,11 @@ class TheOnePSRuntime(RuntimeBase):
TheOnePSRuntime
.
__exclude_vars
(
saved_varnames
),
TheOnePSRuntime
.
__exclude_vars
(
saved_varnames
),
main_program
.
list_vars
()))
main_program
.
list_vars
()))
fluid
.
io
.
save_vars
(
import
paddle
executor
,
for
var
in
remaining_vars
:
main_program
=
main_program
,
tensor
=
var
.
get_value
()
dirname
=
dirname
,
paddle
.
save
(
vars
=
remaining_vars
)
tensor
,
os
.
path
.
join
(
dirname
,
var
.
name
),
use_binary_format
=
True
)
def
_ps_inference_save_persistables
(
self
,
def
_ps_inference_save_persistables
(
self
,
executor
,
executor
,
...
@@ -978,20 +978,19 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -978,20 +978,19 @@ class TheOnePSRuntime(RuntimeBase):
if
isinstance
(
executor
,
ParallelExecutor
):
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
raise
TypeError
(
"in fleet.save
_persistables
() function, executor must be as Executor type, ParallelExecutor is not allowed"
"in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
)
if
not
isinstance
(
executor
,
Executor
):
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
raise
TypeError
(
"in fleet.save_persistables() function, executor must be as Executor type"
"in fleet.save() function, executor must be as Executor type"
)
)
if
main_program
is
None
:
if
main_program
is
None
:
main_program
=
self
.
compiled_strategy
.
get_origin_ps_main_program
()
main_program
=
self
.
compiled_strategy
.
get_origin_ps_main_program
()
if
isinstance
(
main_program
,
CompiledProgram
):
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
raise
TypeError
(
"in fleet.save
_persistables
() function, main_program must be as Program type, CompiledProgram is not allowed"
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
)
# Todo(MrChengmo): Save optimizer status
# Todo(MrChengmo): Save optimizer status
...
@@ -1013,37 +1012,36 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1013,37 +1012,36 @@ class TheOnePSRuntime(RuntimeBase):
if
isinstance
(
executor
,
ParallelExecutor
):
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
raise
TypeError
(
"in fleet.save
_inference_model
() function, executor must be as Executor type, ParallelExecutor is not allowed"
"in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
)
if
not
isinstance
(
executor
,
Executor
):
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
raise
TypeError
(
"in fleet.save_inference_model() function, executor must be as Executor type"
"in fleet.save() function, executor must be as Executor type"
)
import
paddle
program
=
self
.
origin_main_program
if
main_program
is
None
else
main_program
if
isinstance
(
program
,
CompiledProgram
):
raise
TypeError
(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
)
if
main_program
is
not
None
:
feed_vars
=
[
if
isinstance
(
main_program
,
CompiledProgram
):
program
.
global_block
().
var
(
name
)
for
name
in
feeded_var_names
raise
TypeError
(
]
"in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
)
infer_program
=
paddle
.
static
.
normalize_program
(
program
,
feed_vars
,
fluid
.
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
target_vars
)
target_vars
,
executor
,
main_program
,
None
,
None
,
export_for_deployment
)
infer_program
.
_copy_dist_param_info_from
(
program
)
else
:
fluid
.
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
model_basename
=
"__model__"
target_vars
,
executor
,
model_basename
=
os
.
path
.
join
(
dirname
,
model_basename
)
self
.
origin_main_program
,
None
,
None
,
paddle
.
save
(
infer_program
,
model_basename
)
export_for_deployment
,
True
)
model_basename
=
"__model__"
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
infer_program
,
model_filename
=
os
.
path
.
join
(
dirname
,
model_basename
)
mode
)
with
open
(
model_filename
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
program
=
Program
.
parse_from_string
(
program_desc_str
)
program
.
_copy_dist_param_info_from
(
fluid
.
default_main_program
())
self
.
_ps_inference_save_persistables
(
executor
,
dirname
,
program
,
mode
)
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
def
_save_inference_model
(
self
,
*
args
,
**
kwargs
):
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
self
.
_ps_inference_save_inference_model
(
*
args
,
**
kwargs
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_base_2.py
浏览文件 @
890f626b
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
unittest
import
unittest
import
paddle
import
paddle
paddle
.
enable_static
()
import
os
import
os
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -21,18 +23,16 @@ import paddle.fluid as fluid
...
@@ -21,18 +23,16 @@ import paddle.fluid as fluid
class
TestFleetBase
(
unittest
.
TestCase
):
class
TestFleetBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
\
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
\
"127.0.0.1:36001,127.0.0.2:36001"
"127.0.0.1:36001,127.0.0.2:36001"
def
test_ps_minimize
(
self
):
def
test_ps_minimize
(
self
):
import
paddle
import
paddle
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet
as
fleet
os
.
environ
[
"TRAINING_ROLE"
]
=
"PSERVER"
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
os
.
environ
[
"PADDLE_PORT"
]
=
"36001"
input_x
=
paddle
.
fluid
.
layers
.
data
(
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
...
@@ -47,24 +47,26 @@ class TestFleetBase(unittest.TestCase):
...
@@ -47,24 +47,26 @@ class TestFleetBase(unittest.TestCase):
role
=
fleet
.
PaddleCloudRoleMaker
(
is_collective
=
False
)
role
=
fleet
.
PaddleCloudRoleMaker
(
is_collective
=
False
)
fleet
.
init
(
role
)
fleet
.
init
(
role
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
a_sync
=
False
strategy
.
a_sync
=
False
strategy
.
a_sync_configs
=
{
"launch_barrier"
:
False
}
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
optimizer
.
minimize
(
avg_cost
)
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
avg_cost
.
name
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
avg_cost
.
name
)
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
fluid
.
default_main_program
())
fluid
.
default_main_program
())
self
.
assertRaises
(
Exception
,
fleet
.
fleet
.
save
(
dirname
=
"/tmp"
,
feed
=
[
'x'
,
'y'
],
fetch
=
[
avg_cost
])
fleet
.
save_inference_model
,
fleet
.
fleet
.
save
(
dirname
=
'/tmp/'
,
dirname
=
"/tmp"
,
feed
=
[
input_x
,
input_y
],
fetch
=
[
avg_cost
])
feeded_var_names
=
[
'x'
,
'y'
],
fleet
.
fleet
.
save
(
dirname
=
"/tmp"
)
target_vars
=
[
avg_cost
],
executor
=
pe
)
self
.
assertRaises
(
self
.
assertRaises
(
Exception
,
Exception
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录