From e6745be9ea20144ddaf44cbc1030aeb2a848f86f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 19 Apr 2018 12:27:41 +0800 Subject: [PATCH] fix_not_trainable_transpiler --- python/paddle/fluid/distribute_transpiler.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 50be386c51..349427525d 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -18,7 +18,7 @@ import math import distributed_splitter as splitter import framework -from framework import Program, default_main_program, Variable +from framework import Program, default_main_program, Variable, Parameter from . import core LOOKUP_TABLE_TYPE = "lookup_table" @@ -222,8 +222,14 @@ class DistributeTranspiler: # step1: For large parameters and gradients, split them into smaller # blocks. - param_list = [pg[0] for pg in params_grads] - grad_list = [pg[1] for pg in params_grads] + param_list = [] + grad_list = [] + for p, g in params_grads: + # skip parameter marked not trainable + if type(p) == Parameter and p.trainable == False: + continue + param_list.append(p) + grad_list.append(g) if self.has_distributed_lookup_table: param_list = [ -- GitLab