diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d4f025a4af60d3365b0e85e3024a2dd94f5ae842..f6ff83924f2516875f44b8b49ac75e36d7d1d407 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -113,6 +113,7 @@ class GradientClipByNorm(BaseGradientClipAttr): class GradientClipByGlobalNorm(BaseGradientClipAttr): global_norm_var = None + local_norm_var = None clip_norm_var = None scale_var = None @@ -123,12 +124,18 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): cls.global_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=0.0) + cls.local_norm_var = framework.default_main_program().current_block( + ).create_var( + name=framework.unique_name("local_norm"), + dtype="float32", + persistable=False) cls.clip_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=clip_norm) @classmethod def check_init(cls): if not (isinstance(cls.global_norm_var, framework.Variable) and + isinstance(cls.local_norm_var, framework.Variable) and isinstance(cls.clip_norm_var, framework.Variable)): raise ValueError( "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ @@ -138,9 +145,10 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): cls = self.__class__ cls.check_init() - local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + cls.local_norm_var = layers.reduce_sum( + input=layers.pow(x=grad, factor=2.0)) layers.sums( - input=[local_norm_var, cls.global_norm_var], + input=[cls.local_norm_var, cls.global_norm_var], out=[cls.global_norm_var]) def create_operators(self, param, grad): @@ -148,7 +156,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): cls.check_init() if cls.scale_var is None: - cls.global_norm_var = layers.sqrt(x=cls.global_norm_var) + layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var) cls.scale_var = layers.elementwise_div( x=cls.clip_norm_var, y=layers.elementwise_max( diff --git a/python/paddle/v2/fluid/tests/test_clip.py b/python/paddle/v2/fluid/tests/test_error_clip.py similarity index 100% rename from python/paddle/v2/fluid/tests/test_clip.py rename to python/paddle/v2/fluid/tests/test_error_clip.py diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb7f0b2cb3f697c4565053af46a18c000a1c4b9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -0,0 +1,82 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + + +def _get_global_param_norm_(params_grads): + res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0) + for _, grad in params_grads: + norm_var = fluid.layers.reduce_sum( + input=fluid.layers.pow(x=grad, factor=2.0)) + fluid.layers.sums(input=[norm_var, res], out=[res]) + fluid.layers.sqrt(x=res, out=res) + return res + + +BATCH_SIZE = 128 +CLIP = 0.5 +prog = fluid.framework.Program() + +with fluid.program_guard(main_program=prog): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + +prog_clip = prog.clone() + +avg_cost_clip = prog_clip.block(0).var(avg_cost.name) + +p_g = fluid.backward.append_backward(loss=avg_cost) +p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) + +with fluid.program_guard(main_program=prog): + gloabl_norm = _get_global_param_norm_(p_g) + +with fluid.program_guard(main_program=prog_clip): + fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) + gloabl_norm_clip = _get_global_param_norm_(p_g_clip) + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +feeder = fluid.DataFeeder(feed_list=[image, label], place=place) +exe.run(fluid.default_startup_program()) + +count = 0 +for data in train_reader(): + count += 1 + if count > 5: + break + out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm]) + out_clip, = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=[gloabl_norm_clip]) + + if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))): + exit(1) +exit(0)