diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index b9ac54e446811889b647397ae1fbb11c28f46777..a4d1e812a54e8d92750c991d09860ab974e3e56d 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -115,7 +115,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 5826db292b2b9c36a8546041460c2eef13bd4821..b3a8958b22906cf83df5cff9f7b6bb415b0e4a2a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -487,7 +487,7 @@ class DistributeTranspiler(object): if init_op_num != 1: raise ValueError("table init op num should be 1, now is " + str( init_op_num)) - table_init_op = table_param_init_op[1] + table_init_op = table_param_init_op[0] self.startup_program.global_block().append_op( type="fake_init", inputs={},