提交 fe78a92e 编写于 作者: L luotao1

refine with comments

test=develop
上级 5d20954a
...@@ -926,12 +926,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -926,12 +926,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
// If Op has attribute all_kernels_must_compute_runtime_shape, if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
// all the kernels of this Op would compute runtime shape,
// and skip infershape in runtime for speedup.
// TODO(luotao): Note that it is a temporal attribute, after all ops
// implement computing runtime shape, this attribute would be deleted.
if (!HasAttr("all_kernels_must_compute_runtime_shape")) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
......
...@@ -62,6 +62,15 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -62,6 +62,15 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
"@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";
// define some kernel priority // define some kernel priority
/* Define multiple kernel type fallback order*/ /* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
......
...@@ -88,13 +88,7 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,13 +88,7 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>( AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
FusedEmbeddingSeqPool Operator. FusedEmbeddingSeqPool Operator.
......
...@@ -54,13 +54,7 @@ $$Out = scale * X$$ ...@@ -54,13 +54,7 @@ $$Out = scale * X$$
)DOC"); )DOC");
AddAttr<int>("num_hash", "").SetDefault(1); AddAttr<int>("num_hash", "").SetDefault(1);
AddAttr<int>("mod_by", "").SetDefault(100000); AddAttr<int>("mod_by", "").SetDefault(100000);
AddAttr<bool>( AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true); .SetDefault(true);
} }
}; };
......
...@@ -59,13 +59,7 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,13 +59,7 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.") AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>( AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
"all_kernels_must_compute_runtime_shape",
"(boolean, default true) "
"An attribute to speed up OperatorWithKernel::RunImpl."
"If true, all the kernels of this Op would compute runtime "
"shape, but skip infershape in runtime. Note that it is a temporal "
"attribute, please do DOT set it in python layer.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Sequence Enumerate Operator. Sequence Enumerate Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册