提交 6db8c3bf 编写于 作者: M minqiyang

Implement the infer shape and infer var type

上级 c69d2bbe
...@@ -18,34 +18,53 @@ limitations under the License. */ ...@@ -18,34 +18,53 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class LookupTableOp : public framework::OperatorWithKernel { class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null."); "Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of LookupTableOp should not be null."); "Input Ids of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupTableOp should not be null."); "Output of FusedEmbeddingSeqPoolOp should not be null.");
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size(); const std::string& combiner = ctx->Attrs().Get<std::string>("combiner");
PADDLE_ENFORCE_EQ(table_dims.size(), 2); PADDLE_ENFORCE_EQ(table_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, PADDLE_ENFORCE_GE(ids_dims.size(), 1u,
"The dim size of the 'Ids' tensor must greater than 1.");
PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1,
"The last dimension of the 'Ids' tensor must be 1."); "The last dimension of the 'Ids' tensor must be 1.");
// we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum");
auto output_dims = if (ctx->IsRuntime()) {
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); Variable* ids_var = boost::get<Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
output_dims.push_back(table_dims[1]); const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->GetOutputsVarType("Out")[0] == // in run time, the LoD of ids must be 1
framework::proto::VarType::LOD_TENSOR) { PADDLE_ENFORCE(ids_lod.size(), 1u,
ctx->ShareLoD("Ids", /*->*/ "Out"); "The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
size_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out",
framework::make_ddim({batch_size, table_dims[1]}));
} else {
// in compile time, the lod level of ids must be 1
VarDesc* ids_desc = boost::get<VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, table_dims[1]}));
} }
} }
...@@ -57,7 +76,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -57,7 +76,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
} }
}; };
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("W", AddInput("W",
...@@ -68,42 +87,44 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,42 +87,44 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"contains the ids to be looked up in W. " "contains the ids to be looked up in W. "
"The last dimension size must be 1."); "The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W."); AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<std::string>("combiner",
"(string, default sum) "
"A string specifying the reduction op. Currently sum "
"are supported, sum computes the weighted sum of the "
"embedding results for each row.")
.SetDefault("sum");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. FusedEmbeddingSeqPool Operator.
Computes embeddings for the given ids and weights.
This operator is used to perform lookups on the parameter W, This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor. then computes the weighted sum of the lookups results for each row
and concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information, The input Ids should carry the LoD (Level of Details) information.
or not. And the output only shares the LoD information with input Ids. And the output will change the LoD information with input Ids.
)DOC"); )DOC");
} }
}; };
class LookupTableOpGradDescMaker class FusedEmbeddingSeqPoolOpGradDescMaker
: public framework::DefaultGradOpDescMaker<true> { : public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker< using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker; true>::DefaultGradOpDescMaker;
protected: protected:
virtual std::string GradOpType() const { return "lookup_table_grad"; } virtual std::string GradOpType() const {
return "fused_embedding_seq_pool_grad";
}
}; };
class LookupTableOpGrad : public framework::OperatorWithKernel { class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -120,7 +141,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -120,7 +141,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
} }
}; };
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
...@@ -128,13 +150,13 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -128,13 +150,13 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
auto attr = op_desc.GetAttr("is_sparse"); auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "fused_embedding_seq_pool_grad op "
<< " is set to SelectedRows"; << framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name) block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS); ->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "fused_embedding_seq_pool_grad op "
<< " is set to LoDTensor"; << framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
...@@ -145,14 +167,16 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -145,14 +167,16 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp,
ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker); ops::FusedEmbeddingSeqPoolOpGradDescMaker,
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, ops::FusedEmbeddingSeqPoolOpMaker);
ops::LookupTableOpGradVarTypeInference); REGISTER_OPERATOR(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolOpGrad,
// REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>, ops::FusedEmbeddingSeqPoolOpGradVarTypeInference);
// ops::LookupTableKernel<double>);
// REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>, REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool,
// ops::LookupTableGradKernel<double>); ops::FusedEmbeddingSeqPoolKernel<float>,
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>); ops::FusedEmbeddingSeqPoolKernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>); REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolGradKernel<float>,
ops::FusedEmbeddingSeqPoolGradKernel<double>);
...@@ -31,8 +31,6 @@ using LoDTensor = framework::LoDTensor; ...@@ -31,8 +31,6 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册