未验证 提交 051ba1ce 编写于 作者: Q Qiao Longfei 提交者: GitHub

Use force cpu in fill constant op (#8254)

上级 222155cc
...@@ -14,14 +14,37 @@ ...@@ -14,14 +14,37 @@
import framework import framework
import numpy as np import numpy as np
import contextlib
__all__ = [ __all__ = [
'Constant', 'Constant', 'Uniform', 'Normal', 'Xavier', 'force_init_on_cpu',
'Uniform', 'init_on_cpu'
'Normal',
'Xavier',
] ]
_force_init_on_cpu_ = False
def force_init_on_cpu():
return _force_init_on_cpu_
@contextlib.contextmanager
def init_on_cpu():
"""
Switch program with `with` statement
Examples:
>>> with init_on_cpu():
>>> step = layers.create_global_var()
"""
global _force_init_on_cpu_
pre_state = force_init_on_cpu()
_force_init_on_cpu_ = True
yield
_force_init_on_cpu_ = pre_state
class Initializer(object): class Initializer(object):
"""Base class for variable initializers """Base class for variable initializers
...@@ -80,7 +103,7 @@ class ConstantInitializer(Initializer): ...@@ -80,7 +103,7 @@ class ConstantInitializer(Initializer):
"""Implements the constant initializer """Implements the constant initializer
""" """
def __init__(self, value=0.0): def __init__(self, value=0.0, force_cpu=False):
"""Constructor for ConstantInitializer """Constructor for ConstantInitializer
Args: Args:
...@@ -89,6 +112,7 @@ class ConstantInitializer(Initializer): ...@@ -89,6 +112,7 @@ class ConstantInitializer(Initializer):
assert value is not None assert value is not None
super(ConstantInitializer, self).__init__() super(ConstantInitializer, self).__init__()
self._value = value self._value = value
self._force_cpu = force_cpu
def __call__(self, var, block): def __call__(self, var, block):
"""Add constant initialization ops for a variable """Add constant initialization ops for a variable
...@@ -110,7 +134,8 @@ class ConstantInitializer(Initializer): ...@@ -110,7 +134,8 @@ class ConstantInitializer(Initializer):
attrs={ attrs={
"shape": var.shape, "shape": var.shape,
"dtype": int(var.dtype), "dtype": int(var.dtype),
"value": self._value "value": float(self._value),
'force_cpu': self._force_cpu or force_init_on_cpu()
}) })
var.op = op var.op = op
return op return op
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from ..framework import Variable, unique_name from ..framework import Variable, unique_name
from layer_function_generator import OpProtoHolder from layer_function_generator import OpProtoHolder
from ..initializer import force_init_on_cpu
__all__ = ['monkey_patch_variable'] __all__ = ['monkey_patch_variable']
...@@ -36,9 +37,12 @@ def monkey_patch_variable(): ...@@ -36,9 +37,12 @@ def monkey_patch_variable():
block.append_op( block.append_op(
type="fill_constant", type="fill_constant",
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={'dtype': var.dtype, attrs={
'dtype': var.dtype,
'shape': shape, 'shape': shape,
'value': value}) 'value': value,
'force_cpu': force_init_on_cpu()
})
return var return var
def create_scalar(block, value, dtype): def create_scalar(block, value, dtype):
......
...@@ -16,7 +16,7 @@ from ..layer_helper import LayerHelper ...@@ -16,7 +16,7 @@ from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant from ..initializer import Constant, force_init_on_cpu
from ..core import DataType from ..core import DataType
import numpy import numpy
...@@ -69,12 +69,30 @@ def create_parameter(shape, ...@@ -69,12 +69,30 @@ def create_parameter(shape,
default_initializer) default_initializer)
def create_global_var(shape, value, dtype, persistable=False, name=None): def create_global_var(shape,
value,
dtype,
persistable=False,
force_cpu=False,
name=None):
"""
Create a global variable. such as global_step
Args:
shape(list[int]): shape of the variable
value(float): the value of the variable
dtype(string): element type of the parameter
persistable(bool): if this variable is persistable
force_cpu(bool): force this variable to be on CPU
Returns:
Variable: the created Variable
"""
helper = LayerHelper("global_var", **locals()) helper = LayerHelper("global_var", **locals())
var = helper.create_global_variable( var = helper.create_global_variable(
dtype=dtype, shape=shape, persistable=persistable, name=name) dtype=dtype, shape=shape, persistable=persistable, name=name)
helper.set_variable_initializer( helper.set_variable_initializer(
var, initializer=Constant(value=float(value))) var, initializer=Constant(
value=float(value), force_cpu=force_cpu))
return var return var
...@@ -221,6 +239,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -221,6 +239,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
dtype(np.dtype|core.DataType|str): Data type of the output tensor. dtype(np.dtype|core.DataType|str): Data type of the output tensor.
value(float): The constant value used to initialize the output tensor. value(float): The constant value used to initialize the output tensor.
out(Variable): The output tensor. out(Variable): The output tensor.
force_cpu(True|False): data should be on CPU if set true.
Returns: Returns:
Variable: The tensor variable storing the output. Variable: The tensor variable storing the output.
...@@ -242,7 +261,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -242,7 +261,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
'shape': shape, 'shape': shape,
'dtype': out.dtype, 'dtype': out.dtype,
'value': float(value), 'value': float(value),
'force_cpu': force_cpu 'force_cpu': force_cpu or force_init_on_cpu()
}) })
out.stop_gradient = True out.stop_gradient = True
return out return out
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import layers import layers
from framework import Variable from framework import Variable
from initializer import init_on_cpu
__all__ = [ __all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
...@@ -54,11 +55,14 @@ def exponential_decay(learning_rate, ...@@ -54,11 +55,14 @@ def exponential_decay(learning_rate,
if not isinstance(global_step, Variable): if not isinstance(global_step, Variable):
raise ValueError("global_step is required for exponential_decay.") raise ValueError("global_step is required for exponential_decay.")
with init_on_cpu():
# update learning_rate # update learning_rate
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = layers.floor(x=div_res) div_res = layers.floor(x=div_res)
return learning_rate * (decay_rate**div_res) decayed_lr = learning_rate * (decay_rate**div_res)
return decayed_lr
def natural_exp_decay(learning_rate, def natural_exp_decay(learning_rate,
...@@ -88,10 +92,13 @@ def natural_exp_decay(learning_rate, ...@@ -88,10 +92,13 @@ def natural_exp_decay(learning_rate,
if not isinstance(global_step, Variable): if not isinstance(global_step, Variable):
raise ValueError("global_step is required for natural_exp_decay.") raise ValueError("global_step is required for natural_exp_decay.")
with init_on_cpu():
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = layers.floor(x=div_res) div_res = layers.floor(x=div_res)
return learning_rate * layers.exp(x=(-1 * decay_rate * div_res)) decayed_lr = learning_rate * layers.exp(x=(-1 * decay_rate * div_res))
return decayed_lr
def inverse_time_decay(learning_rate, def inverse_time_decay(learning_rate,
...@@ -121,11 +128,14 @@ def inverse_time_decay(learning_rate, ...@@ -121,11 +128,14 @@ def inverse_time_decay(learning_rate,
if not isinstance(global_step, Variable): if not isinstance(global_step, Variable):
raise ValueError("global_step is required for inverse_time_decay.") raise ValueError("global_step is required for inverse_time_decay.")
with init_on_cpu():
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = layers.floor(x=div_res) div_res = layers.floor(x=div_res)
return learning_rate / (1 + decay_rate * div_res) decayed_lr = learning_rate / (1 + decay_rate * div_res)
return decayed_lr
def polynomial_decay(learning_rate, def polynomial_decay(learning_rate,
...@@ -160,10 +170,13 @@ def polynomial_decay(learning_rate, ...@@ -160,10 +170,13 @@ def polynomial_decay(learning_rate,
if not isinstance(global_step, Variable): if not isinstance(global_step, Variable):
raise ValueError("global_step is required for inverse_time_decay.") raise ValueError("global_step is required for inverse_time_decay.")
with init_on_cpu():
if cycle: if cycle:
div_res = layers.ceil(x=(global_step / decay_steps)) div_res = layers.ceil(x=(global_step / decay_steps))
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0) zero_var = layers.fill_constant(
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0) shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(
shape=[1], dtype='float32', value=1.0)
with layers.Switch() as switch: with layers.Switch() as switch:
with switch.case(layers.equal(x=global_step, y=zero_var)): with switch.case(layers.equal(x=global_step, y=zero_var)):
...@@ -172,10 +185,12 @@ def polynomial_decay(learning_rate, ...@@ -172,10 +185,12 @@ def polynomial_decay(learning_rate,
else: else:
decay_steps_var = layers.fill_constant( decay_steps_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps)) shape=[1], dtype='float32', value=float(decay_steps))
global_step = layers.elementwise_min(x=global_step, y=decay_steps_var) global_step = layers.elementwise_min(
x=global_step, y=decay_steps_var)
return (learning_rate - end_learning_rate) * \ decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate ((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr
def piecewise_decay(global_step, boundaries, values): def piecewise_decay(global_step, boundaries, values):
...@@ -200,6 +215,7 @@ def piecewise_decay(global_step, boundaries, values): ...@@ -200,6 +215,7 @@ def piecewise_decay(global_step, boundaries, values):
if not isinstance(global_step, Variable): if not isinstance(global_step, Variable):
raise ValueError("global_step is required for piecewise_decay.") raise ValueError("global_step is required for piecewise_decay.")
with init_on_cpu():
lr = layers.create_global_var( lr = layers.create_global_var(
shape=[1], shape=[1],
value=0.0, value=0.0,
...@@ -216,7 +232,9 @@ def piecewise_decay(global_step, boundaries, values): ...@@ -216,7 +232,9 @@ def piecewise_decay(global_step, boundaries, values):
with switch.case(layers.less_than(global_step, boundary_val)): with switch.case(layers.less_than(global_step, boundary_val)):
layers.assign(value_var, lr) layers.assign(value_var, lr)
last_value_var = layers.fill_constant( last_value_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1])) shape=[1],
dtype='float32',
value=float(values[len(values) - 1]))
with switch.default(): with switch.default():
layers.assign(last_value_var, lr) layers.assign(last_value_var, lr)
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05 import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
from paddle.v2.fluid.initializer import init_on_cpu
import contextlib import contextlib
import time import time
import unittest import unittest
...@@ -167,7 +168,16 @@ def train(use_cuda, save_dirname=None): ...@@ -167,7 +168,16 @@ def train(use_cuda, save_dirname=None):
# TODO(qiao) # TODO(qiao)
# check other optimizers and check why out will be NAN # check other optimizers and check why out will be NAN
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001) global_step = fluid.layers.create_global_var(
shape=[1], value=0, dtype='float32', force_cpu=True, persistable=True)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.learning_rate_decay.exponential_decay(
learning_rate=0.0001,
global_step=global_step,
decay_steps=100000,
decay_rate=0.5,
staircase=True),
global_step=global_step)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
# TODO(qiao) # TODO(qiao)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册