From 17b8335bbbb1f0c420b0da6a776d9c9aae872381 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Wed, 23 Mar 2022 11:03:22 +0800 Subject: [PATCH] fix cinn graph may hasn't input problem (#40814) --- .../operators/cinn/cinn_instruction_run_op.cc | 51 ++++++++++++++++++- paddle/fluid/operators/cinn/cinn_launch_op.cc | 9 ++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc index edf854a9c9..8139530b80 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc @@ -24,7 +24,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun"); + // The cinn-graph may hasn't input for CINN now support fill_constant, + // and its all inputs may generated by fill_constant instead of by fetch. + // OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun"); OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, "CinnInstructionRun"); const CinnCompiledObject& compiled_object = @@ -43,6 +45,53 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { }); ctx->SetOutputsDim(kOutputs, output_dims); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // Why we need override GetExpectedKernelType? + // A cinn-graph may has no inpute var, if we use the base function, + // it will check wheter input tensors is initialized. Here we rewrite + // the function so that we can infer kernel type by output date type. + if (ctx.InputSize(kX)) { + // if the instruction has input, infer kernel type by input date type: + return OperatorWithKernel::GetExpectedKernelType(ctx); + } + + // Else infer kernel type by output date type: + // The `OutputVar` will check wheter the kOutputs iff has one output var + const framework::Variable* var = ctx.OutputVar(kOutputs); + PADDLE_ENFORCE_NE( + var, nullptr, + platform::errors::InvalidArgument( + "The cinn_instruction_run Op's Output Variable should not empty.")); + + const framework::Tensor* tensor = nullptr; + if (var->IsType()) { + tensor = &var->Get(); + } else if (var->IsType()) { + tensor = &var->Get(); + } else if (var->IsType()) { + tensor = &(var->Get().value()); + } else if (var->IsType()) { + auto t_arr = &var->Get(); + PADDLE_ENFORCE_EQ(t_arr->size(), 1UL, + platform::errors::InvalidArgument( + "The cinn_instruction_run Op should just has One " + "Output when Input empty.")); + tensor = &(t_arr->front()); + } + + PADDLE_ENFORCE_NE( + tensor, nullptr, + platform::errors::InvalidArgument( + "The cinn_instruction_run Op's Output Tensor should not empty.")); + + VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is " + << paddle::framework::DataType2String(tensor->dtype()); + auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype()); + return framework::OpKernelType(output_type, ctx.device_context()); + } }; class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index d918b7216c..5d006a947b 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -87,9 +87,12 @@ class CinnLaunchOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX), - "Input", string::format_string("%s|%s", kX, kNoNeedBufferX), - "CinnLaunchOp"); + // The cinn-graph may hasn't input for CINN now support fill_constant, + // and its all inputs may generated by fill_constant instead of by fetch. + // OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX), + // "Input", string::format_string("%s|%s", kX, + // kNoNeedBufferX), + // "CinnLaunchOp"); OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, "CinnLaunchOp"); } -- GitLab