diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index c450a3a3684eb44cdc758a2b27783b5a81945c38..dd8544e2e7d39be33a9096cad16c4d58eb58bcad 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -43,12 +43,15 @@ class Momentum(object): self.grad_clip = grad_clip def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=model.parameters()) + parameters=train_params) return opt @@ -76,6 +79,9 @@ class Adam(object): self.lazy_mode = lazy_mode def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.Adam( learning_rate=self.learning_rate, beta1=self.beta1, @@ -85,7 +91,7 @@ class Adam(object): grad_clip=self.grad_clip, name=self.name, lazy_mode=self.lazy_mode, - parameters=model.parameters()) + parameters=train_params) return opt @@ -118,6 +124,9 @@ class RMSProp(object): self.grad_clip = grad_clip def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.RMSProp( learning_rate=self.learning_rate, momentum=self.momentum, @@ -125,7 +134,7 @@ class RMSProp(object): epsilon=self.epsilon, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=model.parameters()) + parameters=train_params) return opt @@ -149,6 +158,9 @@ class Adadelta(object): self.name = name def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.Adadelta( learning_rate=self.learning_rate, epsilon=self.epsilon, @@ -156,7 +168,7 @@ class Adadelta(object): weight_decay=self.weight_decay, grad_clip=self.grad_clip, name=self.name, - parameters=model.parameters()) + parameters=train_params) return opt @@ -190,17 +202,20 @@ class AdamW(object): self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay def __call__(self, model): - parameters = model.parameters() + parameters = [ + param for param in model.parameters() if param.trainable is True + ] self.no_weight_decay_param_name_list = [ - p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) + p.name for n, p in model.named_parameters() + if any(nd in n for nd in self.no_weight_decay_name_list) ] if self.one_dim_param_no_weight_decay: self.no_weight_decay_param_name_list += [ - p.name for n, p in model.named_parameters() if len(p.shape) == 1 + p.name for n, p in model.named_parameters() if len(p.shape) == 1 ] - + opt = optim.AdamW( learning_rate=self.learning_rate, beta1=self.beta1, @@ -216,4 +231,4 @@ class AdamW(object): return opt def _apply_decay_param_fun(self, name): - return name not in self.no_weight_decay_param_name_list \ No newline at end of file + return name not in self.no_weight_decay_param_name_list