未验证 提交 da051350 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add cost interface (#47043)

* add cost interface

* update inferface and add unittest

* update unittest

* update inferface
上级 30dae6db
...@@ -44,6 +44,8 @@ class CostEstimator: ...@@ -44,6 +44,8 @@ class CostEstimator:
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {} self._bubble_time_mapping = {}
self._ordered_ops = [] self._ordered_ops = []
self.max_memories = {}
self.max_memory = None
@property @property
def loop_count(self): def loop_count(self):
...@@ -122,7 +124,7 @@ class CostEstimator: ...@@ -122,7 +124,7 @@ class CostEstimator:
for i in range(loop_count): for i in range(loop_count):
for op in ops: for op in ops:
self._detailed_cost[op.desc.id()] = OrderedDict() self._detailed_cost[op.desc.id()] = OrderedDict()
# if in the while sub block, the detail of cost is the last cost # If in the while sub block, the detail of cost is the last cost
detail = self._detailed_cost[op.desc.id()] detail = self._detailed_cost[op.desc.id()]
detail["reshard_cost"] = OrderedDict() # detail["reshard_cost"] = OrderedDict() #
detail["dist_op_cost"] = [] detail["dist_op_cost"] = []
...@@ -146,15 +148,15 @@ class CostEstimator: ...@@ -146,15 +148,15 @@ class CostEstimator:
var = get_var_with_recursion(var_name, block, self.program) var = get_var_with_recursion(var_name, block, self.program)
reshard_cost = resharder.get_cost(op, var, self.cluster) reshard_cost = resharder.get_cost(op, var, self.cluster)
# calc reshard cost # Calc reshard cost
if reshard_cost is not None: if reshard_cost is not None:
detail["reshard_cost"][var_name] = reshard_cost detail["reshard_cost"][var_name] = reshard_cost
comm_costs = reshard_cost[0] comm_costs = reshard_cost[0]
local_comp_cost = reshard_cost[1] local_comp_cost = reshard_cost[1]
for comm_cost in comm_costs: for comm_cost in comm_costs:
# time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost. # Time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
# comm sync # Comm sync
for item in comm_cost: for item in comm_cost:
group_ranks, cost = item group_ranks, cost = item
max_time = None max_time = None
...@@ -182,7 +184,7 @@ class CostEstimator: ...@@ -182,7 +184,7 @@ class CostEstimator:
for comp_cost in local_comp_cost[rank]: for comp_cost in local_comp_cost[rank]:
self.local_cost(rank).time += comp_cost.time self.local_cost(rank).time += comp_cost.time
# calc dist op cost # Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.processes
...@@ -200,7 +202,7 @@ class CostEstimator: ...@@ -200,7 +202,7 @@ class CostEstimator:
continue continue
for item in dist_op_cost: for item in dist_op_cost:
if isinstance(item, list): if isinstance(item, list):
# comm sync # Comm sync
for comm_op_cost in item: for comm_op_cost in item:
max_time = None max_time = None
cost_time = {} cost_time = {}
...@@ -221,9 +223,9 @@ class CostEstimator: ...@@ -221,9 +223,9 @@ class CostEstimator:
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank])
elif isinstance(item, dict): elif isinstance(item, dict):
# op just one # Op just one
for rank in processes: for rank in processes:
# dp+pp+mp # DP+PP+MP
if rank not in item: if rank not in item:
continue continue
self.local_cost(rank).time += item[rank].time self.local_cost(rank).time += item[rank].time
...@@ -266,7 +268,7 @@ class CostEstimator: ...@@ -266,7 +268,7 @@ class CostEstimator:
return result return result
memories = {} memories = {}
max_memories = {} self.max_memories = {}
var_info = { var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} } # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
...@@ -276,6 +278,10 @@ class CostEstimator: ...@@ -276,6 +278,10 @@ class CostEstimator:
self._ordered_ops.sort(key=lambda x: x[0]) self._ordered_ops.sort(key=lambda x: x[0])
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
...@@ -287,7 +293,7 @@ class CostEstimator: ...@@ -287,7 +293,7 @@ class CostEstimator:
input_dims_mapping) input_dims_mapping)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# it is even partition now # It is even partition now
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -325,6 +331,10 @@ class CostEstimator: ...@@ -325,6 +331,10 @@ class CostEstimator:
has_used_vars = set() has_used_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
can_free_memories = {} can_free_memories = {}
can_free_vars = set() can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
...@@ -336,14 +346,14 @@ class CostEstimator: ...@@ -336,14 +346,14 @@ class CostEstimator:
input_dims_mapping) input_dims_mapping)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.processes:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# used # Used
else: else:
if op_id == var_info[var_name][key]["position"][-1]: if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
...@@ -362,14 +372,14 @@ class CostEstimator: ...@@ -362,14 +372,14 @@ class CostEstimator:
output_dims_mapping) output_dims_mapping)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
# not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.processes:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# used # Used
else: else:
if op_id == var_info[var_name][key]["position"][-1]: if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
...@@ -381,21 +391,22 @@ class CostEstimator: ...@@ -381,21 +391,22 @@ class CostEstimator:
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name][key]["memory"]
# calc peak memory # Calc peak memory
for process in memories: for process in memories:
if process not in max_memories: if process not in self.max_memories:
max_memories[process] = memories[process] self.max_memories[process] = memories[process]
else: else:
if memories[process] > max_memories[process]: if memories[process] > self.max_memories[process]:
max_memories[process] = memories[process] self.max_memories[process] = memories[process]
# free memory # Free memory
for process in can_free_memories: for process in can_free_memories:
if process in memories: if process in memories:
memories[process] -= can_free_memories[process] memories[process] -= can_free_memories[process]
# Calculate the max memory in all ranks # Calculate the max memory in all ranks
max_memory = max(max_memories.values()) max_memory = max(self.max_memories.values())
self.max_memory = max_memory
return max_memory return max_memory
...@@ -409,3 +420,143 @@ class CostEstimator: ...@@ -409,3 +420,143 @@ class CostEstimator:
self._estimate_core(dist_context, resharder, block) self._estimate_core(dist_context, resharder, block)
return self.global_cost return self.global_cost
def _print_tag(self, max_len, length):
tag = "+" + "-" * max_len
for i in range(length):
print(tag, end="")
if i == length - 1:
print("+")
def _print_vals(self, vals, max_len):
for idx, val in enumerate(vals):
s = "|" + str(val).center(max_len)
print(s, end="")
if idx == len(vals) - 1:
print("|")
def _pretty_print_memory_cost(self):
"""Print memory of every rank prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate memory cost before print.")
# Padding automatically
max_len = 0
header = ["Rank", "Memory(MiB)"]
memories = [
int(item // 1e6) for item in list(self.max_memories.values())
]
for memory in (memories + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print rank and its memory
for i in range(len(self.max_memories)):
memory = memories[i]
vals = [i, memory]
self._print_vals(vals, max_len)
self._print_tag(max_len, len(header))
def _pretty_print_global(self):
"""Print global execution time and max memory prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate cost before print.")
# Padding automatically
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print exec time and max memory
self._print_vals(vals, max_len)
# Print tag
self._print_tag(max_len, len(header))
def pretty_print_cost(self):
"""Print cost prettily."""
print("The global execution time and max memory are as follows:")
self._pretty_print_global()
print("The memory of every rank is as follows:")
self._pretty_print_memory_cost()
def get_cost_from_engine(engine, mode):
from ..utils import to_list
# Construct cost estimator by original main program
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
serial_startup_prog = engine._serial_startup_progs[mode].clone(
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone(
)
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {},
{"loss": losses}, engine._cluster,
engine._strategy)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog,
serial_startup_prog,
serial_loss)
# Generate optimizer
optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer,
params_grads)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
# Print the cost
cost_estimator.pretty_print_cost()
return global_cost, max_memory
...@@ -48,6 +48,8 @@ from .dist_context import DistributedContext, get_default_distributed_context ...@@ -48,6 +48,8 @@ from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy from .strategy import Strategy
from .interface import CollectionNames, get_collection from .interface import CollectionNames, get_collection
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
from .utils import initialize_pg_in_full_mode
from .cost.estimate_cost import get_cost_from_engine
class Engine: class Engine:
...@@ -127,12 +129,6 @@ class Engine: ...@@ -127,12 +129,6 @@ class Engine:
"'model must be sub classes of `paddle.nn.Layer` or any callable function." "'model must be sub classes of `paddle.nn.Layer` or any callable function."
) )
self._model = model self._model = model
# if loss and not isinstance(loss,
# paddle.nn.Layer) and not callable(loss):
# raise TypeError(
# "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
# )
self._loss = loss self._loss = loss
if optimizer and not isinstance( if optimizer and not isinstance(
...@@ -201,6 +197,7 @@ class Engine: ...@@ -201,6 +197,7 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._losses = None
self.history = None self.history = None
...@@ -487,6 +484,7 @@ class Engine: ...@@ -487,6 +484,7 @@ class Engine:
outputs = self.program_helper.output_vars outputs = self.program_helper.output_vars
labels = self.program_helper.label_vars labels = self.program_helper.label_vars
losses = self.program_helper.loss_vars losses = self.program_helper.loss_vars
self._losses = losses
metrics = self.program_helper.metric_vars metrics = self.program_helper.metric_vars
self._inputs = inputs self._inputs = inputs
...@@ -512,6 +510,7 @@ class Engine: ...@@ -512,6 +510,7 @@ class Engine:
outputs = to_list(self._model(*inputs)) outputs = to_list(self._model(*inputs))
if mode != "predict" and self._loss: if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
self._losses = losses
if mode != "predict" and (outputs or labels): if mode != "predict" and (outputs or labels):
for metric in self._metrics: for metric in self._metrics:
...@@ -519,6 +518,7 @@ class Engine: ...@@ -519,6 +518,7 @@ class Engine:
to_list(metric.compute(*(outputs + labels)))) to_list(metric.compute(*(outputs + labels))))
else: else:
losses = to_list(self._loss) losses = to_list(self._loss)
self.losses = losses
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation: if not default_ctx.has_annotation:
...@@ -648,7 +648,9 @@ class Engine: ...@@ -648,7 +648,9 @@ class Engine:
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
# NOTE: add the comm init control in the future for auto search if self._strategy.auto_mode == "full":
initialize_pg_in_full_mode(all_process_groups, cur_rank)
else:
for process_group in all_process_groups: for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks: if self._cur_rank not in process_group.ranks:
continue continue
...@@ -1022,16 +1024,9 @@ class Engine: ...@@ -1022,16 +1024,9 @@ class Engine:
test_dataloader = self._prepare_dataloader_from_generator( test_dataloader = self._prepare_dataloader_from_generator(
dataset=test_data, dataset=test_data,
# feed_list=feed_list,
capacity=70, capacity=70,
# use_double_buffer=use_double_buffer,
iterable=False, iterable=False,
# return_list=return_list,
# use_multiprocess=use_multiprocess,
# drop_last=drop_last,
# places=places,
batch_size=batch_size, batch_size=batch_size,
# epochs=epochs,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
...@@ -1100,13 +1095,11 @@ class Engine: ...@@ -1100,13 +1095,11 @@ class Engine:
steps_per_epoch=steps_per_epoch) steps_per_epoch=steps_per_epoch)
return dataloader return dataloader
def dataloader_from_generator( def dataloader_from_generator(self,
self,
dataset, dataset,
capacity=70, capacity=70,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
# return_list=False,
use_multiprocess=False, use_multiprocess=False,
drop_last=True, drop_last=True,
batch_size=1, batch_size=1,
...@@ -1127,14 +1120,12 @@ class Engine: ...@@ -1127,14 +1120,12 @@ class Engine:
self._switch_mode(self._mode) self._switch_mode(self._mode)
dataloader = self._prepare_dataloader_from_generator( dataloader = self._prepare_dataloader_from_generator(
dataset=dataset, dataset=dataset,
# feed_list=feed_list,
capacity=capacity, capacity=capacity,
use_double_buffer=use_double_buffer, use_double_buffer=use_double_buffer,
iterable=iterable, iterable=iterable,
return_list=False, return_list=False,
use_multiprocess=use_multiprocess, use_multiprocess=use_multiprocess,
drop_last=drop_last, drop_last=drop_last,
# places=places,
batch_size=batch_size, batch_size=batch_size,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
...@@ -1187,20 +1178,7 @@ class Engine: ...@@ -1187,20 +1178,7 @@ class Engine:
assert self._inputs_spec and self._labels_spec, \ assert self._inputs_spec and self._labels_spec, \
"Please call the dataloader(...) before calling prepare(...)" "Please call the dataloader(...) before calling prepare(...)"
def run( def run(self, data=None, feed=None, fetch_list=None, mode=None):
self,
data=None,
# program=None,
feed=None,
fetch_list=None,
# feed_var_name='feed',
# fetch_var_name='fetch',
# scope=None,
# return_numpy=True,
# use_program_cache=False,
# return_merged=True,
# use_prune=False,
mode=None):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
feed_dict = self._prepare_feed(data, feed, self._mode) feed_dict = self._prepare_feed(data, feed, self._mode)
...@@ -1571,6 +1549,54 @@ class Engine: ...@@ -1571,6 +1549,54 @@ class Engine:
path, load_optimizer) path, load_optimizer)
return self._state_dict, self._dist_attr return self._state_dict, self._dist_attr
def cost(self, inputs_spec=None, labels_spec=None, mode="train"):
"""
Get and Print cost, including memory of every rank,
max memory among all ranks, and the global cost of one step based on
communication cost(computation cost is 0 by default).
In the future, the flops information of every rank and global cost including
computation cost will be added.
Args:
inputs_spec(InputSpec): The specification of inputs. Default: None.
labels_spec(InputSpec): The specification of labels. Default: None.
mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train".
Returns:
Return the global execution time (ms) and max memory (B).
"""
# Check parallel mode
if self._strategy.auto_mode == "full":
print(
"The cost will be calcudated in the search process when the auto mode is full."
)
return
# Check mode
accepted_modes = ["train", "predict", "eval"]
if mode not in accepted_modes:
raise ValueError("The mode {} is not in accepted modes {}".format(
mode, accepted_modes))
self.to_mode(mode)
if inputs_spec is not None:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
self._build(mode)
self._plan(mode)
else:
if _non_static_mode() or self._dygraph_mode:
raise ValueError(
"Please call `engine._prepare_program('mode')` firstly when in the static graph mode."
)
# Estimate the exec cost and max memory
global_cost, max_memory = get_cost_from_engine(self, mode)
return global_cost.time, max_memory
@property @property
def main_program(self): def main_program(self):
return self._dist_main_progs[self._mode][self._cur_rank] return self._dist_main_progs[self._mode][self._cur_rank]
......
...@@ -1602,3 +1602,65 @@ def get_lr(optimizer): ...@@ -1602,3 +1602,65 @@ def get_lr(optimizer):
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
) )
def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket
from ..collective import _get_global_env
has_recv_by_socket = []
# This is a magic number
magic_num = 500
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups:
if cur_rank not in process_group.ranks:
continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".format(
process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept()
rank = int(client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print("It is able to instantiate {} as recver now.".
format(process_group.ranks))
break
process_group.instantiate()
server_socket.close()
...@@ -141,6 +141,7 @@ def train_high_level(fetch): ...@@ -141,6 +141,7 @@ def train_high_level(fetch):
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
eval_dataset1 = MyDataset(5 * batch_size) eval_dataset1 = MyDataset(5 * batch_size)
history = engine.fit(train_data=train_dataset, history = engine.fit(train_data=train_dataset,
epochs=2, epochs=2,
batch_size=batch_size, batch_size=batch_size,
...@@ -354,9 +355,74 @@ def train_non_builtin_data_vars(): ...@@ -354,9 +355,74 @@ def train_non_builtin_data_vars():
) # call DataLoader.reset() after catching EOFException ) # call DataLoader.reset() after catching EOFException
def get_cost():
main_program = static.default_main_program()
startup_program = static.default_startup_program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
engine.cost()
def get_cost_by_spec():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.cost(mode="eval", inputs_spec=[input_spec], labels_spec=[label_spec])
if __name__ == "__main__": if __name__ == "__main__":
train_high_level(fetch=True) train_high_level(fetch=True)
train_high_level(fetch=False) train_high_level(fetch=False)
train_low_level() train_low_level()
train_builtin_data_vars() train_builtin_data_vars()
train_non_builtin_data_vars() train_non_builtin_data_vars()
get_cost()
get_cost_by_spec()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册