diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4a078258d23be25e4a71fda6c9c19e7c87fbfe40..82a23797d4720e74a7adb8d88b41555b4c3eb71a 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -342,10 +342,9 @@ class InferShapeContext { PADDLE_ENFORCE_LT(j, OutputSize(out)); auto* in_var = MultiInputVar(in)[i]; auto* out_var = MultiOutputVar(out)[j]; - PADDLE_ENFORCE(in_var->IsType(), - "The %d-th input of Input(%s) must be LoDTensor.", in); + if (!in_var->IsType()) return; PADDLE_ENFORCE(out_var->IsType(), - "The %d-th output of Output(%s) must be LoDTensor.", out); + "The %d-th output of Output(%s) must be LoDTensor.", j, out); auto in_tensor = in_var->Get(); auto* out_tensor = out_var->GetMutable(); out_tensor->set_lod(in_tensor.lod()); @@ -363,6 +362,13 @@ template <> const std::vector InferShapeContext::MultiInput( const std::string& name) const; +template <> +Tensor* InferShapeContext::Output(const std::string& name) const; + +template <> +std::vector InferShapeContext::MultiOutput( + const std::string& name) const; + template struct EigenDeviceConverter;