From 31ccaf091641b991af885427eb3071a276ccc70e Mon Sep 17 00:00:00 2001 From: luotao1 Date: Mon, 11 Mar 2019 19:58:41 +0800 Subject: [PATCH] add all_kernels_must_compute_runtime_shape example for speedup infershape test=develop --- paddle/fluid/framework/operator.cc | 11 +++++++++-- .../operators/fused/fused_embedding_seq_pool_op.cc | 11 ++++++++--- paddle/fluid/operators/hash_op.cc | 11 ++++++++--- .../operators/sequence_ops/sequence_enumerate_op.cc | 11 ++++++++--- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index df1689764d2..9f48b8cb9e7 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -926,8 +926,15 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); - this->InferShape(&infer_shape_ctx); + // If Op has attribute all_kernels_must_compute_runtime_shape, + // all the kernels of this Op would compute runtime shape, + // and skip infershape in runtime for speedup. + // TODO(luotao): Note that it is a temporal attribute, after all ops + // implement computing runtime shape, this attribute would be deleted. + if (!HasAttr("all_kernels_must_compute_runtime_shape")) { + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); + this->InferShape(&infer_shape_ctx); + } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. kernel_iter->second( diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 80caf70b08e..17a81d3e880 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -23,9 +23,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE(ctx->HasInput("W"), "Input W of FusedEmbeddingSeqPoolOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Ids"), @@ -91,6 +88,14 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { "(boolean, default false) " "Sparse update.") .SetDefault(false); + AddAttr( + "all_kernels_must_compute_runtime_shape", + "(boolean, default true) " + "An attribute to speed up OperatorWithKernel::RunImpl." + "If true, all the kernels of this Op would compute runtime " + "shape, but skip infershape in runtime. Note that it is a temporal " + "attribute, please do DOT set it in python layer.") + .SetDefault(true); AddComment(R"DOC( FusedEmbeddingSeqPool Operator. diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index 7a29f80ff1c..b39eba081ec 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -26,9 +26,6 @@ class HashOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of HashOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -57,6 +54,14 @@ $$Out = scale * X$$ )DOC"); AddAttr("num_hash", "").SetDefault(1); AddAttr("mod_by", "").SetDefault(100000); + AddAttr( + "all_kernels_must_compute_runtime_shape", + "(boolean, default true) " + "An attribute to speed up OperatorWithKernel::RunImpl." + "If true, all the kernels of this Op would compute runtime " + "shape, but skip infershape in runtime. Note that it is a temporal " + "attribute, please do DOT set it in python layer.") + .SetDefault(true); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index d3dcd1f96a9..63e95e86544 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -22,9 +22,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE( ctx->HasInput("X"), "Input(X) of SequecceEnumerate operator should not be null."); @@ -62,6 +59,14 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("pad_value", "(int) The enumerate sequence padding value.") .SetDefault(0); + AddAttr( + "all_kernels_must_compute_runtime_shape", + "(boolean, default true) " + "An attribute to speed up OperatorWithKernel::RunImpl." + "If true, all the kernels of this Op would compute runtime " + "shape, but skip infershape in runtime. Note that it is a temporal " + "attribute, please do DOT set it in python layer.") + .SetDefault(true); AddComment(R"DOC( Sequence Enumerate Operator. -- GitLab