未验证 提交 bb6bd223 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] support ClipGradByGlobalNorm (#45205)

* add clip_grad

* fix comments

* add unittest

* update logger
上级 d257acc6
......@@ -19,7 +19,7 @@ import time
from paddle.fluid import core
from paddle.fluid import framework
from .utils import print_program_with_dist_attr
from .utils import print_program_with_dist_attr, _is_gradient_clip_op
from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
......@@ -1319,26 +1319,70 @@ class Completer:
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
self._dist_context.set_tensor_dist_attr_for_program(
out, out_dist_attr)
# TODO:
# 1. move `generate_optimizer` before `partitioner`
# 2. implement grad_clip completion by `dist_op`
# 3. allreduce dist_gloabl_norm (mp-group) and no_dist_global_norm (pp-group, sharding-group)
if _is_gradient_clip_op(op):
if op.type in [
"sum", "sqrt", "fill_constant", "elementwise_max",
"elementwise_div"
]:
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = world_ranks
for in_name in op.input_arg_names:
in_var = vars[in_name]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = world_ranks
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr)
remove_no_need_in_op(op, self._dist_context)
else:
in_var = vars[op.input("X")[0]]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var)
assert in_dist_attr is not None
ref_process_mesh = in_dist_attr.process_mesh
ref_dims_mapping = in_dist_attr.dims_mapping
if op.type == "cast" and ops[
idx + 1].type == "elementwise_mul":
ref_var = vars[ops[idx + 1].input("X")[0]]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh
out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
if out_var.shape == in_var.shape:
out_dist_attr.dims_mapping = ref_dims_mapping
else:
assert len(
out_var.shape) == 1 and out_var.shape[0] == 1
out_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(
in_var.name, in_dist_attr)
op_dist_attr.set_output_dist_attr(
out_var.name, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
......@@ -1383,11 +1427,17 @@ class Completer:
for input_name in op.desc.input_names():
if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate",
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
"MasterParam"
'Param',
'Grad',
'LearningRate',
"SkipUpdate",
"Beta1Tensor",
"Beta2Tensor",
"EpsilonTensor",
]:
continue
if len(op.desc.input(input_name)) == 0:
continue
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
......@@ -1400,7 +1450,6 @@ class Completer:
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1])
else:
assert "Moment" in input_name or "Velocity" in input_name
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping)
......@@ -1481,3 +1530,20 @@ class Completer:
break
else:
dist_op.dist_attr = backup_op_dist_attr
def remove_no_need_in_op(op, dist_context):
if op.type == "fill_constant":
return
filter_vars = []
main_block = op.block
rank_id = dist_context.dist_op_context.rank_id
for varname in op.input("X"):
if rank_id in dist_context.get_tensor_dist_attr_for_program(
main_block.var(varname)).process_mesh.processes:
filter_vars.append(varname)
if not filter_vars:
return
op.desc.set_input('X', filter_vars)
......@@ -68,7 +68,6 @@ class DistributedContext:
self._original_serial_loss = serial_loss
self._original_serial_feed_vars = feed_vars
self._original_serial_fetch_vars = fetch_vars
self._original_serial_optimizer = serial_optimizer
# Data members related to programs (changed)
self._serial_main_program = None
......@@ -77,6 +76,7 @@ class DistributedContext:
self._serial_optimizer = None
self._serial_feed_vars = {}
self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program
self._dist_tensors_for_program = {}
......@@ -126,7 +126,7 @@ class DistributedContext:
self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = True
self._dygraph_mode = False
@property
def serial_main_program(self):
......@@ -235,31 +235,20 @@ class DistributedContext:
if dist:
self._backup_dist_info(dist_mode)
def _restore_serial_info(self, mode="to_backup"):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._serial_optimizer = self._original_serial_optimizer
def _restore_serial_loss(self):
if self._original_serial_loss:
if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
loss = self._original_serial_loss[0]
block_idx = loss.block.idx
var_name = loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
if len(self._original_serial_loss) == 1:
loss = self._original_serial_loss[0]
block_idx = loss.block.idx
var_name = loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
elif len(self._original_serial_loss) == 0:
self._serial_loss = []
else:
raise ValueError("multi loss vars are not supported.")
else:
block_idx = self._original_serial_loss.block.idx
var_name = self._original_serial_loss.name
......@@ -267,6 +256,7 @@ class DistributedContext:
block_idx]._var_recursive(var_name)
self._serial_loss = var
def _restore_serial_feed_vars(self):
for key, var_list in self._original_serial_feed_vars.items():
new_var_list = []
for var in var_list:
......@@ -277,6 +267,7 @@ class DistributedContext:
new_var_list.append(var)
self._serial_feed_vars[key] = new_var_list
def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
......@@ -287,6 +278,24 @@ class DistributedContext:
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list
def _restore_serial_info(self, mode="to_backup"):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._restore_serial_loss()
self._restore_serial_feed_vars()
self._restore_serial_fetch_vars()
self._serial_optimizer = self._original_serial_optimizer
self._pass_context = self._backup_pass_context_stack.pop()
self._block_state = self._backup_block_state_stack.pop()
......@@ -353,25 +362,21 @@ class DistributedContext:
def initialize(self, with_graph=True):
if not self._is_initialized:
if not self._serial_main_program:
self._serial_main_program = self._original_serial_main_program
if self._original_serial_main_program:
self._serial_main_program = self._original_serial_main_program.clone(
)
if not self._serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program
if self._original_serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program.clone(
)
if not self._serial_loss:
if isinstance(self._original_serial_loss, list):
if len(self._original_serial_loss) == 1:
self._serial_loss = self._original_serial_loss[0]
elif len(self._original_serial_loss) == 0:
self._serial_loss = self._original_serial_loss
else:
raise ValueError("multi loss vars are not supported.")
else:
self._serial_loss = self._original_serial_loss
self._restore_serial_loss()
if not self._serial_optimizer:
self._serial_optimizer = self._original_serial_optimizer
if not self._serial_feed_vars:
self._serial_feed_vars = self._original_serial_feed_vars
self._restore_serial_feed_vars()
if not self._serial_fetch_vars:
self._serial_fetch_vars = self._original_serial_fetch_vars
self._restore_serial_fetch_vars()
self._init_dist_attr_for_program()
# Backup the original distributed information for later restore
......@@ -856,7 +861,11 @@ class DistributedContext:
"_serial_main_program", "_serial_startup_program", "_serial_graph", \
"_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes"]:
"_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]:
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......
......@@ -16,7 +16,6 @@ import time
import copy
import logging
from collections import defaultdict
import socket
import paddle
import paddle.utils as utils
......@@ -35,7 +34,6 @@ from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .hepler import ProgramHelper
......@@ -76,7 +74,18 @@ class Engine:
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
# TODO: add logger module
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program()
......@@ -307,7 +316,7 @@ class Engine:
mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode].serial_optimizer
self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
......@@ -429,25 +438,27 @@ class Engine:
lr_scheduler = self.get_lr_scheduler(self.main_program)
for epoch in range(epochs):
train_logs = {"epoch": epoch}
train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
train_logs["step: {:d} "] = step
if lr_scheduler is not None:
lr_scheduler.step()
train_logs["lr"] = self._optimizer.get_lr()
train_logs["step"] = step
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
# inner fetches
if fetch_loss:
train_logs["train_loss"] = outs[0][0]
train_logs["loss: {:9f} "] = outs[0][0]
# user fetches
user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs):
train_logs["train_" + fetch_map[user_fetch_list[i]]] = out
self._logger.info(train_logs)
train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
# logger
string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values())))
def evaluate(self,
eval_data,
......@@ -473,14 +484,14 @@ class Engine:
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for step, _ in enumerate(eval_dataloader):
eval_logs = {"step": step}
eval_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
# inner fetches
if fetch_loss:
eval_logs["eval_loss"] = outs[0][0]
eval_logs["loss: {:9f} "] = outs[0][0]
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
......@@ -488,14 +499,15 @@ class Engine:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
eval_logs[metric.name()[i] + ": {:9f} "] = res
# usr fetches
usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out
# logger
self._logger.info(eval_logs)
string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values())))
def predict(self,
test_data,
......@@ -520,15 +532,17 @@ class Engine:
outputs = []
for step, _ in enumerate(test_dataloader):
predict_logs = {"step": step}
predict_logs = {"step: {:d} ": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out
self._logger.info(predict_logs)
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
# logger
string = '[pred] ' + ''.join(list(predict_logs.keys()))
self._logger.info(string.format(*list(predict_logs.values())))
return outputs
......
......@@ -20,7 +20,7 @@ from collections import defaultdict
import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
......@@ -143,15 +143,18 @@ class Parallelizer:
def _generate_optimizer(self, main_program, startup_program, optimizer,
params_grads):
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
if self._dist_context._dygraph_mode:
paddle.disable_static()
optimizer = copy.deepcopy(optimizer)
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer
self._dist_context._lr_optimizer = optimizer
with program_guard(main_program, startup_program):
optimizer_ops = optimizer.apply_gradients(params_grads)
with unique_name.guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops
......
......@@ -30,10 +30,13 @@ from .cost import build_comm_desc, CommContext
from .cost import AllgatherOpCost, SendOpCost
from .cost import SliceOpCost, SplitOpCost, ConcatOpCost
from .cluster import Cluster
from .utils import print_program_with_dist_attr
from .utils import print_program_with_dist_attr, _is_gradient_clip_op
# NOTE: If op in _g_special_ops, it will not be resharded.
# NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
def get_var_with_recursion(var_name, block, program):
......@@ -1076,9 +1079,11 @@ class Resharder:
return True
def is_special_op(self, op):
global _g_special_ops
global _g_special_ops, _g_gradient_clip_ops
if op.type in _g_special_ops:
return True
if _is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops:
return True
return False
def is_condition_replicative(self, op):
......
......@@ -1131,6 +1131,11 @@ def is_loss_grad_op(op):
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
def _is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def is_prim_op(op):
return op.type.endswith("_p")
......
......@@ -64,4 +64,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
endif()
......@@ -108,9 +108,7 @@ def train(fetch):
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00001,
T_max=10)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
......
......@@ -15,6 +15,7 @@
import unittest
import os
import json
import copy
import paddle
import numpy as np
......@@ -194,6 +195,32 @@ class TestDistributedContext(unittest.TestCase):
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")
def test_deepcopy(self):
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program(
)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars)
dist_context.initialize()
copy_dist_context = copy.deepcopy(dist_context)
copy_list = [
"_original_serial_main_program", "_original_serial_startup_program", \
"_serial_main_program", "_serial_startup_program", "_serial_graph", \
"_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]
for i in range(len(copy_list)):
copy_obj = "copy_dist_context." + copy_list[i]
obj = "dist_context." + copy_list[i]
assert id(eval(copy_obj)) == id(eval(obj))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper
from test_to_static import MLPLayer, MyDataset
paddle.enable_static()
class TestEngineBase(unittest.TestCase):
def setUp(self):
self.batch_size = 4
self.batch_num = 5
self.hidden_size = 1024
self.init_model()
self.init_optimizer()
self.init_dataset()
self.init_engine()
def init_model(self):
self.mlp = MLPLayer(hidden_size=self.hidden_size,
intermediate_size=4 * self.hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
self.loss = paddle.nn.CrossEntropyLoss()
def init_optimizer(self):
self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
parameters=self.mlp.parameters())
def init_dataset(self):
self.dataset = MyDataset(self.batch_num * self.batch_size)
def init_engine(self):
inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x')
labels = InputSpec([self.batch_size], 'int64', 'label')
self.engine = Engine(model=self.mlp,
inputs_spec=inputs,
labels_spec=labels)
self.engine.prepare(optimizer=self.optimizer,
loss=self.loss,
metrics=paddle.metric.Accuracy())
class TestLRScheduler(TestEngineBase):
def init_optimizer(self):
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=0.00001, T_max=10)
self.optimizer = paddle.optimizer.SGD(learning_rate=scheduler)
def test_lr_scheduler(self):
self.init_engine()
lr = self.engine._optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
self.engine.fit(self.dataset, batch_size=self.batch_size)
class TestGradClip(TestEngineBase):
def init_optimizer(self):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
self.optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
grad_clip=clip)
def test_grad_clip(self):
clip = self.engine._optimizer._grad_clip
assert isinstance(clip, paddle.nn.ClipGradByGlobalNorm)
self.engine.fit(self.dataset, batch_size=self.batch_size)
self.check_program()
def check_program(self):
ops = self.engine.main_program.global_block().ops
has_grad_clip = False
for op in ops:
if op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip"):
has_grad_clip = True
break
assert has_grad_clip is True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册