From ae676a609f95dfcc5a54437925037cace76386ea Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 22 Jan 2018 10:44:36 +0800 Subject: [PATCH] Enhance lookup_table_op to support padding_idx --- paddle/framework/attribute.cc | 3 ++ paddle/framework/attribute.h | 26 ++++++++++++++ paddle/framework/framework.proto | 2 ++ paddle/framework/op_desc.cc | 1 + paddle/framework/type_defs.h | 2 +- paddle/operators/lookup_table_op.cc | 10 +++--- paddle/operators/lookup_table_op.cu | 33 ++++++++++++++---- paddle/operators/lookup_table_op.h | 27 ++++++++------- python/paddle/v2/fluid/framework.py | 1 + python/paddle/v2/fluid/layers/nn.py | 34 ++++++++++++++----- .../v2/fluid/tests/test_lookup_table_op.py | 14 ++++++++ 11 files changed, 119 insertions(+), 34 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index b0fd4d2750..5074e8f5a0 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { } return val; } + case proto::AttrType::LONG: { + return attr_desc.l(); + } default: PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); } diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index c1c63d9cb1..bcff9bc4c4 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -168,6 +168,32 @@ struct ExtractAttribute { const std::string& attr_name_; }; +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + int64_t* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(float)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } + int64_t* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", + attr_name_, attr.type().name()); + } + return attr_value; + } + + const std::string& attr_name_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ea69b87e2a..5b6ef03f61 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -26,6 +26,7 @@ enum AttrType { BOOLEAN = 6; BOOLEANS = 7; BLOCK = 8; + LONG = 9; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -44,6 +45,7 @@ message OpDesc { optional bool b = 10; repeated bool bools = 11; optional int32 block_idx = 12; + optional int64 l = 13; }; message Var { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 1c0372bb16..43e5f79735 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -282,6 +282,7 @@ struct SetAttrDescVisitor : public boost::static_visitor { VectorToRepeated(v, attr_->mutable_bools()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } + void operator()(int64_t v) const { attr_->set_l(v); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index d834d34375..1eedbbc419 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -35,7 +35,7 @@ using VariableNameMap = std::map>; using Attribute = boost::variant, std::vector, std::vector, bool, - std::vector, BlockDesc*>; + std::vector, BlockDesc*, int64_t>; using AttributeMap = std::unordered_map; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 54c326c1d9..2405852f53 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,11 +66,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "(boolean, default false) " "Sparse update") .SetDefault(false); - AddAttr( - "padding_idx", - "(int64_t, default -1) " - " If given, pads the output with zeros whenever it encounters " - "the index.") + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") .SetDefault(-1); AddComment(R"DOC( Lookup Table Operator. diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 261a28da69..0482746ec8 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -21,9 +21,11 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template __global__ void LookupTable(T* output, const T* table, const int64_t* ids, - const int64_t N, const int64_t K, const int64_t D) { + const int64_t N, const int64_t K, const int64_t D, + const int64_t padding_idx) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * GridDimX; @@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, T* out = output + idy * D; const T* tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { - out[i] = tab[i]; + if (PaddingFlag) { + if (idx == padding_idx) + out[i] = static_cast(0); + else + out[i] = tab[i]; + } else { + out[i] = tab[i]; + } } idy += BlockDimY * GridDimX; } @@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { auto* table_t = context.Input("W"); auto* ids_t = context.Input("Ids"); auto* output_t = context.Output("Out"); + int64_t padding_idx = context.Attr("padding_idx"); size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; @@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable< - T, 128, 8, - 8><<>>( - output, table, ids, N, K, D); + + if (padding_idx == -1) + LookupTable< + T, 128, 8, 8, + false><<>>( + output, table, ids, N, K, D, padding_idx); + else + LookupTable< + T, 128, 8, 8, + true><<>>( + output, table, ids, N, K, D, padding_idx); } }; @@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); bool is_sparse = context.Attr("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { auto* ids = context.Input("Ids"); auto* table = context.Input("W"); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 2fa45e2437..0842c422f7 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -39,14 +39,23 @@ class LookupTableKernel : public framework::OpKernel { auto* ids = ids_t->data(); auto* table = table_t->data(); auto* output = output_t->mutable_data(context.GetPlace()); - for (int64_t i = 0; i < ids_t->numel(); ++i) { - if (ids[i] == padding_idx) { - memset(output + i * D, 0, D * sizeof(T)); - } else { + + if (padding_idx == -1) { + for (int64_t i = 0; i < ids_t->numel(); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } + } else { + for (int64_t i = 0; i < ids_t->numel(); ++i) { + if (ids[i] == padding_idx) { + memset(output + i * D, 0, D * sizeof(T)); + } else { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); + } + } } } }; @@ -56,8 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { bool is_sparse = context.Attr("is_sparse"); - int64_t padding_idx = context.Attr("padding_idx"); - + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { auto* ids = context.Input("Ids"); auto* table = context.Input("W"); @@ -70,9 +79,6 @@ class LookupTableGradKernel : public framework::OpKernel { framework::Vector new_rows; new_rows.reserve(ids_dim[0]); for (int64_t i = 0; i < ids_dim[0]; i++) { - if (ids_data[i] == padding_idx) - continue; // Paddings are not trainable and the gradient are not - // necessary. new_rows.push_back(ids_data[i]); } d_table->set_rows(new_rows); @@ -106,9 +112,6 @@ class LookupTableGradKernel : public framework::OpKernel { memset(d_table_data, 0, d_table->numel() * sizeof(T)); for (int64_t i = 0; i < ids->numel(); ++i) { - if (ids_data[i] == padding_idx) - continue; // Paddings are not trainable and the gradient are not - // necessary. PADDLE_ENFORCE_LT(ids_data[i], N); PADDLE_ENFORCE_GE(ids_data[i], 0); for (int j = 0; j < D; ++j) { diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 4f8366b640..6823d2f86e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -471,6 +471,7 @@ class Operator(object): self.desc.set_serialized_attr( attr_name, attrs[attr_name].serialize_to_string()) else: + # print 'haha', attrs[attr_name], type(attrs[attr_name]) self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.check_attrs() diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index fc4c22e152..c556e315b9 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -176,22 +176,35 @@ def fc(input, return helper.append_activation(pre_activation) -def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): +def embedding(input, + size, + is_sparse=False, + padding_idx=None, + param_attr=None, + dtype='float32'): """ **Embedding Layer** - This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table. - The result of this lookup is the embedding of each ID in the *input*. + This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in + a lookup table. The result of this lookup is the embedding of each ID in the + :attr:`input`. All the input variables are passed in as local variables to the LayerHelper constructor. Args: - input(Variable): Input to the function - size(tuple|list|None): Shape of the look up table parameter - is_sparse(bool): Boolean flag that specifying whether the input is sparse - param_attr(ParamAttr): Parameters for this layer - dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc + input(Variable): The tensor variable containing the IDs. + size(tuple|list): The shape of the look up table parameter. It should + have two elements which indicate the size of the dictionary of + embeddings and the size of each embedding vector respectively. + is_sparse(bool): The flag indicating whether to use sparse update. + padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. + Otherwise the given :attr:`padding_idx` indicates padding the output + with zeros whenever lookup encounters it in :attr:`input`. If + :math:`padding_idx < 0`, the padding_idx to use in lookup is + :math:`size[0] + dim`. + param_attr(ParamAttr): Parameters for this layer + dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc Returns: Variable: The tensor variable storing the embeddings of the \ @@ -209,12 +222,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): w = helper.create_parameter( attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) tmp = helper.create_tmp_variable(dtype) + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + size[0] + padding_idx) helper.append_op( type='lookup_table', inputs={'Ids': input, 'W': w}, outputs={'Out': tmp}, - attrs={'is_sparse': is_sparse}) + attrs={'is_sparse': is_sparse, + 'padding_idx': padding_idx}) return tmp diff --git a/python/paddle/v2/fluid/tests/test_lookup_table_op.py b/python/paddle/v2/fluid/tests/test_lookup_table_op.py index 1ff6b305bc..d60a8d3deb 100644 --- a/python/paddle/v2/fluid/tests/test_lookup_table_op.py +++ b/python/paddle/v2/fluid/tests/test_lookup_table_op.py @@ -32,5 +32,19 @@ class TestLookupTableOp(OpTest): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) +class TestLookupTableOpWithPadding(TestLookupTableOp): + def test_check_output(self): + ids = np.squeeze(self.inputs['Ids']) + padding_idx = np.random.choice(ids, 1)[0] + self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.attrs = {'padding_idx': long(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass + + if __name__ == "__main__": unittest.main() -- GitLab