diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fb985a0c0fe77a9527c7e15f8c4d1423b9ce499a..44821aadf6d1f6777484dab5d25d01ff3b42596b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -926,8 +926,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); - this->InferShape(&infer_shape_ctx); + if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); + this->InferShape(&infer_shape_ctx); + } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. kernel_iter->second( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 55629636a816982c4debe4b5b7138558ac309eb5..822bf5c9ceaa31e1283fa3cf1dbe42a43894a5dd 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -62,6 +62,15 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; /// Variables with this suffix are the new Gradient. constexpr char kNewGradSuffix[] = "@NEWGRAD@"; +/// If an Op has this attribute, all its kernels should calculate output +/// variable's shape in the corresponding Compute() function. And +/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() +/// function in its runtime for speedup. +/// TODO(luotao): Note that this temporal attribute would be deleted after all +/// ops contain it. +constexpr char kAllKernelsMustComputeRuntimeShape[] = + "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@"; + // define some kernel priority /* Define multiple kernel type fallback order*/ extern std::vector> kKernelPriority; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 80caf70b08e65932d6ccb90a5293d072b2b2bc72..a0026427e2514735711f7eba26fcf861cb498d5e 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -23,9 +23,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE(ctx->HasInput("W"), "Input W of FusedEmbeddingSeqPoolOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Ids"), @@ -91,6 +88,8 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { "(boolean, default false) " "Sparse update.") .SetDefault(false); + AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + .SetDefault(true); AddComment(R"DOC( FusedEmbeddingSeqPool Operator. diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 5e2e336e7117cc4816a52405b7bc2689bc03dd46..4651c2b2ba81a404b64818fec81cef79634ff036 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -121,6 +121,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); + // runtime shape + d_table->set_height(table_dim[0]); auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index 7a29f80ff1ce413519ea9cea6a35747bdced5885..f6395fb32feac175976cb96e1c0bee7347cb3ea8 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -26,9 +26,6 @@ class HashOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of HashOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -57,6 +54,8 @@ $$Out = scale * X$$ )DOC"); AddAttr("num_hash", "").SetDefault(1); AddAttr("mod_by", "").SetDefault(100000); + AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + .SetDefault(true); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index d3dcd1f96a986d2450c8af780a12183f7dfc66d5..f357c9c08d042b69259f229955922f2f11b52c63 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -22,9 +22,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - if (ctx->IsRuntime()) { - return; - } PADDLE_ENFORCE( ctx->HasInput("X"), "Input(X) of SequecceEnumerate operator should not be null."); @@ -62,6 +59,8 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("pad_value", "(int) The enumerate sequence padding value.") .SetDefault(0); + AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + .SetDefault(true); AddComment(R"DOC( Sequence Enumerate Operator.