未验证 提交 610ad495 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7637 from JiayiFeng/dev_global_norm_clip

Gradient clip by global norm
......@@ -14,6 +14,7 @@
import functools
import layers
import framework
from . import core
__all__ = [
......@@ -66,7 +67,7 @@ def error_clip_callback(block, context):
class BaseGradientClipAttr(object):
def process_context(self, context, p_g):
def process_context(self, context, param, grad):
raise NotImplementedError()
def create_operators(self, param, grad):
......@@ -74,7 +75,7 @@ class BaseGradientClipAttr(object):
class NullGradientClipAttr(BaseGradientClipAttr):
def process_context(self, context, p_g):
def process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
......@@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr):
self.max = max
self.min = min
def process_context(self, context, p_g):
def process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
......@@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr):
return param, new_grad
class GradientClipByNorm(BaseGradientClipAttr):
def __init__(self, clip_norm):
self.clip_norm = clip_norm
def process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad
class GradientClipByGlobalNorm(BaseGradientClipAttr):
def __init__(self, clip_norm, group_name="default_group"):
if not isinstance(group_name, basestring):
raise TypeError("'group_name' must be a basestring.")
self.clip_norm = clip_norm
self.group_name = group_name
def process_context(self, context, param, grad):
if self.group_name not in context:
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype="float32", value=self.clip_norm)
else:
if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError(
"All parameters' 'clip_norm' of a same group should be the same"
)
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
context[self.group_name].append(local_norm_var)
self.context = context
def create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
layers.sqrt(x=group_norm_var, out=group_norm_var)
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(
x=clip_var,
y=layers.elementwise_max(
x=clip_var, y=group_norm_var))
assert group_scale_var.shape == (1L, )
self.context[group_scale_name] = group_scale_var
new_grad = layers.elementwise_mul(
x=grad, y=self.context[group_scale_name])
return param, new_grad
def gradient_clip_by_global_norm(clip_norm,
param_list=None,
group_name="default_group",
program=None):
if program is None:
program = framework.default_main_program()
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, basestring) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
for param in param_list:
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
group_name)
def append_gradient_clip_ops(param_grad):
context = dict()
create_op_callbacks = []
for p, g in param_grad:
clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr())
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None:
clip_attr = NullGradientClipAttr()
if not isinstance(clip_attr, BaseGradientClipAttr):
raise TypeError(
"clip attribute should be an instance of BaseGradientClippingAttr"
)
"clip attribute should be an instance of BaseGradientClipAttr")
clip_attr.process_context(context=context, p_g=param_grad)
clip_attr.process_context(context=context, param=p, grad=g)
create_op_callbacks.append(
functools.partial(
clip_attr.create_operators, param=p, grad=g))
......
......@@ -780,7 +780,7 @@ class Block(object):
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
clip_attr=p.clip_attr,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name)
self.vars[new_p.name] = new_p
......@@ -948,7 +948,7 @@ class Parameter(Variable):
self.regularizer = kwargs.get('regularizer', None)
self.clip_attr = kwargs.get('clip_attr', None)
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
# program is a global instance.
......
......@@ -46,20 +46,10 @@ __activations__ = [
]
__all__ = [
'mean',
'mul',
'reshape',
'scale',
'transpose',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'clip',
'sequence_softmax',
'mean', 'mul', 'reshape', 'scale', 'transpose',
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min',
'clip', 'clip_by_norm', 'sequence_softmax'
] + __activations__
for _OP in set(__all__):
......
......@@ -25,13 +25,13 @@ class ParamAttr(object):
learning_rate=1.0,
regularizer=None,
trainable=True,
clip=None):
gradient_clip=None):
self.name = name
self.initializer = initializer
self.learning_rate = learning_rate
self.regularizer = regularizer
self.trainable = trainable
self.clip = clip
self.gradient_clip = gradient_clip
def set_default_initializer(self, initializer):
if initializer is None:
......@@ -77,7 +77,7 @@ class ParamAttr(object):
},
'regularizer': self.regularizer,
'trainable': self.trainable,
'clip_attr': self.clip
'gradient_clip_attr': self.gradient_clip
}
if with_initializer:
kwargs['initializer'] = self.initializer
......
......@@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image,
act='relu',
param_attr=fluid.ParamAttr(
regularizer=regularizer,
clip=fluid.clip.ClipByValue(10)))
gradient_clip=fluid.clip.ClipByValue(10)))
hidden2 = fluid.layers.fc(input=hidden1,
size=64,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
BATCH_SIZE = 128
CLIP = 1
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
prog_clip = prog.clone()
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g]
grad_clip_list = [elem[1] for elem in p_g_clip]
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(fluid.default_startup_program())
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
global_norm = 0
for v in out[1:]:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in out_clip[1:]:
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
if not np.isclose(
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
exit(1)
exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册