未验证 提交 9ae1523e 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #7719 from guoshengCS/enhance-lookup_table_op-padidx

Enhance lookup_table_op to support padding_idx.
...@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
} }
return val; return val;
} }
case proto::AttrType::LONG: {
return attr_desc.l();
}
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
} }
......
...@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> { ...@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
const std::string& attr_name_; const std::string& attr_name_;
}; };
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, attr.type().name());
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
......
...@@ -26,6 +26,7 @@ enum AttrType { ...@@ -26,6 +26,7 @@ enum AttrType {
BOOLEAN = 6; BOOLEAN = 6;
BOOLEANS = 7; BOOLEANS = 7;
BLOCK = 8; BLOCK = 8;
LONG = 9;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -44,6 +45,7 @@ message OpDesc { ...@@ -44,6 +45,7 @@ message OpDesc {
optional bool b = 10; optional bool b = 10;
repeated bool bools = 11; repeated bool bools = 11;
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13;
}; };
message Var { message Var {
......
...@@ -283,6 +283,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -283,6 +283,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
......
...@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>; std::vector<bool>, BlockDesc*, int64_t>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update") "Sparse update")
.SetDefault(false); .SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
......
...@@ -21,9 +21,11 @@ limitations under the License. */ ...@@ -21,9 +21,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag>
__global__ void LookupTable(T* output, const T* table, const int64_t* ids, __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const int64_t N, const int64_t K, const int64_t D) { const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX; int idy = blockIdx.x + threadIdx.y * GridDimX;
...@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, ...@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
T* out = output + idy * D; T* out = output + idy * D;
const T* tab = table + id * D; const T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
out[i] = tab[i]; if (PaddingFlag) {
if (id == padding_idx)
out[i] = static_cast<T>(0);
else
out[i] = tab[i];
} else {
out[i] = tab[i];
}
} }
idy += BlockDimY * GridDimX; idy += BlockDimY * GridDimX;
} }
...@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W"); auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids"); auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out"); auto* output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
...@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTable<
T, 128, 8, if (padding_idx == -1)
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>( LookupTable<
output, table, ids, N, K, D); T, 128, 8, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
} }
}; };
...@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W"); auto* table = context.Input<LoDTensor>("W");
......
...@@ -32,16 +32,30 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -32,16 +32,30 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
auto* ids = ids_t->data<int64_t>(); auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>(); auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); if (padding_idx == -1) {
PADDLE_ENFORCE_GE(ids[i], 0); for (int64_t i = 0; i < ids_t->numel(); ++i) {
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
} else {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
} }
} }
}; };
...@@ -51,6 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -51,6 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W"); auto* table = context.Input<LoDTensor>("W");
......
...@@ -64,6 +64,8 @@ std::string AttrType(paddle::framework::proto::AttrType at) { ...@@ -64,6 +64,8 @@ std::string AttrType(paddle::framework::proto::AttrType at) {
return "bool array"; return "bool array";
case paddle::framework::proto::BLOCK: case paddle::framework::proto::BLOCK:
return "block id"; return "block id";
case paddle::framework::proto::LONG:
return "long";
} }
return "UNKNOWN"; // not possible return "UNKNOWN"; // not possible
} }
......
...@@ -185,22 +185,35 @@ def fc(input, ...@@ -185,22 +185,35 @@ def fc(input,
return helper.append_activation(pre_activation) return helper.append_activation(pre_activation)
def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): def embedding(input,
size,
is_sparse=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
""" """
**Embedding Layer** **Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table. This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
The result of this lookup is the embedding of each ID in the *input*. a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper All the input variables are passed in as local variables to the LayerHelper
constructor. constructor.
Args: Args:
input(Variable): Input to the function input(Variable): The tensor variable containing the IDs.
size(tuple|list|None): Shape of the look up table parameter size(tuple|list): The shape of the look up table parameter. It should
is_sparse(bool): Boolean flag that specifying whether the input is sparse have two elements which indicate the size of the dictionary of
param_attr(ParamAttr): Parameters for this layer embeddings and the size of each embedding vector respectively.
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc is_sparse(bool): The flag indicating whether to use sparse update.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
Returns: Returns:
Variable: The tensor variable storing the embeddings of the \ Variable: The tensor variable storing the embeddings of the \
...@@ -218,12 +231,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'): ...@@ -218,12 +231,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_tmp_variable(dtype)
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)
helper.append_op( helper.append_op(
type='lookup_table', type='lookup_table',
inputs={'Ids': input, inputs={'Ids': input,
'W': w}, 'W': w},
outputs={'Out': tmp}, outputs={'Out': tmp},
attrs={'is_sparse': is_sparse}) attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx})
return tmp return tmp
......
...@@ -33,5 +33,19 @@ class TestLookupTableOp(OpTest): ...@@ -33,5 +33,19 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
class TestLookupTableOpWithPadding(TestLookupTableOp):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': long(padding_idx)}
self.check_output()
def test_check_grad(self):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册