Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cbc6e6eb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“47aea0cdf8bb2487be3efd89e42d71bf81d30f18”上不存在“doc/api/v2/run_logic.html”
未验证
提交
cbc6e6eb
编写于
8月 21, 2018
作者:
T
tangwei12
提交者:
GitHub
8月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12247 from seiriosPlus/dis_ckpt_fix
add load slice_vars in io.py
上级
72965226
08152916
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
229 addition
and
15 deletion
+229
-15
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+3
-2
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+1
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+2
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+7
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+123
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+56
-10
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+36
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
cbc6e6eb
...
...
@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', '
paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'
], varargs=None, keywords=None, defaults=(
None, None))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'
, 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None,
None, None))
paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
cbc6e6eb
...
...
@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
auto
*
lt_var
=
scope
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
// TODO(tangwei12): find out why scope will be error.
auto
*
lt_var
=
scope_
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update var kLookupTablePath to: "
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
_
);
return
true
;
}
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
cbc6e6eb
...
...
@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
DeserializeFromStream
(
fin
,
selectedRows
,
dev_ctx
);
selectedRows
->
SyncIndex
();
}
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
cbc6e6eb
...
...
@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
MkDirRecursively
(
DirName
(
filename
).
c_str
());
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
// get device context from pool
...
...
python/paddle/fluid/framework.py
浏览文件 @
cbc6e6eb
...
...
@@ -1363,6 +1363,13 @@ class Program(object):
self
.
_current_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
self
.
_op_role_var
=
[]
# for distribute
self
.
_is_distributed
=
False
self
.
_is_chief
=
False
self
.
_slice_vars_and_attrs
=
[]
self
.
_endpoints
=
[]
self
.
_distributed_lookup_table
=
None
@
property
def
op_role
(
self
):
"""
...
...
python/paddle/fluid/io.py
浏览文件 @
cbc6e6eb
...
...
@@ -372,6 +372,7 @@ def load_vars(executor,
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
filename
=
filename
)
else
:
...
...
@@ -403,9 +404,12 @@ def load_vars(executor,
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
executor
.
run
(
load_prog
)
# load slice vars on pserver, if have it.
_load_slice_up_vars
(
executor
,
dirname
,
main_program
.
_slice_vars_and_attrs
)
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
...
...
@@ -659,11 +663,19 @@ def save_inference_model(dirname,
save_persistables
(
executor
,
dirname
,
inference_program
,
params_filename
)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if
main_program
.
_is_distributed
and
main_program
.
_is_chief
and
main_program
.
_distributed_lookup_table
:
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
_save_lookup_tables_by_notify
(
executor
,
lookup_table_filename
,
main_program
.
_distributed_lookup_table
,
main_program
.
_endpoints
)
def
load_inference_model
(
dirname
,
executor
,
model_filename
=
None
,
params_filename
=
None
):
params_filename
=
None
,
pserver_endpoints
=
None
):
"""
Load inference model from a directory
...
...
@@ -679,6 +691,10 @@ def load_inference_model(dirname,
parameters were saved in a single binary
file. If parameters were saved in separate
files, set it as 'None'.
pserver_endpoints(list|None): This only need by distributed inference.
When use distributed look up table in training,
We also need it in inference.The parameter is
a list of pserver endpoints.
Returns:
tuple: The return of this function is a tuple with three elements:
...
...
@@ -697,12 +713,16 @@ def load_inference_model(dirname,
exe = fluid.Executor(fluid.CPUPlace())
path = "./infer_model"
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[inference_program, feed_target_names, fetch_targets] =
fluid.io.load_inference_model(dirname=path, executor=exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# if we need lookup table, we will use:
fluid.io.load_inference_model(dirname=path, executor=exe, pserver_endpoints=endpoints)
# In this exsample, the inference program was saved in the
# "./infer_model/__model__" and parameters were saved in
# separate files in ""./infer_model".
...
...
@@ -729,6 +749,9 @@ def load_inference_model(dirname,
program
=
Program
.
parse_from_string
(
program_desc_str
)
load_persistables
(
executor
,
dirname
,
program
,
params_filename
)
if
pserver_endpoints
:
program
=
_endpoints_replacement
(
program
,
pserver_endpoints
)
feed_target_names
=
program
.
desc
.
get_feed_target_names
()
fetch_target_names
=
program
.
desc
.
get_fetch_target_names
()
fetch_targets
=
[
...
...
@@ -738,6 +761,61 @@ def load_inference_model(dirname,
return
[
program
,
feed_target_names
,
fetch_targets
]
def
_save_lookup_tables_by_notify
(
executor
,
dirname
,
lookup_table
,
pserver_endpoints
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program
=
Program
()
pserver_notify_block
=
pserver_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
pserver_endpoints
attrs
[
'dir'
]
=
dirname
attrs
[
'lookup_table'
]
=
lookup_table
pserver_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
pserver_notify_program
)
def
_endpoints_replacement
(
program
,
endpoints
):
ENDPOINT_MAP
=
"epmap"
for
op
in
program
.
global_block
().
ops
:
if
op
.
has_attr
(
ENDPOINT_MAP
):
op
.
set_attr
(
ENDPOINT_MAP
,
endpoints
)
program
.
_sync_with_cpp
()
return
program
def
get_parameter_value
(
para
,
executor
):
"""
Get the LoDTensor value of the given parameter.
...
...
@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None):
program
=
default_main_program
()
var
=
program
.
global_block
().
var
(
name
)
return
get_parameter_value
(
var
,
executor
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars_and_attrs
):
if
not
slice_vars_and_attrs
:
return
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
for
var_tuple
in
slice_vars_and_attrs
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
clone_orig_var
.
name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
executor
.
run
(
load_prog
)
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
cbc6e6eb
...
...
@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
get_main_program
(
self
):
main
=
fluid
.
Program
()
...
...
@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase):
def
test_transpiler
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
class
TestBasicModel
(
TranspilerTest
):
...
...
@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest):
decay_rate
=
0.1
,
staircase
=
True
))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest):
learning_rate
=
fluid
.
layers
.
piecewise_decay
([
10000
,
20000
],
[
1.0
,
0.5
,
1.0
]))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest):
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest):
def
network_with_table
(
self
,
is_sparse
,
is_distributed
):
self
.
table_size
=
1000
self
.
emb_size
=
64
self
.
lookup_table_name
=
'shared_w'
def
emb_pool
(
ids
):
emb
=
fluid
.
layers
.
embedding
(
input
=
ids
,
size
=
[
self
.
table_size
,
self
.
emb_size
],
dtype
=
'float32'
,
param_attr
=
'shared_w'
,
# share parameter
param_attr
=
self
.
lookup_table_name
,
# share parameter
is_sparse
=
is_sparse
,
is_distributed
=
is_distributed
)
pool
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'average'
)
...
...
@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
pserver1
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
self
.
assertTrue
(
self
.
transpiler
.
has_distributed_lookup_table
)
lookup_table_var
=
pserver1
.
global_block
().
vars
[
...
...
@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
self
.
assertEqual
(
row_size
,
calc_row_size
)
class
TestDistArgsInProgram
(
TestDistLookupTableBase
):
def
net_conf
(
self
):
self
.
network_with_table
(
is_sparse
=
True
,
is_distributed
=
True
)
def
transpiler_test_impl
(
self
):
trainer
,
_
=
self
.
get_trainer
()
self
.
assertTrue
(
trainer
.
_is_distributed
)
self
.
assertTrue
(
trainer
.
_is_chief
)
self
.
assertEqual
(
trainer
.
_distributed_lookup_table
,
self
.
lookup_table_name
)
self
.
assertEqual
(
trainer
.
_endpoints
,
[
self
.
pserver1_ep
,
self
.
pserver2_ep
])
class
TestRMSPropOptimizer
(
TranspilerTest
):
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
...
...
@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
optimizer
=
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.1
)
optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest):
self
.
assertEqual
(
moment_var
.
shape
,
(
500
,
1000
))
class
TestLoadSliceVar
(
TranspilerTest
):
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1000
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'fc_w'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'fc_b'
))
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
optimizer
=
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.1
)
optimizer
.
minimize
(
avg_cost
)
def
transpiler_test_impl
(
self
):
pserver
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
)
pserver2
,
_
=
self
.
get_pserver
(
self
.
pserver2_ep
)
self
.
assertTrue
(
pserver
.
_slice_vars_and_attrs
)
self
.
assertTrue
(
pserver2
.
_slice_vars_and_attrs
)
for
idx
in
xrange
(
len
(
pserver
.
_slice_vars_and_attrs
)):
self
.
assertEqual
(
pserver
.
_slice_vars_and_attrs
[
idx
][
0
],
pserver2
.
_slice_vars_and_attrs
[
idx
][
0
])
total_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
0
].
shape
)
self
.
assertEqual
(
total_numel
,
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
)
+
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver2
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
cbc6e6eb
...
...
@@ -215,6 +215,13 @@ class DistributeTranspiler(object):
for
param_var
,
grad_var
in
self
.
params_grads
:
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
self
.
origin_program
.
_is_chief
=
self
.
trainer_id
==
0
self
.
origin_program
.
_distributed_lookup_table
=
self
.
table_name
if
self
.
table_name
else
None
# split and create vars, then put splited vars in dicts for later use.
# step 1: split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
()
...
...
@@ -590,6 +597,8 @@ class DistributeTranspiler(object):
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
pserver_program
,
table_opt_block
.
idx
)
pserver_program
.
_distributed_lookup_table
=
self
.
table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if
self
.
has_distributed_lookup_table
:
...
...
@@ -616,6 +625,10 @@ class DistributeTranspiler(object):
outputs
=
{},
attrs
=
attrs
)
# add distributed attrs
pserver_program
.
_slice_vars_and_attrs
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
pserver_program
.
_sync_with_cpp
()
return
pserver_program
...
...
@@ -689,8 +702,31 @@ class DistributeTranspiler(object):
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
all_attrs
())
# add slice vars
s_prog
.
_slice_vars_and_attrs
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
return
s_prog
def
_get_slice_vars_and_attrs
(
self
,
endpoint
):
slice_vars_and_attrs
=
[]
block_suffix
=
"block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
orig_var_name
,
block_name
,
_
=
self
.
_get_varname_parts
(
param
.
name
)
if
not
block_name
:
continue
block_idx
=
int
(
block_name
.
split
(
block_suffix
)[
1
])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_numel
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_attrs
.
append
([
orig_var
,
skip_numel
,
param
])
return
slice_vars_and_attrs
# ====================== private transpiler functions =====================
def
_has_distributed_lookup_table
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录