未验证 提交 cbc6e6eb 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #12247 from seiriosPlus/dis_ckpt_fix

add load slice_vars in io.py
...@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', ' ...@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', '
paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
......
...@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); "when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); // TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name; << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true; return true;
} }
......
...@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx); framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
} }
}; };
......
...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase { ...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool
......
...@@ -1363,6 +1363,13 @@ class Program(object): ...@@ -1363,6 +1363,13 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = [] self._op_role_var = []
# for distribute
self._is_distributed = False
self._is_chief = False
self._slice_vars_and_attrs = []
self._endpoints = []
self._distributed_lookup_table = None
@property @property
def op_role(self): def op_role(self):
""" """
......
...@@ -372,6 +372,7 @@ def load_vars(executor, ...@@ -372,6 +372,7 @@ def load_vars(executor,
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=dirname,
main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
else: else:
...@@ -403,9 +404,12 @@ def load_vars(executor, ...@@ -403,9 +404,12 @@ def load_vars(executor,
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
...@@ -659,11 +663,19 @@ def save_inference_model(dirname, ...@@ -659,11 +663,19 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename) save_persistables(executor, dirname, inference_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
def load_inference_model(dirname, def load_inference_model(dirname,
executor, executor,
model_filename=None, model_filename=None,
params_filename=None): params_filename=None,
pserver_endpoints=None):
""" """
Load inference model from a directory Load inference model from a directory
...@@ -679,6 +691,10 @@ def load_inference_model(dirname, ...@@ -679,6 +691,10 @@ def load_inference_model(dirname,
parameters were saved in a single binary parameters were saved in a single binary
file. If parameters were saved in separate file. If parameters were saved in separate
files, set it as 'None'. files, set it as 'None'.
pserver_endpoints(list|None): This only need by distributed inference.
When use distributed look up table in training,
We also need it in inference.The parameter is
a list of pserver endpoints.
Returns: Returns:
tuple: The return of this function is a tuple with three elements: tuple: The return of this function is a tuple with three elements:
...@@ -697,12 +713,16 @@ def load_inference_model(dirname, ...@@ -697,12 +713,16 @@ def load_inference_model(dirname,
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
path = "./infer_model" path = "./infer_model"
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[inference_program, feed_target_names, fetch_targets] = [inference_program, feed_target_names, fetch_targets] =
fluid.io.load_inference_model(dirname=path, executor=exe) fluid.io.load_inference_model(dirname=path, executor=exe)
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img}, feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets) fetch_list=fetch_targets)
# if we need lookup table, we will use:
fluid.io.load_inference_model(dirname=path, executor=exe, pserver_endpoints=endpoints)
# In this exsample, the inference program was saved in the # In this exsample, the inference program was saved in the
# "./infer_model/__model__" and parameters were saved in # "./infer_model/__model__" and parameters were saved in
# separate files in ""./infer_model". # separate files in ""./infer_model".
...@@ -729,6 +749,9 @@ def load_inference_model(dirname, ...@@ -729,6 +749,9 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints)
feed_target_names = program.desc.get_feed_target_names() feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names() fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [ fetch_targets = [
...@@ -738,6 +761,61 @@ def load_inference_model(dirname, ...@@ -738,6 +761,61 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets] return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.has_attr(ENDPOINT_MAP):
op.set_attr(ENDPOINT_MAP, endpoints)
program._sync_with_cpp()
return program
def get_parameter_value(para, executor): def get_parameter_value(para, executor):
""" """
Get the LoDTensor value of the given parameter. Get the LoDTensor value of the given parameter.
...@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
...@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase): ...@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def get_main_program(self): def get_main_program(self):
main = fluid.Program() main = fluid.Program()
...@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase): ...@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase):
def test_transpiler(self): def test_transpiler(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.unique_name.guard():
self.transpiler_test_impl() with fluid.program_guard(main, startup):
self.transpiler_test_impl()
class TestBasicModel(TranspilerTest): class TestBasicModel(TranspilerTest):
...@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest): ...@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest):
decay_rate=0.1, decay_rate=0.1,
staircase=True)) staircase=True))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest): ...@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest):
learning_rate=fluid.layers.piecewise_decay([10000, 20000], learning_rate=fluid.layers.piecewise_decay([10000, 20000],
[1.0, 0.5, 1.0])) [1.0, 0.5, 1.0]))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest): ...@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest):
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, is_sparse, is_distributed): def network_with_table(self, is_sparse, is_distributed):
self.table_size = 1000 self.table_size = 1000
self.emb_size = 64 self.emb_size = 64
self.lookup_table_name = 'shared_w'
def emb_pool(ids): def emb_pool(ids):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=ids, input=ids,
size=[self.table_size, self.emb_size], size=[self.table_size, self.emb_size],
dtype='float32', dtype='float32',
param_attr='shared_w', # share parameter param_attr=self.lookup_table_name, # share parameter
is_sparse=is_sparse, is_sparse=is_sparse,
is_distributed=is_distributed) is_distributed=is_distributed)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average') pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
...@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): ...@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) pserver1, _ = self.get_pserver(self.pserver1_ep, config)
self.assertTrue(self.transpiler.has_distributed_lookup_table) self.assertTrue(self.transpiler.has_distributed_lookup_table)
lookup_table_var = pserver1.global_block().vars[ lookup_table_var = pserver1.global_block().vars[
...@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): ...@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
self.assertEqual(row_size, calc_row_size) self.assertEqual(row_size, calc_row_size)
class TestDistArgsInProgram(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=True)
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
self.assertTrue(trainer._is_distributed)
self.assertTrue(trainer._is_chief)
self.assertEqual(trainer._distributed_lookup_table,
self.lookup_table_name)
self.assertEqual(trainer._endpoints,
[self.pserver1_ep, self.pserver2_ep])
class TestRMSPropOptimizer(TranspilerTest): class TestRMSPropOptimizer(TranspilerTest):
def net_conf(self): def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32') x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
...@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest): ...@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1) optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest): ...@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest):
self.assertEqual(moment_var.shape, (500, 1000)) self.assertEqual(moment_var.shape, (500, 1000))
class TestLoadSliceVar(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) + reduce(
lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -215,6 +215,13 @@ class DistributeTranspiler(object): ...@@ -215,6 +215,13 @@ class DistributeTranspiler(object):
for param_var, grad_var in self.params_grads: for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name self.param_name_to_grad_name[param_var.name] = grad_var.name
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._is_chief = self.trainer_id == 0
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
# split and create vars, then put splited vars in dicts for later use.
# step 1: split and create vars, then put splited vars in dicts for later use. # step 1: split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars() self._init_splited_vars()
...@@ -590,6 +597,8 @@ class DistributeTranspiler(object): ...@@ -590,6 +597,8 @@ class DistributeTranspiler(object):
checkpoint_block_id = self._create_checkpoint_save_block( checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx) pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -616,6 +625,10 @@ class DistributeTranspiler(object): ...@@ -616,6 +625,10 @@ class DistributeTranspiler(object):
outputs={}, outputs={},
attrs=attrs) attrs=attrs)
# add distributed attrs
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
endpoint)
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
return pserver_program return pserver_program
...@@ -689,8 +702,31 @@ class DistributeTranspiler(object): ...@@ -689,8 +702,31 @@ class DistributeTranspiler(object):
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.all_attrs()) attrs=op.all_attrs())
# add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
return s_prog return s_prog
def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = []
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name)
if not block_name:
continue
block_idx = int(block_name.split(block_suffix)[1])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_attrs.append([orig_var, skip_numel, param])
return slice_vars_and_attrs
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self): def _has_distributed_lookup_table(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册