diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index efb400ccc6d43df44325dc7ef88c14afe4b704c3..48a46a0ff02526d91b70011127612ee38632aeed 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -306,7 +306,8 @@ def embedding(input, is_distributed=False, padding_idx=None, param_attr=None, - dtype='float32'): + dtype='float32', + remote_prefetch=False): """ **Embedding Layer** @@ -345,7 +346,7 @@ def embedding(input, """ helper = LayerHelper('embedding', **locals()) - remote_prefetch = is_sparse and (not is_distributed) + remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch if remote_prefetch: assert is_sparse is True and is_distributed is False w = helper.create_parameter(