自己写了Operator在模型中使用发生异常,请问如何解决?
Created by: zzhzz
异常信息如下: paddle.fluid.core.EnforceNotMet: Enforce failed. Expected arg_names.size() == 1UL, but received arg_names.size():0 != 1UL:1. Output(NodesVector@GRAD) should hold one element, but now it holds 0 at [/mnt/tree_conv/paddle/fluid/framework/shape_inference.cc:59] PaddlePaddle Call Stacks: 0x7f4c05b574dcp paddle::framework::InferShapeContext::SetOutputDim(std::string const&, paddle::framework::DDim const&) + 316 2 0x7f4c058a8d72p paddle::operators::TreeConvGradOp::InferShape(paddle::framework::InferShapeContext*) const + 162 3 0x7f4c04d00246p paddle::framework::OpDesc::InferShape(paddle::framework::BlockDesc const&) const + 886 4 0x7f4c04cabb65p ZZN8pybind1112cpp_function10initializeIZNS0_C1IvN6paddle9framework6OpDescEIRKNS4_9BlockDescEEINS_4nameENS_9is_methodENS_7siblingEEEEMT0_KFT_DpT1_EDpRKT2_EUlPKS5_S8_E_vISN_S8_EIS9_SA_SB_EEEvOSD_PFSC_SF_ESL_ENUlRNS_6detail13function_callEE1_4_FUNESU + 213 5 0x7f4c04c5f864p pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 2596 6 0x7f4c36925ce8p PyEval_EvalFrameEx + 28264 7 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 8 0x7f4c36925d70p PyEval_EvalFrameEx + 28400 9 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 10 0x7f4c36925d70p PyEval_EvalFrameEx + 28400 11 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 12 0x7f4c36925d70p PyEval_EvalFrameEx + 28400 13 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 14 0x7f4c36925d70p PyEval_EvalFrameEx + 28400 15 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 16 0x7f4c36925d70p PyEval_EvalFrameEx + 28400 17 0x7f4c36925e9ep PyEval_EvalFrameEx + 28702 18 0x7f4c3692837dp PyEval_EvalCodeEx + 2061 19 0x7f4c369284b2p PyEval_EvalCode + 50 20 0x7f4c369521c2p PyRun_FileExFlags + 146 21 0x7f4c36953559p PyRun_SimpleFileExFlags + 217 22 0x7f4c369691ddp Py_Main + 3149 23 0x7f4c35bfcd1dp __libc_start_main + 253 24 0x4006b1p
我的反向计算Kernel的InferShape函数,代码如下:
void InferShape(framework::InferShapeContext *ctx) const override {
66
67 PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "the gradient of output must not be null");
68
69 if(ctx->HasOutput(framework::GradVarName("NodesVector"))){
70 ctx->SetOutputDim(framework::GradVarName("NodesVector"),
71 ctx->GetInputDim("NodesVector"));
72 }
73 if(ctx->HasOutput(framework::GradVarName("Filter"))){
74 ctx->SetOutputDim(framework::GradVarName("Filter"),
75 ctx->GetInputDim("Filter"));
76 }
77 }
请问如何解决?