提交 6db8c3bf 编写于 作者: M minqiyang

Implement the infer shape and infer var type

上级 c69d2bbe
......@@ -18,34 +18,53 @@ limitations under the License. */
namespace paddle {
namespace operators {
class LookupTableOp : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
"Input Ids of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupTableOp should not be null.");
"Output of FusedEmbeddingSeqPoolOp should not be null.");
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
const std::string& combiner = ctx->Attrs().Get<std::string>("combiner");
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
PADDLE_ENFORCE_GE(ids_dims.size(), 1u,
"The dim size of the 'Ids' tensor must greater than 1.");
PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
// we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum");
auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
output_dims.push_back(table_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->IsRuntime()) {
Variable* ids_var = boost::get<Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Ids", /*->*/ "Out");
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u,
"The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
size_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out",
framework::make_ddim({batch_size, table_dims[1]}));
} else {
// in compile time, the lod level of ids must be 1
VarDesc* ids_desc = boost::get<VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, table_dims[1]}));
}
}
......@@ -57,7 +76,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
};
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("W",
......@@ -68,42 +87,44 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"contains the ids to be looked up in W. "
"The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<std::string>("combiner",
"(string, default sum) "
"A string specifying the reduction op. Currently sum "
"are supported, sum computes the weighted sum of the "
"embedding results for each row.")
.SetDefault("sum");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddComment(R"DOC(
Lookup Table Operator.
FusedEmbeddingSeqPool Operator.
Computes embeddings for the given ids and weights.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
then computes the weighted sum of the lookups results for each row
and concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
The input Ids should carry the LoD (Level of Details) information.
And the output will change the LoD information with input Ids.
)DOC");
}
};
class LookupTableOpGradDescMaker
class FusedEmbeddingSeqPoolOpGradDescMaker
: public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;
protected:
virtual std::string GradOpType() const { return "lookup_table_grad"; }
virtual std::string GradOpType() const {
return "fused_embedding_seq_pool_grad";
}
};
class LookupTableOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -120,7 +141,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
};
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
......@@ -128,13 +150,13 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
......@@ -145,14 +167,16 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp,
ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableOpGradVarTypeInference);
// REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
// ops::LookupTableKernel<double>);
// REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
// ops::LookupTableGradKernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp,
ops::FusedEmbeddingSeqPoolOpGradDescMaker,
ops::FusedEmbeddingSeqPoolOpMaker);
REGISTER_OPERATOR(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolOpGrad,
ops::FusedEmbeddingSeqPoolOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool,
ops::FusedEmbeddingSeqPoolKernel<float>,
ops::FusedEmbeddingSeqPoolKernel<double>);
REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolGradKernel<float>,
ops::FusedEmbeddingSeqPoolGradKernel<double>);
......@@ -31,8 +31,6 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1;
template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册