提交 a7a4f053 编写于 作者: S sneaxiy

Merge develop

test=develop
......@@ -926,8 +926,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
}
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(
......
......@@ -62,6 +62,15 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
"@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";
// define some kernel priority
/* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
......
......@@ -23,9 +23,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
......@@ -91,6 +88,8 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
.SetDefault(true);
AddComment(R"DOC(
FusedEmbeddingSeqPool Operator.
......
......@@ -121,6 +121,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
// runtime shape
d_table->set_height(table_dim[0]);
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
......
......@@ -26,9 +26,6 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -57,6 +54,8 @@ $$Out = scale * X$$
)DOC");
AddAttr<int>("num_hash", "").SetDefault(1);
AddAttr<int>("mod_by", "").SetDefault(100000);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
.SetDefault(true);
}
};
......
......@@ -22,9 +22,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
......@@ -62,6 +59,8 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
.SetDefault(true);
AddComment(R"DOC(
Sequence Enumerate Operator.
......
......@@ -455,7 +455,11 @@ function assert_api_spec_approvals() {
# NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable.
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 2887803 35982308`
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 2887803 35982308 46782768 30176695`
if [ "${APPROVALS}" == "TRUE" ];then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 35982308`
fi
else
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
......@@ -463,7 +467,7 @@ function assert_api_spec_approvals() {
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
echo "You must have panyx0718 and shanyi15 approval for the api change! ${API_FILE}"
echo "You must have one RD (panyx0718 or chengduoZH or XiaoguangHu01) and one PM (shanyi15) approval for the api change! ${API_FILE}"
else
echo "You must have panyx0718 approval for the api change! ${API_FILE}"
fi
......
......@@ -51,6 +51,8 @@ def visit_member(parent_name, member):
all = (args, doc)
member_dict[cur_name] = all
except TypeError: # special for PyBind method
if cur_name in check_modules_list:
return
member_dict[cur_name] = " ".join([
line.strip() for line in pydoc.render_doc(member).split('\n')
if "->" in line
......@@ -78,6 +80,7 @@ def visit_all_module(mod):
visit_member(mod.__name__, instance)
check_modules_list = ["paddle.reader.ComposeNotAligned.__init__"]
modules = sys.argv[1].split(",")
for m in modules:
visit_all_module(importlib.import_module(m))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册