diff --git a/PaddleNLP/neural_machine_translation/transformer/train.py b/PaddleNLP/neural_machine_translation/transformer/train.py index f284c9c6a0b547d5c119232d1dab76de3dbd1064..aa6a99353b4ffc4c46ff5efc46e8e572a5ec33d9 100644 --- a/PaddleNLP/neural_machine_translation/transformer/train.py +++ b/PaddleNLP/neural_machine_translation/transformer/train.py @@ -4,6 +4,10 @@ import copy import logging import multiprocessing import os + +if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None: + os.environ['FLAGS_eager_delete_tensor_gb'] = '0' + import six import sys sys.path.append("../../") @@ -720,9 +724,6 @@ def train(args): optimizer = fluid.optimizer.SGD(0.003) optimizer.minimize(avg_cost) - if args.use_mem_opt: - fluid.memory_optimize(train_prog) - if args.local: logging.info("local start_up:") train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, @@ -806,4 +807,4 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) args = parse_args() - train(args) \ No newline at end of file + train(args)