提交 e6745be9 编写于 作者: T typhoonzero

fix_not_trainable_transpiler

上级 82b192a3
......@@ -18,7 +18,7 @@ import math
import distributed_splitter as splitter
import framework
from framework import Program, default_main_program, Variable
from framework import Program, default_main_program, Variable, Parameter
from . import core
LOOKUP_TABLE_TYPE = "lookup_table"
......@@ -222,8 +222,14 @@ class DistributeTranspiler:
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = [pg[0] for pg in params_grads]
grad_list = [pg[1] for pg in params_grads]
param_list = []
grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
if self.has_distributed_lookup_table:
param_list = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册