Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cbc6e6eb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
cbc6e6eb
编写于
8月 21, 2018
作者:
T
tangwei12
提交者:
GitHub
8月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12247 from seiriosPlus/dis_ckpt_fix
add load slice_vars in io.py
上级
72965226
08152916
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
229 addition
and
15 deletion
+229
-15
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+3
-2
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+1
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+2
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+7
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+123
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+56
-10
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+36
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
cbc6e6eb
...
...
@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', '
paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'
], varargs=None, keywords=None, defaults=(
None, None))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'
, 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None,
None, None))
paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
cbc6e6eb
...
...
@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
auto
*
lt_var
=
scope
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
// TODO(tangwei12): find out why scope will be error.
auto
*
lt_var
=
scope_
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update var kLookupTablePath to: "
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
_
);
return
true
;
}
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
cbc6e6eb
...
...
@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
DeserializeFromStream
(
fin
,
selectedRows
,
dev_ctx
);
selectedRows
->
SyncIndex
();
}
};
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
cbc6e6eb
...
...
@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
MkDirRecursively
(
DirName
(
filename
).
c_str
());
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
// get device context from pool
...
...
python/paddle/fluid/framework.py
浏览文件 @
cbc6e6eb
...
...
@@ -1363,6 +1363,13 @@ class Program(object):
self
.
_current_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
self
.
_op_role_var
=
[]
# for distribute
self
.
_is_distributed
=
False
self
.
_is_chief
=
False
self
.
_slice_vars_and_attrs
=
[]
self
.
_endpoints
=
[]
self
.
_distributed_lookup_table
=
None
@
property
def
op_role
(
self
):
"""
...
...
python/paddle/fluid/io.py
浏览文件 @
cbc6e6eb
...
...
@@ -372,6 +372,7 @@ def load_vars(executor,
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
vars
=
list
(
filter
(
predicate
,
main_program
.
list_vars
())),
filename
=
filename
)
else
:
...
...
@@ -403,9 +404,12 @@ def load_vars(executor,
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
executor
.
run
(
load_prog
)
# load slice vars on pserver, if have it.
_load_slice_up_vars
(
executor
,
dirname
,
main_program
.
_slice_vars_and_attrs
)
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
...
...
@@ -659,11 +663,19 @@ def save_inference_model(dirname,
save_persistables
(
executor
,
dirname
,
inference_program
,
params_filename
)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if
main_program
.
_is_distributed
and
main_program
.
_is_chief
and
main_program
.
_distributed_lookup_table
:
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
_save_lookup_tables_by_notify
(
executor
,
lookup_table_filename
,
main_program
.
_distributed_lookup_table
,
main_program
.
_endpoints
)
def
load_inference_model
(
dirname
,
executor
,
model_filename
=
None
,
params_filename
=
None
):
params_filename
=
None
,
pserver_endpoints
=
None
):
"""
Load inference model from a directory
...
...
@@ -679,6 +691,10 @@ def load_inference_model(dirname,
parameters were saved in a single binary
file. If parameters were saved in separate
files, set it as 'None'.
pserver_endpoints(list|None): This only need by distributed inference.
When use distributed look up table in training,
We also need it in inference.The parameter is
a list of pserver endpoints.
Returns:
tuple: The return of this function is a tuple with three elements:
...
...
@@ -697,12 +713,16 @@ def load_inference_model(dirname,
exe = fluid.Executor(fluid.CPUPlace())
path = "./infer_model"
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[inference_program, feed_target_names, fetch_targets] =
fluid.io.load_inference_model(dirname=path, executor=exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# if we need lookup table, we will use:
fluid.io.load_inference_model(dirname=path, executor=exe, pserver_endpoints=endpoints)
# In this exsample, the inference program was saved in the
# "./infer_model/__model__" and parameters were saved in
# separate files in ""./infer_model".
...
...
@@ -729,6 +749,9 @@ def load_inference_model(dirname,
program
=
Program
.
parse_from_string
(
program_desc_str
)
load_persistables
(
executor
,
dirname
,
program
,
params_filename
)
if
pserver_endpoints
:
program
=
_endpoints_replacement
(
program
,
pserver_endpoints
)
feed_target_names
=
program
.
desc
.
get_feed_target_names
()
fetch_target_names
=
program
.
desc
.
get_fetch_target_names
()
fetch_targets
=
[
...
...
@@ -738,6 +761,61 @@ def load_inference_model(dirname,
return
[
program
,
feed_target_names
,
fetch_targets
]
def
_save_lookup_tables_by_notify
(
executor
,
dirname
,
lookup_table
,
pserver_endpoints
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program
=
Program
()
pserver_notify_block
=
pserver_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
pserver_endpoints
attrs
[
'dir'
]
=
dirname
attrs
[
'lookup_table'
]
=
lookup_table
pserver_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
pserver_notify_program
)
def
_endpoints_replacement
(
program
,
endpoints
):
ENDPOINT_MAP
=
"epmap"
for
op
in
program
.
global_block
().
ops
:
if
op
.
has_attr
(
ENDPOINT_MAP
):
op
.
set_attr
(
ENDPOINT_MAP
,
endpoints
)
program
.
_sync_with_cpp
()
return
program
def
get_parameter_value
(
para
,
executor
):
"""
Get the LoDTensor value of the given parameter.
...
...
@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None):
program
=
default_main_program
()
var
=
program
.
global_block
().
var
(
name
)
return
get_parameter_value
(
var
,
executor
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars_and_attrs
):
if
not
slice_vars_and_attrs
:
return
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
for
var_tuple
in
slice_vars_and_attrs
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
clone_orig_var
.
name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
executor
.
run
(
load_prog
)
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
cbc6e6eb
...
...
@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
get_main_program
(
self
):
main
=
fluid
.
Program
()
...
...
@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase):
def
test_transpiler
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
class
TestBasicModel
(
TranspilerTest
):
...
...
@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest):
decay_rate
=
0.1
,
staircase
=
True
))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest):
learning_rate
=
fluid
.
layers
.
piecewise_decay
([
10000
,
20000
],
[
1.0
,
0.5
,
1.0
]))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest):
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
sgd_optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest):
def
network_with_table
(
self
,
is_sparse
,
is_distributed
):
self
.
table_size
=
1000
self
.
emb_size
=
64
self
.
lookup_table_name
=
'shared_w'
def
emb_pool
(
ids
):
emb
=
fluid
.
layers
.
embedding
(
input
=
ids
,
size
=
[
self
.
table_size
,
self
.
emb_size
],
dtype
=
'float32'
,
param_attr
=
'shared_w'
,
# share parameter
param_attr
=
self
.
lookup_table_name
,
# share parameter
is_sparse
=
is_sparse
,
is_distributed
=
is_distributed
)
pool
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'average'
)
...
...
@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
pserver1
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
self
.
assertTrue
(
self
.
transpiler
.
has_distributed_lookup_table
)
lookup_table_var
=
pserver1
.
global_block
().
vars
[
...
...
@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
self
.
assertEqual
(
row_size
,
calc_row_size
)
class
TestDistArgsInProgram
(
TestDistLookupTableBase
):
def
net_conf
(
self
):
self
.
network_with_table
(
is_sparse
=
True
,
is_distributed
=
True
)
def
transpiler_test_impl
(
self
):
trainer
,
_
=
self
.
get_trainer
()
self
.
assertTrue
(
trainer
.
_is_distributed
)
self
.
assertTrue
(
trainer
.
_is_chief
)
self
.
assertEqual
(
trainer
.
_distributed_lookup_table
,
self
.
lookup_table_name
)
self
.
assertEqual
(
trainer
.
_endpoints
,
[
self
.
pserver1_ep
,
self
.
pserver2_ep
])
class
TestRMSPropOptimizer
(
TranspilerTest
):
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
...
...
@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
optimizer
=
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.1
)
optimizer
.
minimize
(
avg_cost
)
return
def
transpiler_test_impl
(
self
):
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
)
...
...
@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest):
self
.
assertEqual
(
moment_var
.
shape
,
(
500
,
1000
))
class
TestLoadSliceVar
(
TranspilerTest
):
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1000
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'fc_w'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'fc_b'
))
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
optimizer
=
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.1
)
optimizer
.
minimize
(
avg_cost
)
def
transpiler_test_impl
(
self
):
pserver
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
)
pserver2
,
_
=
self
.
get_pserver
(
self
.
pserver2_ep
)
self
.
assertTrue
(
pserver
.
_slice_vars_and_attrs
)
self
.
assertTrue
(
pserver2
.
_slice_vars_and_attrs
)
for
idx
in
xrange
(
len
(
pserver
.
_slice_vars_and_attrs
)):
self
.
assertEqual
(
pserver
.
_slice_vars_and_attrs
[
idx
][
0
],
pserver2
.
_slice_vars_and_attrs
[
idx
][
0
])
total_numel
=
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
0
].
shape
)
self
.
assertEqual
(
total_numel
,
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
)
+
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver2
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
cbc6e6eb
...
...
@@ -215,6 +215,13 @@ class DistributeTranspiler(object):
for
param_var
,
grad_var
in
self
.
params_grads
:
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
self
.
origin_program
.
_is_chief
=
self
.
trainer_id
==
0
self
.
origin_program
.
_distributed_lookup_table
=
self
.
table_name
if
self
.
table_name
else
None
# split and create vars, then put splited vars in dicts for later use.
# step 1: split and create vars, then put splited vars in dicts for later use.
self
.
_init_splited_vars
()
...
...
@@ -590,6 +597,8 @@ class DistributeTranspiler(object):
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
pserver_program
,
table_opt_block
.
idx
)
pserver_program
.
_distributed_lookup_table
=
self
.
table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if
self
.
has_distributed_lookup_table
:
...
...
@@ -616,6 +625,10 @@ class DistributeTranspiler(object):
outputs
=
{},
attrs
=
attrs
)
# add distributed attrs
pserver_program
.
_slice_vars_and_attrs
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
pserver_program
.
_sync_with_cpp
()
return
pserver_program
...
...
@@ -689,8 +702,31 @@ class DistributeTranspiler(object):
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
all_attrs
())
# add slice vars
s_prog
.
_slice_vars_and_attrs
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
return
s_prog
def
_get_slice_vars_and_attrs
(
self
,
endpoint
):
slice_vars_and_attrs
=
[]
block_suffix
=
"block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
orig_var_name
,
block_name
,
_
=
self
.
_get_varname_parts
(
param
.
name
)
if
not
block_name
:
continue
block_idx
=
int
(
block_name
.
split
(
block_suffix
)[
1
])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_numel
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_attrs
.
append
([
orig_var
,
skip_numel
,
param
])
return
slice_vars_and_attrs
# ====================== private transpiler functions =====================
def
_has_distributed_lookup_table
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录