From 835879160782721961cee33511041a517bef562a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 3 Dec 2020 20:31:36 +0800 Subject: [PATCH] fix gpu outofrange (#29238) * fix gpu emb out of range Change-Id: I5794ac73bd634d5ea069a6fbbd914274b6d6b7bf * fix doc Change-Id: I5a3350b2930a9ab2f52116c192b087307faf8fdf --- paddle/fluid/operators/lookup_table_v2_op.cu | 35 +++++++++----------- python/paddle/fluid/input.py | 14 +++----- python/paddle/nn/functional/input.py | 29 ++++++++-------- python/paddle/nn/layer/common.py | 2 +- 4 files changed, 33 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index bd31d7dd1b8..493966ecda7 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -31,16 +31,6 @@ __global__ void LookupTableV2(T *output, const T *table, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ENFORCE( - id >= 0, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, id); - PADDLE_ENFORCE( - id < N, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, id); T *out = output + idy * D; const T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -66,16 +56,6 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ENFORCE( - id >= 0, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, id); - PADDLE_ENFORCE( - id < N, - "Variable value (input) of OP(fluid.layers.embedding) " - "expected >= 0 and < %ld, but got %ld. Please check input value.", - N, id); const T *out = output + idy * D; T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -127,6 +107,21 @@ class LookupTableV2CUDAKernel : public framework::OpKernel { ids_p = ids_t->data(); } + for (int64_t i = 0; i < K; ++i) { + PADDLE_ENFORCE_GE( + ids[i], 0, + platform::errors::InvalidArgument( + "Variable value (input) of OP(paddle.nn.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids[i])); + PADDLE_ENFORCE_LT( + ids[i], N, + platform::errors::InvalidArgument( + "Variable value (input) of OP(paddle.nn.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids[i])); + } + auto *table = table_t->data(); auto *output = output_t->mutable_data(context.GetPlace()); diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py index 2c4a9272648..b13419ae36c 100644 --- a/python/paddle/fluid/input.py +++ b/python/paddle/fluid/input.py @@ -197,10 +197,7 @@ def embedding(input, indicates the size of the dictionary of embeddings and the size of each embedding vector respectively. is_sparse(bool): The flag indicating whether to use sparse update. This parameter only affects the performance of the backwards gradient update. It is recommended to set - True because sparse update is faster. But some optimizer does not support sparse update, - such as :ref:`api_fluid_optimizer_AdadeltaOptimizer` , :ref:`api_fluid_optimizer_AdamaxOptimizer` , - :ref:`api_fluid_optimizer_DecayedAdagradOptimizer` , :ref:`api_fluid_optimizer_FtrlOptimizer` , - :ref:`api_fluid_optimizer_LambOptimizer` and :ref:`api_fluid_optimizer_LarsMomentumOptimizer` . + True because sparse update is faster. But some optimizer does not support sparse update In these case, is_sparse must be False. Default: False. is_distributed(bool): Whether to store the embedding matrix in a distributed manner. Only used in multi-machine distributed CPU training. Default: False. @@ -210,11 +207,10 @@ def embedding(input, encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. If set None, it makes no effect to output. Default: None. param_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the - default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . In addition, + default weight parameter property is used. In addition, user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. The local word vector needs to be transformed into numpy format, and the shape of local word - vector should be consistent with :attr:`size` . Then :ref:`api_fluid_initializer_NumpyArrayInitializer` - is used to load custom or pre-trained word vectors. See code example 2 for details. + vector should be consistent with :attr:`size` . dtype(str|core.VarDesc.VarType): It refers to the data type of output Tensor. It must be float32 or float64. Default: float32. @@ -267,9 +263,7 @@ def embedding(input, import paddle import numpy as np - - paddle.disable_static() - + x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) # x is a Tensor. diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index 5cabc4b6755..bf389717518 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -168,28 +168,25 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): .. code-block:: python + import numpy as np import paddle import paddle.nn as nn - weight = prog.global_block().create_parameter( - attr=self._param_attr, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0)) + x0 = np.arange(3, 6).reshape((3, 1)).astype(np.int64) + w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) - prog = paddle.static.Program() + # x.data = [[3], [4], [5]] + # x.shape = [3, 1] + x = paddle.to_tensor(x0, stop_gradient=False) - weight = prog.global_block().create_parameter( - (128, 100), dtype="float32", default_initializer=Constant(1.0)) + # w.data = [[2. 2. 2.] ... [2. 2. 2.]] + # w.shape = [10, 3] + w = paddle.to_tensor(w0, stop_gradient=False) - label = paddle.static.data( - name="label", - shape=[4], - append_batch_size=False, - dtype="int64") - - emb = nn.embedding( - x=label, weight=weight, sparse=True, name="embedding") + # emb.data = [[[2., 2., 2.]], [[2., 2., 2.]], [[2., 2., 2.]]] + # emb.shape = [3, 1, 3] + emb = nn.functional.embedding( + x=x, weight=w, sparse=True, name="embedding") """ padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 88221b7f009..1969b640481 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1216,7 +1216,7 @@ class Embedding(layers.Layer): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) - paddle.disable_static(paddle.CPUPlace()) + x = paddle.to_tensor(x_data, stop_gradient=False) y = paddle.to_tensor(y_data, stop_gradient=False) -- GitLab