未验证 提交 f69a0b5e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add an argument to enable the use of experimental feature, fusion_group. (#4252)

test=develop
上级 1fbb0875
...@@ -59,6 +59,12 @@ def parse_args(): ...@@ -59,6 +59,12 @@ def parse_args():
type=str2bool, type=str2bool,
default=False, default=False,
help='Whether profiling the trainning [True|False]') help='Whether profiling the trainning [True|False]')
parser.add_argument(
'--enable_auto_fusion',
type=str2bool,
default=False,
help='Whether enable fusion_group [True|False]. It is a experimental feature.'
)
parser.add_argument( parser.add_argument(
'--use_dataloader', '--use_dataloader',
type=str2bool, type=str2bool,
...@@ -80,8 +86,12 @@ def parse_args(): ...@@ -80,8 +86,12 @@ def parse_args():
parser.add_argument('--enable_ce', action='store_true') parser.add_argument('--enable_ce', action='store_true')
parser.add_argument('--batch_size', type=int, default=0, help='batch size') parser.add_argument('--batch_size', type=int, default=0, help='batch size')
parser.add_argument('--max_epoch', type=int, default=0, help='max epoch') parser.add_argument('--max_epoch', type=int, default=0, help='max epoch')
# NOTE: args for profiler, used for benchmark # NOTE: args for profiler, used for benchmark
parser.add_argument('--profiler_path', type=str, default='/tmp/paddingrnn.profile', help='the profiler output file path. used for benchmark') parser.add_argument(
'--profiler_path',
type=str,
default='/tmp/paddingrnn.profile',
help='the profiler output file path. used for benchmark')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -191,6 +191,12 @@ def main(): ...@@ -191,6 +191,12 @@ def main():
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = True build_strategy.fuse_all_optimizer_ops = True
try:
fluid.require_version(min_version='1.7.0')
build_strategy.enable_auto_fusion = args.enable_auto_fusion
except Exception as e:
logger.info("PaddlePaddle version 1.7.0 or higher is "
"required when you want to enable fusion_group.")
if args.parallel: if args.parallel:
train_program = fluid.compiler.CompiledProgram( train_program = fluid.compiler.CompiledProgram(
...@@ -438,32 +444,35 @@ def main(): ...@@ -438,32 +444,35 @@ def main():
print("ptblm\tlstm_language_model_%s_loss_card%d\t%s" % print("ptblm\tlstm_language_model_%s_loss_card%d\t%s" %
(args.rnn_model, device_count, train_ppl[0])) (args.rnn_model, device_count, train_ppl[0]))
# NOTE(zjl): sometimes we have not enough data for eval if batch_size is large, i.e., 2100 if not args.profile:
# Just skip to avoid error # NOTE(zjl): sometimes we have not enough data for eval if batch_size is large, i.e., 2100
def is_valid_data(data, batch_size, num_steps): # Just skip to avoid error
data_len = len(data) def is_valid_data(data, batch_size, num_steps):
batch_len = data_len // batch_size data_len = len(data)
epoch_size = (batch_len - 1) // num_steps batch_len = data_len // batch_size
return epoch_size >= 1 epoch_size = (batch_len - 1) // num_steps
return epoch_size >= 1
valid_data_valid = is_valid_data(valid_data, config.batch_size,
config.num_steps) valid_data_valid = is_valid_data(valid_data, config.batch_size,
if valid_data_valid: config.num_steps)
valid_ppl = eval(valid_data) if valid_data_valid:
print("Valid ppl: %.5f" % valid_ppl[0]) valid_ppl = eval(valid_data)
else: print("Valid ppl: %.5f" % valid_ppl[0])
print( else:
'WARNING: length of valid_data is {}, which is not enough for batch_size {} and num_steps {}'. print(
format( 'WARNING: length of valid_data is {}, which is not enough for batch_size {} and num_steps {}'.
len(valid_data), config.batch_size, config.num_steps)) format(
len(valid_data), config.batch_size,
save_model_dir = os.path.join(args.save_model_dir, str(epoch_id)) config.num_steps))
if not os.path.exists(save_model_dir):
mkpath(save_model_dir) save_model_dir = os.path.join(args.save_model_dir,
save_model_dir = os.path.join(save_model_dir, 'params') str(epoch_id))
if not os.path.exists(save_model_dir):
fluid.save(main_program, save_model_dir) mkpath(save_model_dir)
print("Saved model to: %s.\n" % save_model_dir) save_model_dir = os.path.join(save_model_dir, 'params')
fluid.save(main_program, save_model_dir)
print("Saved model to: %s.\n" % save_model_dir)
with profile_context(args.profile, args.profiler_path): with profile_context(args.profile, args.profiler_path):
train() train()
......
...@@ -190,38 +190,9 @@ def lm_model(hidden_size, ...@@ -190,38 +190,9 @@ def lm_model(hidden_size,
gate_input = layers.elementwise_add(gate_input, bias) gate_input = layers.elementwise_add(gate_input, bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
try: c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
from paddle.fluid.contrib.layers import fused_elemwise_activation i) * layers.tanh(j)
# fluid.contrib.layers.fused_elemwise_activation can do a fused m = layers.tanh(c) * layers.sigmoid(o)
# operation, like:
# 1) x + sigmoid(y); x + tanh(y)
# 2) tanh(x + y)
# Now the unary operation supported in this fused op is limit, and
# we will extent this operation to support more unary operations and
# do this kind of fusion automitically in future version of paddle.fluid.
# layers.sigmoid(i) * layers.tanh(j)
tmp0 = fused_elemwise_activation(
x=layers.tanh(j),
y=i,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
# pre_cell * layers.sigmoid(f)
tmp1 = fused_elemwise_activation(
x=pre_cell,
y=f,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
c = tmp0 + tmp1
# layers.tanh(c) * layers.sigmoid(o)
m = fused_elemwise_activation(
x=layers.tanh(c),
y=o,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
except ImportError:
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
hidden_array[k] = m hidden_array[k] = m
cell_array[k] = c cell_array[k] = c
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册