From a764e3202f6f16a8e70b4fcd9142243d60490c85 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 29 Sep 2018 15:12:06 +0800 Subject: [PATCH] fix(Trainer): make mt supports trainer --- 08.machine_translation/train.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/08.machine_translation/train.py b/08.machine_translation/train.py index c2dd5b5..589e4dc 100644 --- a/08.machine_translation/train.py +++ b/08.machine_translation/train.py @@ -11,10 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import paddle import paddle.fluid as fluid import paddle.fluid.layers as pd import os +import sys +try: + from paddle.fluid.contrib.trainer import * + from paddle.fluid.contrib.inferencer import * +except ImportError: + print( + "In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib", + file=sys.stderr) + from paddle.fluid.trainer import * + from paddle.fluid.inferencer import * dict_size = 30000 source_dict_dim = target_dict_dim = dict_size @@ -105,15 +116,15 @@ def train(use_cuda): ] def event_handler(event): - if isinstance(event, fluid.EndStepEvent): + if isinstance(event, EndStepEvent): if event.step % 10 == 0: print('pass_id=' + str(event.epoch) + ' batch=' + str( event.step)) - if isinstance(event, fluid.EndEpochEvent): + if isinstance(event, EndEpochEvent): trainer.save_params(model_save_dir) - trainer = fluid.Trainer( + trainer = Trainer( train_func=train_program, place=place, optimizer_func=optimizer_func) trainer.train( -- GitLab