提交 a764e320 编写于 作者: Y Yu Yang

fix(Trainer): make mt supports trainer

上级 da161514
...@@ -11,10 +11,21 @@ ...@@ -11,10 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as pd import paddle.fluid.layers as pd
import os import os
import sys
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
dict_size = 30000 dict_size = 30000
source_dict_dim = target_dict_dim = dict_size source_dict_dim = target_dict_dim = dict_size
...@@ -105,15 +116,15 @@ def train(use_cuda): ...@@ -105,15 +116,15 @@ def train(use_cuda):
] ]
def event_handler(event): def event_handler(event):
if isinstance(event, fluid.EndStepEvent): if isinstance(event, EndStepEvent):
if event.step % 10 == 0: if event.step % 10 == 0:
print('pass_id=' + str(event.epoch) + ' batch=' + str( print('pass_id=' + str(event.epoch) + ' batch=' + str(
event.step)) event.step))
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, EndEpochEvent):
trainer.save_params(model_save_dir) trainer.save_params(model_save_dir)
trainer = fluid.Trainer( trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func) train_func=train_program, place=place, optimizer_func=optimizer_func)
trainer.train( trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册