未验证 提交 17b8335b 编写于 作者: J jiangcheng 提交者: GitHub

fix cinn graph may hasn't input problem (#40814)

上级 db41e39e
......@@ -24,7 +24,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
......@@ -43,6 +45,53 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
});
ctx->SetOutputsDim(kOutputs, output_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Why we need override GetExpectedKernelType?
// A cinn-graph may has no inpute var, if we use the base function,
// it will check wheter input tensors is initialized. Here we rewrite
// the function so that we can infer kernel type by output date type.
if (ctx.InputSize(kX)) {
// if the instruction has input, infer kernel type by input date type:
return OperatorWithKernel::GetExpectedKernelType(ctx);
}
// Else infer kernel type by output date type:
// The `OutputVar` will check wheter the kOutputs iff has one output var
const framework::Variable* var = ctx.OutputVar(kOutputs);
PADDLE_ENFORCE_NE(
var, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Variable should not empty."));
const framework::Tensor* tensor = nullptr;
if (var->IsType<framework::Tensor>()) {
tensor = &var->Get<framework::Tensor>();
} else if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
tensor = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = &var->Get<framework::LoDTensorArray>();
PADDLE_ENFORCE_EQ(t_arr->size(), 1UL,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op should just has One "
"Output when Input empty."));
tensor = &(t_arr->front());
}
PADDLE_ENFORCE_NE(
tensor, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Tensor should not empty."));
VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is "
<< paddle::framework::DataType2String(tensor->dtype());
auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype());
return framework::OpKernelType(output_type, ctx.device_context());
}
};
class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -87,9 +87,12 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
"Input", string::format_string("%s|%s", kX, kNoNeedBufferX),
"CinnLaunchOp");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
// "Input", string::format_string("%s|%s", kX,
// kNoNeedBufferX),
// "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册