From ec24bc989c3f55695c281dc177930c7a0e74c09a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 17:10:28 +0800 Subject: [PATCH] add get inout var ptr for dygraph (#39134) --- .../fluid/eager/legacy/infer_shape_context.h | 23 +++++++++++++++---- paddle/fluid/imperative/infer_shape_context.h | 23 +++++++++++++++---- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/eager/legacy/infer_shape_context.h b/paddle/fluid/eager/legacy/infer_shape_context.h index a1032fd404f..0979abc63d6 100644 --- a/paddle/fluid/eager/legacy/infer_shape_context.h +++ b/paddle/fluid/eager/legacy/infer_shape_context.h @@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { paddle::framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_in_->find(name); + PADDLE_ENFORCE_NE(it, tensor_in_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in inputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_out_->find(name); + PADDLE_ENFORCE_NE(it, tensor_out_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in outputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override { diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index a16ad1688fb..71f7fb7387e 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext { (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("Can not find [%s] in inputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("Can not find [%s] in outputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override { -- GitLab