未验证 提交 b809be1a 编写于 作者: W WangZhen 提交者: GitHub

Support zero dims input for eager run program OP (#44273)

上级 0470e9da
...@@ -58,13 +58,14 @@ static void CheckInputVarStatus(const Tensor &tensor) { ...@@ -58,13 +58,14 @@ static void CheckInputVarStatus(const Tensor &tensor) {
"wrong type. Expect type is DenseTensor.", "wrong type. Expect type is DenseTensor.",
tensor.name())); tensor.name()));
PADDLE_ENFORCE_EQ(tensor.initialized(), PADDLE_ENFORCE_EQ(
true, static_cast<phi::DenseTensor *>(tensor.impl().get())->IsInitialized(),
paddle::platform::errors::InvalidArgument( true,
"The tensor in input tensor %s of " paddle::platform::errors::InvalidArgument(
"RunProgram(Grad)Op " "The tensor in input tensor %s of "
"is not initialized.", "RunProgram(Grad)Op "
tensor.name())); "is not initialized.",
tensor.name()));
} }
static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
...@@ -84,7 +85,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, ...@@ -84,7 +85,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
"RunProgram(Grad)Op's internal scope holds " "RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is DenseTensor", "wrong type. Expect type is DenseTensor",
name)); name));
PADDLE_ENFORCE_EQ(src_tensor.initialized(), PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(),
true, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The tensor in output tensor %s get from " "The tensor in output tensor %s get from "
...@@ -120,7 +121,7 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors, ...@@ -120,7 +121,7 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
paddle::framework::Scope *scope) { paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto name = tensors[i].name(); auto name = tensors[i].name();
if (name == "Fake_var" || !tensors[i].initialized()) { if (name == "Fake_var") {
continue; continue;
} }
auto *var = scope->Var(name); auto *var = scope->Var(name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册