From 34404f9c318975d82443f97892cd1bc1690871e1 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Wed, 27 Feb 2019 17:23:10 +0800 Subject: [PATCH] refine infershape of sequence_enumerate, hash and fuse_emb_seq_pool test=develop --- .../fused/fused_embedding_seq_pool_op.cc | 40 +++++-------------- .../fused/fused_embedding_seq_pool_op.h | 20 ++++++++++ paddle/fluid/operators/hash_op.cc | 15 +++---- paddle/fluid/operators/hash_op.h | 25 ++++++++++-- .../sequence_ops/sequence_enumerate_op.cc | 9 +++-- .../sequence_ops/sequence_enumerate_op.cu | 2 + .../sequence_ops/sequence_enumerate_op.h | 2 + 7 files changed, 68 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index fe4c73f47..80caf70b0 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->IsRuntime()) { + return; + } PADDLE_ENFORCE(ctx->HasInput("W"), "Input W of FusedEmbeddingSeqPoolOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Ids"), @@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { // we only support sum now PADDLE_ENFORCE_EQ(combiner, "sum"); - int64_t last_dim = table_dims[1]; - for (int i = 1; i != ids_dims.size(); ++i) { - last_dim *= ids_dims[i]; - } - - if (ctx->IsRuntime()) { - framework::Variable* ids_var = - boost::get(ctx->GetInputVarPtrs("Ids")[0]); - const auto& ids_lod = ids_var->Get().lod(); + int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims); + // in compile time, the lod level of ids must be 1 + framework::VarDesc* ids_desc = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); + PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); - // in run time, the LoD of ids must be 1 - PADDLE_ENFORCE(ids_lod.size(), 1u, - "The LoD level of Input(Ids) must be 1"); - PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); - - int64_t batch_size = ids_lod[0].size() - 1; - - // in run time, the shape from Ids -> output - // should be [seq_length, 1] -> [batch_size, embedding_size] - ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim})); - } else { - // in compile time, the lod level of ids must be 1 - framework::VarDesc* ids_desc = - boost::get(ctx->GetInputVarPtrs("Ids")[0]); - PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); - - // in compile time, the shape from Ids -> output - // should be [-1, 1] -> [-1, embedding_size] - ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim})); - } + // in compile time, the shape from Ids -> output + // should be [-1, 1] -> [-1, embedding_size] + ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim})); } protected: diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 33a1b47d1..2b0c1f560 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -61,6 +61,15 @@ struct EmbeddingVSumFunctor { } }; +inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims, + const framework::DDim &ids_dims) { + int64_t last_dim = table_dims[1]; + for (int i = 1; i != ids_dims.size(); ++i) { + last_dim *= ids_dims[i]; + } + return last_dim; +} + template class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { public: @@ -70,6 +79,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { const LoDTensor *table_var = context.Input("W"); const std::string &combiner_type = context.Attr("combiner"); + int64_t last_dim = + FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); + const auto &ids_lod = ids_t->lod(); + // in run time, the LoD of ids must be 1 + PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1"); + PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); + int64_t batch_size = ids_lod[0].size() - 1; + // in run time, the shape from Ids -> output + // should be [seq_length, 1] -> [batch_size, embedding_size] + output_t->Resize({batch_size, last_dim}); + if (combiner_type == "sum") { EmbeddingVSumFunctor functor; functor(context, table_var, ids_t, output_t); diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index b2c2c7954..7a29f80ff 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/hash_op.h" #include -#include namespace paddle { namespace operators { @@ -27,6 +26,9 @@ class HashOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { + if (ctx->IsRuntime()) { + return; + } PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of HashOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -36,15 +38,8 @@ class HashOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dims.size(), 2UL, "The input of hash_op's dimensions must be 2"); std::vector out_dims; - out_dims.reserve(dims.size() + 1); - // copy all dims except the last one - for (int i = 0u; i != dims.size() - 1; ++i) { - out_dims.emplace_back(dims[i]); - } int num_hash = ctx->Attrs().Get("num_hash"); - out_dims.emplace_back(num_hash); - // keep the last dim to 1 - out_dims.emplace_back(1); + HashOutputSize(dims, out_dims, num_hash); ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->ShareLoD("X", /*->*/ "Out"); @@ -71,4 +66,4 @@ $$Out = scale * X$$ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker); -REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel, ops::HashKerel); +REGISTER_OP_CPU_KERNEL(hash, ops::HashKernel, ops::HashKernel); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h index 9781bb0f4..9e7ad5235 100644 --- a/paddle/fluid/operators/hash_op.h +++ b/paddle/fluid/operators/hash_op.h @@ -17,21 +17,34 @@ limitations under the License. */ extern "C" { #include } +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { -// template + +inline void HashOutputSize(const framework::DDim& in_dims, + std::vector& out_dims, // NOLINT + int num_hash) { + out_dims.reserve(in_dims.size() + 1); + // copy all dims except the last one + for (int i = 0u; i != in_dims.size() - 1; ++i) { + out_dims.emplace_back(in_dims[i]); + } + out_dims.emplace_back(num_hash); + // keep the last dim to 1 + out_dims.emplace_back(1); +} + template -class HashKerel : public framework::OpKernel { +class HashKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& context) const { auto* out_t = context.Output("Out"); auto* in_t = context.Input("X"); int mod_by = context.Attr("mod_by"); int num_hash = context.Attr("num_hash"); - auto* output = out_t->mutable_data(context.GetPlace()); auto in_dims = in_t->dims(); auto in_lod = in_t->lod(); @@ -39,6 +52,11 @@ class HashKerel : public framework::OpKernel { static_cast(in_dims[0]), in_lod[0].back(), "The actual input data's size mismatched with LoD information."); + std::vector out_dims; + HashOutputSize(in_dims, out_dims, num_hash); + out_t->Resize(framework::make_ddim(out_dims)); + auto* output = out_t->mutable_data(context.GetPlace()); + auto seq_length = in_dims[0]; auto last_dim = in_dims[in_dims.size() - 1]; auto* input = in_t->data(); @@ -49,6 +67,7 @@ class HashKerel : public framework::OpKernel { } input += last_dim; } + out_t->set_lod(in_t->lod()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index 0932211ca..d3dcd1f96 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -22,6 +22,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->IsRuntime()) { + return; + } PADDLE_ENFORCE( ctx->HasInput("X"), "Input(X) of SequecceEnumerate operator should not be null."); @@ -33,9 +36,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_dims.size(), 2, "Input(X) of SequenceEnumerate operator's rank should be 2."); - PADDLE_ENFORCE_EQ( - x_dims[1], 1, - "Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); + PADDLE_ENFORCE_EQ(x_dims[1], 1, + "Input(X) of SequenceEnumerate operator's 2nd " + "dimension should be 1."); const auto win_size = ctx->Attrs().Get("win_size"); ctx->SetOutputDim("Out", {x_dims[0], win_size}); diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu index 28821e712..d5deb7582 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cu @@ -65,6 +65,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { auto lod0 = in_lod[0]; auto in_len = in->numel(); auto in_data = in->data(); + out->Resize({in_dims[0], win_size}); auto out_data = out->mutable_data(context.GetPlace()); // Copy LoD to GPU const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); @@ -72,6 +73,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data); + out->set_lod(in->lod()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h index dc18d9b20..18da69993 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h @@ -39,6 +39,7 @@ class SequenceEnumerateKernel : public framework::OpKernel { // Generate enumerate sequence set auto lod0 = in_lod[0]; auto in_data = in->data(); + out->Resize({in_dims[0], win_size}); auto out_data = out->mutable_data(context.GetPlace()); for (size_t i = 0; i < lod0.size() - 1; ++i) { for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { @@ -49,6 +50,7 @@ class SequenceEnumerateKernel : public framework::OpKernel { } } } + out->set_lod(in->lod()); } }; -- GitLab