diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index fe4d1546ea1d7f88cca1f61f98ceb98e85ee8dda..98799e049dff384c17304e3ba50a3bcfaf2c9a2c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -305,11 +305,18 @@ void build_op_func_list(const platform::Place& place, RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); runtime_context.outputs.swap(outs_map); - InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); - // TODO(Aurelius84): In case of control flow ops, they are NOT inheritted - // from OperatorWithKernel. - static_cast(op)->InferShape( - &infer_shape_ctx); + + // see OperatorWithKernel::RunImpl in operator.cc for why + if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op->Attr(kAllKernelsMustComputeRuntimeShape))) { + InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); + // TODO(Aurelius84): In case of control flow ops, they are NOT + // inheritted + // from OperatorWithKernel. + static_cast(op)->InferShape( + &infer_shape_ctx); + } + auto kernels_iter = all_op_kernels.find(op->Type()); PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(),