提交 a764e320 编写于 作者: Y Yu Yang

fix(Trainer): make mt supports trainer

上级 da161514
......@@ -11,10 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as pd
import os
import sys
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
......@@ -105,15 +116,15 @@ def train(use_cuda):
]
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if isinstance(event, EndStepEvent):
if event.step % 10 == 0:
print('pass_id=' + str(event.epoch) + ' batch=' + str(
event.step))
if isinstance(event, fluid.EndEpochEvent):
if isinstance(event, EndEpochEvent):
trainer.save_params(model_save_dir)
trainer = fluid.Trainer(
trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func)
trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册