未验证 提交 6d5e582d 编写于 作者: H Helin Wang 提交者: GitHub

Merge pull request #10343 from reyoung/feature/new_api_train_impl

A naive implement trainer.train by executor
...@@ -21,8 +21,7 @@ import executor ...@@ -21,8 +21,7 @@ import executor
from executor import * from executor import *
import trainer import trainer
from trainer import Trainer from trainer import *
from trainer import Event
import inferencer import inferencer
from inferencer import Inferencer from inferencer import Inferencer
......
...@@ -50,8 +50,6 @@ def data(name, ...@@ -50,8 +50,6 @@ def data(name,
dtype(int|float): The type of data : float32, float_16, int etc dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR. type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence. lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program(Program): Name of the main program that calls this
startup_program(Program): Name of the startup program
stop_gradient(bool): A boolean that mentions whether gradient should flow. stop_gradient(bool): A boolean that mentions whether gradient should flow.
Returns: Returns:
...@@ -74,13 +72,15 @@ def data(name, ...@@ -74,13 +72,15 @@ def data(name,
if append_batch_size: if append_batch_size:
shape = [-1] + shape # append batch size as -1 shape = [-1] + shape # append batch size as -1
return helper.create_global_variable( data_var = helper.create_global_variable(
name=name, name=name,
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
type=type, type=type,
stop_gradient=stop_gradient, stop_gradient=stop_gradient,
lod_level=lod_level) lod_level=lod_level)
data_var.is_data = True
return data_var
class BlockGuardServ(BlockGuard): class BlockGuardServ(BlockGuard):
......
...@@ -28,7 +28,8 @@ from contextlib import contextmanager ...@@ -28,7 +28,8 @@ from contextlib import contextmanager
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'Adadelta', 'ModelAverage' 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'Adadelta', 'ModelAverage',
'Optimizer'
] ]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import math
import sys
from functools import partial
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 32
def create_random_lodtensor(lod, place, low, high):
# The range of data elements is [low, high]
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
def inference_network(is_sparse):
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64')
embed_first = fluid.layers.embedding(
input=first_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_second = fluid.layers.embedding(
input=second_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_third = fluid.layers.embedding(
input=third_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_forth = fluid.layers.embedding(
input=forth_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
concat_embed = fluid.layers.concat(
input=[embed_first, embed_second, embed_third, embed_forth], axis=1)
hidden1 = fluid.layers.fc(input=concat_embed,
size=HIDDEN_SIZE,
act='sigmoid')
predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax')
return predict_word
def train_network(is_sparse):
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
predict_word = inference_network(is_sparse)
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def train(use_cuda, is_sparse, save_path):
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
def event_handler(event):
print type(event)
if isinstance(event, fluid.EndEpochEvent):
avg_cost = trainer.test(reader=paddle.dataset.imikolov.test(
word_dict, N))
if avg_cost < 5.0:
trainer.params.save(save_path)
return
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")
trainer = fluid.Trainer(
partial(train_network, is_sparse),
fluid.optimizer.SGD(learning_rate=0.001),
place=place)
trainer.train(
reader=train_reader, num_epochs=100, event_handler=event_handler)
def infer(use_cuda, save_path):
params = fluid.Params(save_path)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(inference_network, params, place=place)
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
result = inferencer.infer({
'firstw': first_word,
'secondw': second_word,
'thirdw': third_word,
'forthw': fourth_word
})
print(result)
def main(use_cuda, is_sparse):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_path = "word2vec.inference.model"
train(use_cuda, is_sparse, save_path)
infer(use_cuda, save_path)
if __name__ == '__main__':
for use_cuda in (False, True):
for is_sparse in (False, True):
main(use_cuda=use_cuda, is_sparse=is_sparse)
...@@ -12,44 +12,200 @@ ...@@ -12,44 +12,200 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import core
import framework
import executor
import data_feeder
import contextlib
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
__all__ = [ __all__ = [
'Event',
'Trainer', 'Trainer',
'BeginEpochEvent',
'EndEpochEvent',
'BeginStepEvent',
'EndStepEvent',
] ]
class Event(object): class BeginEpochEvent(object):
BEGIN_EPOCH = 0 def __init__(self, epoch_id):
END_EPOCH = 1 self.epoch = epoch_id
BEGIN_STEP = 2
END_STEP = 3
class EndEpochEvent(object):
def __init__(self, epoch_id):
self.epoch = epoch_id
def __init__(self):
self.step = 0 class BeginStepEvent(object):
self.epoch = 0 def __init__(self, epoch_id, step_id):
self.type = Event.BEGIN_EPOCH self.epoch = epoch_id
self.step = step_id
class EndStepEvent(object):
def __init__(self, epoch_id, step_id):
self.epoch = epoch_id
self.step = step_id
class Trainer(object): class Trainer(object):
"""
Args:
network_func(callable): A function which will return loss. The loss must be a scaler.
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
params:
place: The device place of this trainer.
"""
def __init__(self, network_func, optimizer, params=None, place=None): def __init__(self, network_func, optimizer, params=None, place=None):
# 1. we need to generate a framework.Program by calling # 1. we need to generate a framework.Program by calling
# network_func. Reference: fluid.program_guard in # network_func. Reference: fluid.program_guard in
# test_word2vec.py # test_word2vec.py
self.scope = self._get_scope_from_params(params)
self.startup_program = framework.Program()
self.train_program = framework.Program()
with framework.program_guard(self.train_program, self.startup_program):
loss = network_func()
if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError(
"The optimizer should be an instance of Optimizer")
optimizer.minimize(loss)
self.place = Trainer._check_and_get_place(place)
# 2. move the default_main_program to self.program and run the # 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope() # default_startup program on an empty core.Scope()
# Run startup program
if params is None:
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
# 3. call self.params.add_vars with the initialized scope, it # 3. call self.params.add_vars with the initialized scope, it
# will add the new vars of the initialized scope into # will add the new vars of the initialized scope into
# self.params. # self.params.
self.network_func = network_func # TODO(yuyang): This depends on parameters implementation.
self.optimizer = optimizer
self.params = params
self.place = place
# TODO(helin): support distributed training # TODO(helin): support distributed training
def train(self, reader, num_epochs, event_handler): def train(self,
pass num_epochs,
event_handler,
reader=None,
parallel=False,
feed_order=None):
"""
Train the model.
Args:
num_epochs: The number of epoch. An epoch will process all data in reader
event_handler: The event handler. A function with type (ev:Event)->void
reader:
parallel: True if use multi-CPUs or multi-GPUs
feed_order: Feeding order of reader. None will following the defining
order in program
Returns:
"""
if parallel:
raise NotImplementedError(
"Parallel Executor version of trainer is not implemented")
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
def test(self, reader): def test(self, reader):
pass pass
def _get_scope_from_params(self, params):
"""
Get Scope from parameter object.
Args:
params(Parameter|None): The parameter object instance. Could be None.
Returns: New scope if params is None. Or params.scope()
NOTE: This method is WIP. Not fully implemented.
"""
if params is None:
return core.Scope() # new scope when params is None
else:
raise NotImplementedError("Not implemented right now.")
@staticmethod
def _check_and_get_place(place):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if place is None:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
else:
return core.CPUPlace()
else:
if not isinstance(place, core.CUDAPlace) and not isinstance(
place, core.CPUPlace):
raise TypeError("Place should be either CUDAPlace or CPUPlace")
return place
@contextlib.contextmanager
def _prog_and_scope_guard(self):
with framework.program_guard(
main_program=self.train_program,
startup_program=self.startup_program):
with executor.scope_guard(self.scope):
yield
def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
"""
Train by Executor and single device.
Args:
num_epochs:
event_handler:
reader:
feed_order:
Returns:
"""
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
if feed_order is None:
feed_var_list = [
var
for var in self.train_program.global_block(
).vars.itervalues()
if hasattr(var, 'is_data') and var.is_data
]
else:
feed_var_list = [
self.train_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = data_feeder.DataFeeder(
feed_list=feed_var_list, place=self.place)
for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()):
event_handler(BeginStepEvent(epoch_id, step_id))
exe.run(feed=feeder.feed(data), fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id))
event_handler(EndEpochEvent(epoch_id))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册