未验证 提交 304d7815 编写于 作者: C ceci3 提交者: GitHub

optimize auto compress (#1550)

* support sub block

* support post-precess

* update

* fix

* add unittest

* revert prune

* fix unittest

* add unittest
上级 e72ce197
......@@ -550,7 +550,7 @@ class AutoCompression:
train_program_info = self._compiled_program(train_program_info,
strategy)
test_program_info = self._compiled_program(test_program_info,
self._strategy)
strategy)
return train_program_info, test_program_info
def _compiled_program(self, program_info, strategy):
......
......@@ -84,6 +84,13 @@ def _create_optimizer(train_config):
return opt, lr
def _find_var_from_program(program, var_name):
for block in program.blocks:
if block.has_var(var_name):
return block.var(var_name)
raise ValueError("var {} not in this program".format(var_name))
def _get_distill_node(student_program, config):
node = config.get('node')
if len(node) == 0:
......@@ -95,7 +102,7 @@ def _get_distill_node(student_program, config):
else:
test_node = node[0]
try:
test_var = student_program.global_block().var(test_node)
test_var = _find_var_from_program(student_program, test_node)
distill_node_pair = []
if isinstance(node[0], list):
for n_list in node:
......@@ -113,6 +120,14 @@ def _get_distill_node(student_program, config):
return node
def _get_target_node(distill_node):
targets = []
for idx, node in enumerate(distill_node):
if idx % 2 != 0:
targets.append(node)
return targets
def _parse_distill_loss(distill_node_pair,
distill_loss='l2',
distill_lambda=1.0):
......@@ -149,6 +164,7 @@ def _load_program_and_merge(executor,
model_dir,
model_filename,
params_filename,
distill_node_pair,
teacher_idx=None,
feed_target_names=None):
scope = paddle.static.global_scope()
......@@ -171,8 +187,8 @@ def _load_program_and_merge(executor,
_remove_fetch_node(teacher_program)
if teacher_idx == None or teacher_idx == 1:
test_program = train_program.clone(for_test=True)
target_nodes = _get_target_node(distill_node_pair)
teacher_program = teacher_program._prune(target_nodes)
data_name_map = {}
......@@ -196,9 +212,9 @@ def _load_program_and_merge(executor,
name_prefix=teacher_name_prefix,
merge_feed=merge_feed)
if teacher_idx == None or teacher_idx == 1:
return train_program, test_program, data_name_map
return train_program, data_name_map
else:
return train_program, None, data_name_map
return train_program, data_name_map
def build_distill_program(executor,
......@@ -224,6 +240,38 @@ def build_distill_program(executor,
distill_node_pair = _get_distill_node(train_program,
config) or default_distill_node_pair
test_program = train_program.clone(for_test=True)
target_nodes = _get_target_node(distill_node_pair)
def _prepend_feed(block, feed_idx, feed_target_names):
for idx in feed_idx[::-1]:
block._remove_op(idx)
feed_var = block.create_var(
name='feed',
type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True, )
for i, name in enumerate(feed_target_names):
out = block.var(name)
block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
judge_feed_pos = False
if train_program.desc.block(0).op(0).type() != 'feed':
judge_feed_pos = True
if judge_feed_pos:
feed_idx = []
for op in train_program.global_block().ops:
if op.type == 'feed':
feed_idx.append(op.idx)
_prepend_feed(train_program.global_block(), feed_idx, feed_target_names)
train_program = train_program._prune(target_nodes)
teacher_model_dir = config[
"teacher_model_dir"] if "teacher_model_dir" in config else config[
"teacher_model_path_prefix"]
......@@ -234,7 +282,7 @@ def build_distill_program(executor,
params_filename = config["teacher_params_filename"][
tea_idx] if "teacher_params_filename" in config else None
if tea_idx == 0:
train_program, test_program, data_name_map = _load_program_and_merge(
train_program, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
......@@ -242,10 +290,11 @@ def build_distill_program(executor,
teacher_model_dir[tea_idx],
model_filename,
params_filename,
distill_node_pair,
teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names)
else:
train_program, _, data_name_map = _load_program_and_merge(
train_program, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
......@@ -253,6 +302,7 @@ def build_distill_program(executor,
teacher_model_dir[tea_idx],
model_filename,
params_filename,
distill_node_pair,
teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names)
......@@ -261,7 +311,7 @@ def build_distill_program(executor,
"teacher_model_filename"] if "teacher_model_filename" in config else None
params_filename = config[
"teacher_params_filename"] if "teacher_params_filename" in config else None
train_program, test_program, data_name_map = _load_program_and_merge(
train_program, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
......@@ -269,6 +319,7 @@ def build_distill_program(executor,
teacher_model_dir,
model_filename,
params_filename,
distill_node_pair,
teacher_idx=None,
feed_target_names=feed_target_names)
# all feed node should set stop_gradient is False, for using pact quant algo.
......@@ -479,7 +530,7 @@ def build_prune_program(executor,
place=place)
_logger.info(
"####################channel pruning##########################")
for param in pruned_program.global_block().all_parameters():
for param in pruned_program.all_parameters():
if param.name in original_shapes:
_logger.info("{}, from {} to {}".format(
param.name, original_shapes[param.name], param.shape))
......
......@@ -19,7 +19,7 @@ from ..core import GraphWrapper
from ..common import get_logger
from ..common.recover_program import recover_inference_program
from ..common.transformer_pattern import preprocess_transformer_patterns
from ..common.patterns_common import is_dynamic_weight_op
from ..common.patterns_common import has_trainable_var
_logger = get_logger(__name__, level=logging.INFO)
......@@ -297,7 +297,7 @@ class TransformerPruner:
tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)) and head_num == -1:
not has_trainable_var(op)) and head_num == -1:
inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1]
......
......@@ -30,7 +30,7 @@ def find_final_nodes(program):
final_nodes = []
graph = GraphWrapper(program)
for op in sorted(graph.ops()):
if op.type() in ALL_WEIGHT_OP and is_output_weight_ops(op, graph):
if has_trainable_var(op) and is_final_op_with_trainable_var(op, graph):
n_op = has_bias(op, graph)
if n_op is not None:
final_nodes.extend(n_op.all_outputs())
......@@ -52,7 +52,7 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
matmul_num = 0
for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op):
if not has_trainable_var(op):
matmul_num += 1
if matmul_num == 2:
return True
......@@ -68,7 +68,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
act_num = 0
for op in pattern_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
if is_dynamic_weight_op(op):
if has_trainable_var(op):
linear_num += 1
if op.type() in ['relu', 'gelu']:
act_num += 1
......
......@@ -39,7 +39,7 @@ def find_weight_op(op, graph):
""" Find operators with weight."""
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if is_dynamic_weight_op(next_op):
if has_trainable_var(next_op):
return next_op
else:
return find_weight_op(next_op, graph)
......@@ -56,25 +56,24 @@ def get_weight(op, return_name=True):
return None
def is_dynamic_weight_op(op):
def has_trainable_var(op):
""" Judge whether the operator with trainable variable """
weight_ops = ALL_WEIGHT_OP
if op.type() in weight_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
for inp in sorted(op.all_inputs()):
if inp._var.persistable == True:
return True
return False
return True
for inp in sorted(op.all_inputs()):
if inp._var.persistable == True:
return True
return False
return False
def is_output_weight_ops(op, graph):
def is_final_op_with_trainable_var(op, graph):
""" Judge whether is the final op with weights in the graph """
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if is_dynamic_weight_op(next_op):
if has_trainable_var(next_op):
return False
return is_output_weight_ops(next_op, graph)
return is_final_op_with_trainable_var(next_op, graph)
return True
......
......@@ -31,7 +31,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
continue
next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and is_dynamic_weight_op(next_op):
] and has_trainable_var(next_op):
if block_num not in params_dict:
params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)]
......
......@@ -17,6 +17,30 @@ import paddle
from paddleslim.core import GraphWrapper
def _find_var_from_program(program, var_name):
for block in program.blocks:
if block.has_var(var_name):
return block.var(var_name)
raise ValueError("var {} not in this program".format(var_name))
def _except_feed_fetch(var_name, merge_feed):
if var_name != 'fetch' and (not merge_feed or var_name != 'feed'):
return True
return False
def _is_same_block(block1, block2):
if len(block1.ops) != len(block2.ops):
return False
for op1, op2 in zip(block1.ops, block2.ops):
if op1.type != op2.type:
return False
return True
def merge(teacher_program,
student_program,
data_name_map,
......@@ -52,55 +76,127 @@ def merge(teacher_program,
if teacher_scope == None:
teacher_scope = scope
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
old_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
is_same_model = True
if len(student_program.blocks) == len(teacher_program.blocks):
for block in teacher_program.blocks:
if not _is_same_block(block, student_program.block(block.idx)):
is_same_model = False
break
else:
is_same_model = False
if is_same_model:
for block in student_program.blocks:
for op in block.ops:
if op.type == 'while':
tmp_var = []
for _var_name in op.input('X'):
tmp_var.append('teacher_' + _var_name)
tmp_var.extend(op.input('X'))
op.desc.set_input("X", tmp_var)
for block in teacher_program.blocks:
for teacher_var in list(block.vars.values()):
skip_rename = False
if _except_feed_fetch(teacher_var.name, merge_feed):
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
old_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
# program var rename
renamed_var = block._rename_var(teacher_var.name, new_name)
### input and output of the sub_block need to rename specially.
for op in block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if _except_feed_fetch(
in_var_name,
merge_feed) and not block.has_var(in_var_name):
if in_var_name in data_name_map.keys():
new_name = data_name_map[in_var_name]
if new_name != in_var_name:
op._rename_input(in_var_name,
name_prefix + in_var_name)
else:
op._rename_input(in_var_name,
name_prefix + in_var_name)
for oname in op.output_names:
for out_var_name in op.output(oname):
if _except_feed_fetch(
out_var_name,
merge_feed) and not block.has_var(out_var_name):
if out_var_name in data_name_map.keys():
new_name = data_name_map[out_var_name]
if new_name != out_var_name:
op._rename_output(out_var_name,
name_prefix + out_var_name)
else:
op._rename_output(out_var_name,
name_prefix + out_var_name)
for block in teacher_program.blocks:
for teacher_var in list(block.vars.values()):
if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
# student program add var
if len(student_program.blocks) > 1 and is_same_model:
new_var = student_program.block(block.idx)._clone_variable(
teacher_var, force_persistable=False)
else:
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
for block in reversed(teacher_program.blocks):
for op_idx, op in enumerate(block.ops):
if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
inputs = {}
outputs = {}
attrs = {}
for input_name in op.input_names:
inputs[input_name] = [
block.var(in_var_name)
for in_var_name in op.input(input_name)
]
inputs[input_name] = []
for in_var_name in op.input(input_name):
inputs[input_name].append(
block._find_var_recursive(in_var_name))
for output_name in op.output_names:
outputs[output_name] = [
block.var(out_var_name)
for out_var_name in op.output(output_name)
]
outputs[output_name] = []
for out_var_name in op.output(output_name):
outputs[output_name].append(
block._find_var_recursive(out_var_name))
for attr_name in op.attr_names:
attrs[attr_name] = op.attr(attr_name)
student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
if attr_name == 'sub_block':
attrs[attr_name] = student_program.block(
op._block_attr("sub_block").idx)
else:
attrs[attr_name] = op.attr(attr_name)
if len(student_program.blocks) > 1 and is_same_model:
student_program.block(op.block.idx)._insert_op(
2 * op_idx,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
else:
student_program.global_block().append_op(
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
student_program._sync_with_cpp()
student_graph = GraphWrapper(student_program)
for op in student_graph.ops():
......@@ -137,10 +233,10 @@ def fsp(teacher_var1_name,
"""
if program == None:
program = paddle.static.default_main_program()
teacher_var1 = program.global_block().var(teacher_var1_name)
teacher_var2 = program.global_block().var(teacher_var2_name)
student_var1 = program.global_block().var(student_var1_name)
student_var2 = program.global_block().var(student_var2_name)
teacher_var1 = _find_var_from_program(program, teacher_var1_name)
teacher_var2 = _find_var_from_program(program, teacher_var2_name)
student_var1 = _find_var_from_program(program, student_var1_name)
student_var2 = _find_var_from_program(program, student_var2_name)
teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
teacher_var2)
student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
......@@ -165,8 +261,8 @@ def l2(teacher_var_name, student_var_name, program=None):
"""
if program == None:
program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
student_var = _find_var_from_program(program, student_var_name)
teacher_var = _find_var_from_program(program, teacher_var_name)
l2_loss = paddle.mean(
paddle.nn.functional.square_error_cost(student_var, teacher_var))
return l2_loss
......@@ -194,8 +290,8 @@ def soft_label(teacher_var_name,
"""
if program == None:
program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
student_var = _find_var_from_program(program, student_var_name)
teacher_var = _find_var_from_program(program, teacher_var_name)
teacher_var.stop_gradient = True
student_var = paddle.nn.functional.softmax(student_var /
......@@ -225,7 +321,7 @@ def loss(loss_func, program=None, **kwargs):
for item in kwargs.items():
if isinstance(item[1], str):
func_parameters.setdefault(item[0],
program.global_block().var(item[1]))
_find_var_from_program(program, item[1]))
else:
func_parameters.setdefault(item[0], item[1])
loss = loss_func(**func_parameters)
......@@ -297,8 +393,8 @@ def dkd(teacher_var_name,
"""
if program == None:
program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
student_var = _find_var_from_program(program, student_var_name)
teacher_var = _find_var_from_program(program, teacher_var_name)
return _dkd_loss(
student_var,
teacher_var,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -34,7 +32,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.patterns_common import is_dynamic_weight_op, get_weight
from ..common.patterns_common import has_trainable_var, get_weight
from ..core.graph_wrapper import GraphWrapper
_logger = get_logger(__name__, level=logging.INFO)
......@@ -341,7 +339,7 @@ def quant_aware(program,
for op in ops:
if op._op.type in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)):
not has_trainable_var(op)):
input_names = op._op.input_arg_names
for input_name in input_names:
pre_op = find_pre_ops(program, input_name)[0]
......@@ -387,6 +385,7 @@ def quant_aware(program,
scale_dict = post_training_quantization._scale_dict
else:
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()]
transform_pass_ops = []
quant_dequant_ops = []
for op_type in config['quantize_op_types']:
......@@ -416,7 +415,8 @@ def quant_aware(program,
executor=executor,
is_test=is_test)
transform_pass.apply(main_graph)
for sub_graph in sub_graphs:
transform_pass.apply(sub_graph)
if len(quant_dequant_ops) > 0:
qdq_func = 'AddQuantDequantPassV2' if config[
......@@ -430,7 +430,8 @@ def quant_aware(program,
quantizable_op_type=quant_dequant_ops,
is_test=is_test)
quant_dequant_pass.apply(main_graph)
for sub_graph in sub_graphs:
quant_dequant_pass.apply(sub_graph)
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope,
......@@ -439,16 +440,18 @@ def quant_aware(program,
is_test=is_test,
scale_dict=scale_dict)
out_scale_training_pass.apply(main_graph)
for sub_graph in sub_graphs:
out_scale_training_pass.apply(sub_graph)
if (weight_preprocess_func is not None or
act_preprocess_func is not None) and not for_test:
if (weight_preprocess_func is not None or act_preprocess_func is not None
) and not for_test and not config['onnx_format']:
_logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
)
_logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
for sub_graph in sub_graphs:
save_dict(sub_graph.out_node_mapping_table)
# TDOD: remove it.
if draw_graph:
......@@ -683,18 +686,21 @@ def convert(program,
if config['onnx_format']:
quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph)
for sub_graph in test_graph.all_sub_graphs():
quant_weight_pass.apply(sub_graph)
try:
out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=scope, place=place, quant_bits=config['activation_bits'])
out_scale_infer_pass.apply(test_graph)
for sub_graph in test_graph.all_sub_graphs():
out_scale_infer_pass.apply(sub_graph)
except:
_logger.warning(
"Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0"
)
else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
for sub_graph in test_graph.all_sub_graphs():
out_scale_infer_pass.apply(sub_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
......@@ -705,13 +711,31 @@ def convert(program,
weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph)
for sub_graph in test_graph.all_sub_graphs():
freeze_pass.apply(sub_graph)
freezed_program = test_graph.to_program()
# Move sub blocks persistable var to global block
global_block = freezed_program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = freezed_program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
if save_int8:
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
for sub_graph in test_graph.all_sub_graphs():
convert_int8_pass.apply(sub_graph)
freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8
else:
......
import os
import sys
import unittest
sys.path.append("../../")
import numpy as np
import paddle
from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
class RandomEvalDataset(Dataset):
def __init__(self, num_samples, image_shape=[1, 28, 28], class_num=10):
self.num_samples = num_samples
self.image_shape = image_shape
self.class_num = class_num
def __getitem__(self, idx):
image = np.random.random(self.image_shape).astype('float32')
return image
def __len__(self):
return self.num_samples
class ACTQATWhileOP(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTQATWhileOP, self).__init__(*args, **kwargs)
if not os.path.exists('mnist_while'):
os.system(
"wget -q http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
)
os.system('tar -xzvf mnist_while.tar.gz')
self.create_dataloader()
self.get_config()
def create_dataloader(self):
# define a random dataset
self.eval_dataset = RandomEvalDataset(32)
def get_config(self):
self.config = {
'QuantAware': {},
'Distillation': {},
'TrainConfig': {
'epochs': 1,
'eval_iter': 100,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
def test_demo(self):
image = paddle.static.data(
name='x', shape=[-1, 1, 28, 28], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir="./mnist_while",
model_filename="model.pdmodel",
params_filename="model.pdiparams",
config=self.config,
save_dir="qat_while_output",
train_dataloader=train_loader)
ac.compress()
os.system('rm -rf qat_while_output')
class ACTQATWhileOPCase2(ACTQATWhileOP):
def get_config(self):
self.config = {
'QuantAware': {
'quantize_op_types': ['conv2d', 'mul', 'relu']
},
'Distillation': {},
'TrainConfig': {
'epochs': 1,
'eval_iter': 100,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册