提交 31ccaf09 编写于 作者: L luotao1

add all_kernels_must_compute_runtime_shape example for speedup infershape

test=develop
上级 ad80bde8
......@@ -926,8 +926,15 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
// If Op has attribute all_kernels_must_compute_runtime_shape,
// all the kernels of this Op would compute runtime shape,
// and skip infershape in runtime for speedup.
// TODO(luotao): Note that it is a temporal attribute, after all ops
// implement computing runtime shape, this attribute would be deleted.
if (!HasAttr("all_kernels_must_compute_runtime_shape")) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
}
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(
......
......@@ -23,9 +23,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
......@@ -91,6 +88,14 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>(
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true);
AddComment(R"DOC(
FusedEmbeddingSeqPool Operator.
......
......@@ -26,9 +26,6 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -57,6 +54,14 @@ $$Out = scale * X$$
)DOC");
AddAttr<int>("num_hash", "").SetDefault(1);
AddAttr<int>("mod_by", "").SetDefault(100000);
AddAttr<bool>(
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true);
}
};
......
......@@ -22,9 +22,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
......@@ -62,6 +59,14 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0);
AddAttr<bool>(
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true);
AddComment(R"DOC(
Sequence Enumerate Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册