未验证 提交 ec24bc98 编写于 作者: C Chen Weihang 提交者: GitHub

add get inout var ptr for dygraph (#39134)

上级 7ac2f80f
......@@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
paddle::framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(it, tensor_in_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in inputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}
std::vector<paddle::framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in outputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}
DDim GetInputDim(const std::string& name) const override {
......
......@@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("Can not find [%s] in inputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("Can not find [%s] in outputs.", name));
for (auto& var : it->second) {
res.emplace_back(var->MutableVar());
}
return res;
}
DDim GetInputDim(const std::string& name) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册