提交 ae676a60 编写于 作者: G guosheng

Enhance lookup_table_op to support padding_idx

上级 9247aee7
......@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
return val;
}
case proto::AttrType::LONG: {
return attr_desc.l();
}
default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
}
......
......@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, attr.type().name());
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
......
......@@ -26,6 +26,7 @@ enum AttrType {
BOOLEAN = 6;
BOOLEANS = 7;
BLOCK = 8;
LONG = 9;
}
// OpDesc describes an instance of a C++ framework::OperatorBase
......@@ -44,6 +45,7 @@ message OpDesc {
optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
optional int64 l = 13;
};
message Var {
......
......@@ -282,6 +282,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
......
......@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>;
std::vector<bool>, BlockDesc*, int64_t>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
......
......@@ -66,11 +66,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update")
.SetDefault(false);
AddAttr<int64_t>(
"padding_idx",
"(int64_t, default -1) "
" If given, pads the output with zeros whenever it encounters "
"the index.")
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
AddComment(R"DOC(
Lookup Table Operator.
......
......@@ -21,9 +21,11 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag>
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
const int64_t N, const int64_t K, const int64_t D) {
const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
......@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
T* out = output + idy * D;
const T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
out[i] = tab[i];
if (PaddingFlag) {
if (idx == padding_idx)
out[i] = static_cast<T>(0);
else
out[i] = tab[i];
} else {
out[i] = tab[i];
}
}
idy += BlockDimY * GridDimX;
}
......@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
......@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<
T, 128, 8,
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D);
if (padding_idx == -1)
LookupTable<
T, 128, 8, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
};
......@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
......
......@@ -39,14 +39,23 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
} else {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
}
}
};
......@@ -56,8 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
......@@ -70,9 +79,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework::Vector<int64_t> new_rows;
new_rows.reserve(ids_dim[0]);
for (int64_t i = 0; i < ids_dim[0]; i++) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows);
......@@ -106,9 +112,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) {
if (ids_data[i] == padding_idx)
continue; // Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) {
......
......@@ -471,6 +471,7 @@ class Operator(object):
self.desc.set_serialized_attr(
attr_name, attrs[attr_name].serialize_to_string())
else:
# print 'haha', attrs[attr_name], type(attrs[attr_name])
self.desc.set_attr(attr_name, attrs[attr_name])
self.desc.check_attrs()
......
......@@ -176,22 +176,35 @@ def fc(input,
return helper.append_activation(pre_activation)
def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
def embedding(input,
size,
is_sparse=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
"""
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
The result of this lookup is the embedding of each ID in the *input*.
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
input(Variable): Input to the function
size(tuple|list|None): Shape of the look up table parameter
is_sparse(bool): Boolean flag that specifying whether the input is sparse
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
input(Variable): The tensor variable containing the IDs.
size(tuple|list): The shape of the look up table parameter. It should
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the \
......@@ -209,12 +222,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)
helper.append_op(
type='lookup_table',
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp},
attrs={'is_sparse': is_sparse})
attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx})
return tmp
......
......@@ -32,5 +32,19 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
class TestLookupTableOpWithPadding(TestLookupTableOp):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': long(padding_idx)}
self.check_output()
def test_check_grad(self):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册