提交 1a3b38a4 编写于 作者: M minqiyang

Polish code

test=develop
上级 133bac2b
...@@ -81,6 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,6 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise the given value indicates padding the output " "Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.") "with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding); .SetDefault(kNoPadding);
AddAttr<bool>("grad_inplace",
"(boolean, default false) "
"If the grad op reuse the input's variable.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
......
...@@ -119,9 +119,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -119,9 +119,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
// FIXME(minqiyang):
// memory optimization will NOT reuse Tensor with SelectedRows // memory optimization will NOT reuse Tensor with SelectedRows
// so we could just share the tensor here directly. // so we could just share the tensor here directly.
// However, the InferVarType method will infer the output SelectedRows
// to Tensor sometimes, which is a bug, so we will add an attribute
// here to indicate the inplace and remove this attribute after
// the InferVarType's bug was fixed
bool grad_inplace = context.Attr<bool>("grad_inplace");
if (grad_inplace) {
d_table_value->ShareDataWith(*d_output); d_table_value->ShareDataWith(*d_output);
} else {
d_table_value->mutable_data<T>(context.GetPlace());
d_table->set_height(table_dim[0]);
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->data<T>();
auto d_output_dims = d_output->dims();
PADDLE_ENFORCE_EQ(
d_table_value->dims(),
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
} else { } else {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册