diff --git a/paddle/fluid/eager/legacy/infer_shape_context.h b/paddle/fluid/eager/legacy/infer_shape_context.h index a1032fd404f8551c124c9589067c6adeab48fc75..0979abc63d65870e1a2aabdc14116a55d786ed00 100644 --- a/paddle/fluid/eager/legacy/infer_shape_context.h +++ b/paddle/fluid/eager/legacy/infer_shape_context.h @@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { paddle::framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_in_->find(name); + PADDLE_ENFORCE_NE(it, tensor_in_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in inputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_out_->find(name); + PADDLE_ENFORCE_NE(it, tensor_out_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in outputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override { diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index a16ad1688fbacd4c218c501da213205a1b853897..71f7fb7387effe68ae63d5a3c5236e9a9a108d2f 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext { (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("Can not find [%s] in inputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("Can not find [%s] in outputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override {