提交 f874e02b 编写于 作者: S sandyhouse

update optimizer

上级 d2c81529
...@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"] ...@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase): class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer): def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer) super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
...@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
...@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase):
self.user_defined_strategy.sharding_configs["parallelism"]) self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[ self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"] "use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
if self.use_pipeline: if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = dict() main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index() // ( pp_rank = self.role_maker._worker_index() // (
...@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_rank'] = self.role_maker._worker_index() 'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2 main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize( optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list) self.pipeline_nodes = len(program_list)
else: else:
...@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy # check op dependecy
check_broadcast(main_block) check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id) # self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id) #check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase):
print("pp_group_endpoints:", self.pp_group_endpoints) print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank) print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id) print("pp_ring_id:", self.pp_ring_id)
self._collective_helper._init_communicator( for pair in self.pipeline_pair:
self._startup_program, self.current_endpoint, if self.pp_rank not in pair: continue
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False) pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
......
...@@ -413,6 +413,8 @@ class Section(DeviceWorker): ...@@ -413,6 +413,8 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
cfg = section_param.section_config cfg = section_param.section_config
program = pipeline_opt["section_program"] program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc() cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
...@@ -3788,6 +3788,7 @@ class PipelineOptimizer(object): ...@@ -3788,6 +3788,7 @@ class PipelineOptimizer(object):
self._op_role_var_key = op_maker.kOpRoleVarAttrName() self._op_role_var_key = op_maker.kOpRoleVarAttrName()
self._op_device_key = op_maker.kOpDeviceAttrName() self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None self._param_device_map = None
self._pipeline_pair = []
def _create_vars(self, block, ori_block): def _create_vars(self, block, ori_block):
# Create vars for block, copied from ori_block # Create vars for block, copied from ori_block
...@@ -4134,6 +4135,7 @@ class PipelineOptimizer(object): ...@@ -4134,6 +4135,7 @@ class PipelineOptimizer(object):
if not var_name in first_block.vars: if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name) self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1]) dev_index = int(device.split(':')[1])
print("dev_index:", dev_index)
first_block._insert_op( first_block._insert_op(
index=insert_index, index=insert_index,
type='send_v2', type='send_v2',
...@@ -4141,9 +4143,11 @@ class PipelineOptimizer(object): ...@@ -4141,9 +4143,11 @@ class PipelineOptimizer(object):
attrs={ attrs={
self._op_device_key: first_dev_spec, self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': dev_index, 'peer': dev_index,
'ring_id': self.ring_id, #'ring_id': self.ring_id,
'ring_id': self.ring_id
if dev_index > first_dev_index else self.ring_id + 2,
}) })
# Get the device that that data on # Get the device that that data on
assert device in devices assert device in devices
...@@ -4168,7 +4172,21 @@ class PipelineOptimizer(object): ...@@ -4168,7 +4172,21 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'peer': first_dev_index, 'peer': first_dev_index,
'use_calc_stream': True, 'use_calc_stream': True,
'ring_id': self.ring_id, #'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index < dev_index else self.ring_id + 2,
})
block._insert_op(
index=index + 1,
type='c_sync_comm_stream',
inputs={'X': [new_var]},
outputs={'Out': [new_var]},
attrs={
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index > dev_index else self.ring_id + 2,
}) })
def _strip_grad_suffix(self, name): def _strip_grad_suffix(self, name):
...@@ -4409,30 +4427,91 @@ class PipelineOptimizer(object): ...@@ -4409,30 +4427,91 @@ class PipelineOptimizer(object):
var = block.vars[var_name] var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1]) prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='send_v2', #type='send_v2',
type='c_broadcast',
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var},
attrs={ attrs={
self._op_device_key: prev_device, self._op_device_key: prev_device,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': cur_device_index, #'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
'ring_id': self.ring_id, 'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='recv_v2', #type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'shape': fill_shape,
'dtype': var.dtype, 'dtype': var.dtype,
self._op_device_key: cur_device, self._op_device_key: cur_device,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'value': float(0.0),
'peer': prev_device_index, })
extra_index += 1
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key: op_role,
'ring_id': self.ring_id, 'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
...@@ -4512,6 +4591,15 @@ class PipelineOptimizer(object): ...@@ -4512,6 +4591,15 @@ class PipelineOptimizer(object):
first_optimize_op_index = None first_optimize_op_index = None
for index, op in reversed(tuple(enumerate(list(block.ops)))): for index, op in reversed(tuple(enumerate(list(block.ops)))):
# device = op.attr(self._op_device_key) # device = op.attr(self._op_device_key)
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if not self._is_optimize_op(op) and not first_optimize_op_index: if not self._is_optimize_op(op) and not first_optimize_op_index:
first_optimize_op_index = index + 1 first_optimize_op_index = index + 1
if block.ops[ if block.ops[
...@@ -4553,11 +4641,11 @@ class PipelineOptimizer(object): ...@@ -4553,11 +4641,11 @@ class PipelineOptimizer(object):
# a trick to run this op once per mini-batch # a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched, self._op_role_key: self._op_role.Optimize.LRSched,
}) })
offset += 1 #offset += 1
grad_name = op_role_var[i + 1] # with _0 suffix grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name] # without _0 suffix grad_var = block.vars[grad_name]
real_grad_name = grad_name[0:grad_name.find( real_grad_name = grad_name[0:grad_name.find(
'@GRAD')] + '@GRAD' '@GRAD')] + '@GRAD' # without _0 suffix
real_grad_var = block.vars[ real_grad_var = block.vars[
real_grad_name] # without _0 suffix real_grad_name] # without _0 suffix
# new_grad_var_name = unique_name.generate(grad_name) # new_grad_var_name = unique_name.generate(grad_name)
...@@ -4567,7 +4655,7 @@ class PipelineOptimizer(object): ...@@ -4567,7 +4655,7 @@ class PipelineOptimizer(object):
# self._rename_arg(op, grad_name, new_grad_var_name) # self._rename_arg(op, grad_name, new_grad_var_name)
if not 'cast_fp16' in grad_name: if not 'cast_fp16' in grad_name:
block._insert_op( block._insert_op(
index=first_optimize_op_index + offset, index=index + 1,
type='sum', type='sum',
inputs={'X': [grad_var, real_grad_var]}, inputs={'X': [grad_var, real_grad_var]},
outputs={'Out': real_grad_var}, outputs={'Out': real_grad_var},
...@@ -4576,58 +4664,83 @@ class PipelineOptimizer(object): ...@@ -4576,58 +4664,83 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Backward, self._op_role_key: self._op_role.Backward,
#self._op_role_var_key: op_role_var #self._op_role_var_key: op_role_var
}) })
offset += 1 #offset += 1
else: else:
grad_name = op_role_var[i + 1] # with _0 suffix grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name] # without _0 suffix grad_var = block.vars[grad_name]
fp32_grad_var_name = param_name + core.grad_var_suffix() fp32_grad_var_name = param_name + core.grad_var_suffix(
) # without _0 suffix
fp32_grad_var = block.vars[fp32_grad_var_name] fp32_grad_var = block.vars[fp32_grad_var_name]
fp32_grad_var.persistable = True fp32_grad_var.persistable = True
cast_grad_var_name = unique_name.generate( cast_grad_var_name = unique_name.generate(
fp32_grad_var_name) fp32_grad_var_name)
cast_var = self._create_var(block, grad_var, cast_grad_var = self._create_var(block, fp32_grad_var,
cast_grad_var_name) cast_grad_var_name)
cast_var.persistable = False cast_grad_var.persistable = False
real_grad_name = grad_name[0:grad_name.find(
'@GRAD')] + '@GRAD'
real_grad_var = block.vars[
real_grad_name] # without _0 suffix
block._insert_op( block._insert_op(
index=first_optimize_op_index + offset, index=index + 1,
type='cast', type='cast',
inputs={'X': fp32_grad_var}, inputs={'X': grad_var},
outputs={'Out': cast_var}, outputs={'Out': cast_grad_var},
attrs={ attrs={
'in_dtype': fp32_grad_var.dtype, 'in_dtype': grad_var.dtype,
'out_dtype': cast_var.dtype, 'out_dtype': cast_grad_var.dtype,
# self._op_device_key: device, # self._op_device_key: device,
self._op_role_key: self._op_role.Backward, self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var # self._op_role_var_key: op_role_var
}) })
offset += 1 offset += 1
block._insert_op( block._insert_op(
index=first_optimize_op_index + offset, index=index + 2,
type='sum', type='sum',
inputs={'X': [grad_var, cast_var]}, inputs={'X': [fp32_grad_var, cast_grad_var]},
outputs={'Out': real_grad_var},
attrs={
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var
})
offset += 1
block._insert_op(
index=first_optimize_op_index + offset,
type='cast',
inputs={'X': real_grad_var},
outputs={'Out': fp32_grad_var}, outputs={'Out': fp32_grad_var},
attrs={ attrs={
'in_dtype': real_grad_var.dtype,
'out_dtype': fp32_grad_var.dtype,
# self._op_device_key: device, # self._op_device_key: device,
self._op_role_key: self._op_role.Backward, self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var # self._op_role_var_key: op_role_var
}) })
offset += 1
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
# real_grad_name] # without _0 suffix
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': real_grad_var},
# attrs={
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': real_grad_var},
# outputs={'Out': fp32_grad_var},
# attrs={
# 'in_dtype': real_grad_var.dtype,
# 'out_dtype': fp32_grad_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
def _add_sub_blocks(self, main_block, program_list): def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program main_program = main_block.program
...@@ -4720,12 +4833,14 @@ class PipelineOptimizer(object): ...@@ -4720,12 +4833,14 @@ class PipelineOptimizer(object):
inputs={'X': write_block.var(var_name), }, inputs={'X': write_block.var(var_name), },
attrs={ attrs={
self._op_device_key: write_device, self._op_device_key: write_device,
'use_calc_stream': True, 'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched, self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index, 'peer': read_dev_index,
'ring_id': self.ring_id, #'ring_id': self.ring_id,
'ring_id': self.ring_id if
read_dev_index > write_dev_index else self.ring_id + 2,
}) })
read_block._insert_op( read_block._insert_op(
index=0, index=0,
...@@ -4735,12 +4850,28 @@ class PipelineOptimizer(object): ...@@ -4735,12 +4850,28 @@ class PipelineOptimizer(object):
'out_shape': read_block.var(var_name).shape, 'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype, 'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device, self._op_device_key: read_device,
'use_calc_stream': True, 'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
self._op_role_key: self._op_role.LRSched, self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index, 'peer': write_dev_index,
'ring_id': self.ring_id, #'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index < read_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=1,
type='c_sync_comm_stream',
inputs={'X': [read_block.var(var_name)]},
outputs={'Out': [read_block.var(var_name)]},
attrs={
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index > read_dev_index else self.ring_id + 2,
}) })
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
...@@ -4809,8 +4940,8 @@ class PipelineOptimizer(object): ...@@ -4809,8 +4940,8 @@ class PipelineOptimizer(object):
program_list = self._split_program(main_program, device_list) program_list = self._split_program(main_program, device_list)
for p in program_list: for p in program_list:
self._create_vars(p["program"].block(0), main_block) self._create_vars(p["program"].block(0), main_block)
self._insert_sendrecv_for_data_var(main_block, program_list, #self._insert_sendrecv_for_data_var(main_block, program_list,
startup_program, device_list) # startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in # Step4: Special Case: process persistable vars that exist in
# multiple sections # multiple sections
...@@ -4824,8 +4955,8 @@ class PipelineOptimizer(object): ...@@ -4824,8 +4955,8 @@ class PipelineOptimizer(object):
place_list = [] place_list = []
for dev in device_list: for dev in device_list:
dev_index = int(dev.split(":")[1]) % 8 dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index)) place_list.append(core.CUDAPlace(dev_index % 8))
# Step6: Split startup program # Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program, new_startup_program = self._split_startup_program(startup_program,
...@@ -4851,6 +4982,8 @@ class PipelineOptimizer(object): ...@@ -4851,6 +4982,8 @@ class PipelineOptimizer(object):
"trainer": "PipelineTrainer", "trainer": "PipelineTrainer",
"device_worker": "Section", "device_worker": "Section",
"inner_parallelism": len(device_list), "inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank,
"section_program": program_list[local_rank], "section_program": program_list[local_rank],
"place": place_list[local_rank], "place": place_list[local_rank],
"place_id": place_id, "place_id": place_id,
...@@ -4858,7 +4991,7 @@ class PipelineOptimizer(object): ...@@ -4858,7 +4991,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches, "num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id, "start_cpu_core_id": self._start_cpu_core_id,
} }
return optimize_ops, params_grads, program_list return optimize_ops, params_grads, program_list, self._pipeline_pair
class RecomputeOptimizer(Optimizer): class RecomputeOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册