From 87e4edd2eaebb8b0dc3259ff82797b9551448a34 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 27 Nov 2018 10:40:46 +0800 Subject: [PATCH] fix grad_varname in remote prefetch --- paddle/fluid/operators/lookup_table_op.h | 6 +++--- python/paddle/fluid/transpiler/distribute_transpiler.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 12c5f8f1eb..c1a1ea87a0 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -49,14 +49,14 @@ class LookupTableKernel : public framework::OpKernel { auto id_name = context.Inputs("Ids").front(); auto out_name = context.Outputs("Out").front(); - auto table_name = context.Inputs("W").front(); + + // for remote prefetch auto epmap = context.Attr>("epmap"); - auto remote_prefetch = context.Attr("remote_prefetch"); auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - if (remote_prefetch) { + if (!height_sections.empty()) { // if emap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a1ccb704b2..d163357401 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -369,7 +369,7 @@ class DistributeTranspiler(object): program.global_block(), splited_grad_varname, reverse=True) if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS: sparse_param_name = self.grad_name_to_param_name[ - splited_grad_varname] + grad_varname] if self._is_input_of_remote_sparse_update_op( sparse_param_name): self.sparse_param_to_height_sections[ -- GitLab