diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 42623337de933ddcf2e3d4a036c3e79907ce6c21..ae4befa004c9e587a4a58d7f8df3f248a6fc277f 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -314,7 +314,8 @@ class LocalSGD(Collective): name=self.snapshot_name(param.name), shape=param.shape, persistable=True, - stop_gradient=True) + stop_gradient=True, + dtype=param.dtype) block._insert_op( idx + 1,