diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 12c5f8f1eb6be340811fe7eca5cade27744d70db..c1a1ea87a0ee7e6438a680a531c2a91f1f421395 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -49,14 +49,14 @@ class LookupTableKernel : public framework::OpKernel { auto id_name = context.Inputs("Ids").front(); auto out_name = context.Outputs("Out").front(); - auto table_name = context.Inputs("W").front(); + + // for remote prefetch auto epmap = context.Attr>("epmap"); - auto remote_prefetch = context.Attr("remote_prefetch"); auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - if (remote_prefetch) { + if (!height_sections.empty()) { // if emap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a1ccb704b2d43434beaceb3305783b09dd1f02e2..d1633574011d3a4ea5fa1ef520b8394d0fbab765 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -369,7 +369,7 @@ class DistributeTranspiler(object): program.global_block(), splited_grad_varname, reverse=True) if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS: sparse_param_name = self.grad_name_to_param_name[ - splited_grad_varname] + grad_varname] if self._is_input_of_remote_sparse_update_op( sparse_param_name): self.sparse_param_to_height_sections[