diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index ecc5fbdcf945d82f022ebf7dc012ffabddd1c846..17acbde2a09e72a7ac9886e994a416bb4279d6bb 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } for (auto& out_name : output_names) { - if (ctx->HasOutputs(out_name)) { + if (ctx->HasOutputs(out_name, true)) { auto output_var = ctx->GetOutputVarPtrs(out_name); if (output_var.size() == 1) { infer_meta_context.EmplaceBackOutput(std::make_shared( @@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, paddle::SmallVector> outputs; outputs.reserve(output_var.size()); for (const auto& out : output_var) { - outputs.emplace_back( - std::make_shared(out, ctx->IsRuntime())); + if (ctx->IsRuntime()) { + if (BOOST_GET_CONST(Variable*, out)) { + outputs.emplace_back( + std::make_shared(out, ctx->IsRuntime())); + continue; + } + } else if (BOOST_GET_CONST(VarDesc*, out)) { + outputs.emplace_back( + std::make_shared(out, ctx->IsRuntime())); + continue; + } + outputs.emplace_back(nullptr); } infer_meta_context.EmplaceBackOutputs(std::move(outputs)); } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index ccdd9dc9d50ced8d1fb0ec57b24ee878637dd5a4..089e68fe48c527971cd896373efdd4f41bafa75a 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs( return true; } -bool InterpretercoreInferShapeContext::HasOutputs( - const std::string& name) const { +bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name, + bool allow_null) const { const auto& outs = ctx_.outputs; auto it = outs.find(name); if (it == outs.end() || it->second.empty()) { return false; } - for (auto& output : it->second) { - if (output == nullptr) { - return false; + if (allow_null) { + for (auto& output : it->second) { + if (output != nullptr) return true; + } + return false; + } else { + for (auto& output : it->second) { + if (output == nullptr) return false; } + return true; } - return true; } AttrReader InterpretercoreInferShapeContext::Attrs() const { diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 5704fa414bbb2b195c66a7d85e0cd587403e04fc..aab32cfa06d4042e3181d50e66edaa02cc67b17c 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { bool HasInputs(const std::string& name) const override; - bool HasOutputs(const std::string& name) const override; + bool HasOutputs(const std::string& name, + bool allow_null = false) const override; AttrReader Attrs() const override; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index f31fefcfade89886acfc475ddd99e8310f677090..15b979086d1eb8ead1e38d1be681d258cb1f8182 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasInputs(const std::string &name) const override; - bool HasOutputs(const std::string &name) const override; + bool HasOutputs(const std::string &name, + bool allow_null = false) const override; AttrReader Attrs() const override; @@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { return true; } -bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { +bool CompileTimeInferShapeContext::HasOutputs(const std::string &name, + bool allow_null) const { if (op_.Outputs().find(name) == op_.Outputs().end()) { return false; } @@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { if (output_names.empty()) { return false; } - for (auto &output : output_names) { - if (!block_.HasVarRecursive(output)) return false; + if (allow_null) { + for (auto &output : output_names) { + if (block_.HasVarRecursive(output)) return true; + } + return false; + } else { + for (auto &output : output_names) { + if (!block_.HasVarRecursive(output)) return false; + } + return true; } - return true; } AttrReader CompileTimeInferShapeContext::Attrs() const { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e6577f662ae7b22fb7078ab5aa697c8a3da5feb2..d9704d70b45ec60258d50040afdb99dd0378913e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -718,18 +718,24 @@ class RuntimeInferShapeContext : public InferShapeContext { return true; } - bool HasOutputs(const std::string& name) const override { + bool HasOutputs(const std::string& name, + bool allow_null = false) const override { const auto& outs = ctx_.outputs; auto it = outs.find(name); if (it == outs.end() || it->second.empty()) { return false; } - for (auto& output : it->second) { - if (output == nullptr) { - return false; + if (allow_null) { + for (auto& output : it->second) { + if (output != nullptr) return true; + } + return false; + } else { + for (auto& output : it->second) { + if (output == nullptr) return false; } + return true; } - return true; } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 31e3929362a04114f36c9ff33a0c3e29c7706606..6ba60590cf8f370b3e983ea9c925ec8387eb2fae 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -69,7 +69,8 @@ class InferShapeContext { const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0; - virtual bool HasOutputs(const std::string &name) const = 0; + virtual bool HasOutputs(const std::string &name, + bool allow_null = false) const = 0; virtual DDim GetInputDim(const std::string &name) const = 0; virtual std::vector GetInputsDim(const std::string &name) const = 0; diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index f871e77fdf6e2966de843e8c29274c35021efe81..1e5b112ece21f606b995d922a30b520096d0907b 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -95,17 +95,27 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return true; } - bool HasOutputs(const std::string& name) const override { + bool HasOutputs(const std::string& name, + bool allow_null = false) const override { auto it = var_map_out_->find(name); if (it == var_map_out_->end() || it->second.empty()) { return false; } - for (auto& output : it->second) { - if (output == nullptr) { - return false; + if (allow_null) { + for (auto& output : it->second) { + if (output != nullptr) { + return true; + } } + return false; + } else { + for (auto& output : it->second) { + if (output == nullptr) { + return false; + } + } + return true; } - return true; } framework::AttrReader Attrs() const override { diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc index a9fa78c4e49430dafc56b90c9c6873a72d5d3d91..6fc80ca379f3f9ee93fd9d327f7b97deadc1152f 100644 --- a/paddle/fluid/operators/stack_op.cc +++ b/paddle/fluid/operators/stack_op.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/multiary.h" namespace plat = paddle::platform; @@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp class StackOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Y")), true, - platform::errors::InvalidArgument("Input(Y@Grad) not exist.")); - - int axis = ctx->Attrs().Get("axis"); - auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); - int rank = dy_dim.size(); - PADDLE_ENFORCE_GE( - axis, -rank, - platform::errors::InvalidArgument( - "Attr(axis) must be inside [-rank, rank), where rank = %d, " - "but received axis is:%d.", - rank, axis)); - PADDLE_ENFORCE_LT( - axis, rank, - platform::errors::InvalidArgument( - "Attr(axis) must be inside [-rank, rank), where rank = %d, " - "but received axis is:%d.", - rank, axis)); - - if (axis < 0) axis += rank; - PADDLE_ENFORCE_EQ( - ctx->Outputs(framework::GradVarName("X")).size(), - static_cast(dy_dim[axis]), - platform::errors::InvalidArgument( - "Number of Outputs(X@Grad) is equal to dy dim at axis, but" - " received outputs size is:%d, dy dims is:%d.", - ctx->Outputs(framework::GradVarName("X")).size(), - static_cast(dy_dim[axis]))); - - auto vec = phi::vectorize(dy_dim); - vec.erase(vec.begin() + axis); - ctx->SetOutputsDim( - framework::GradVarName("X"), - std::vector(dy_dim[axis], phi::make_ddim(vec))); - } }; template @@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker { DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor, PD_INFER_META(phi::StackInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(stack_grad, StackGradInferMetaFunctor, + PD_INFER_META(phi::StackGradInferMeta)); REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, ops::StackGradOpMaker, ops::StackGradOpMaker, StackInferMetaFunctor); -REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); +REGISTER_OPERATOR(stack_grad, ops::StackOpGrad, StackGradInferMetaFunctor); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index efbf02e3314333f1e12a1b65856309822a3d2465..84db67978fc23ca0b8e49b5cab0fa8207393a0f0 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -541,8 +541,10 @@ void StackGradInferMeta(const MetaTensor& out_grad, vec.erase(vec.begin() + axis); for (size_t i = 0; i < x_grad.size(); ++i) { - x_grad[i]->set_dims(phi::make_ddim(vec)); - x_grad[i]->set_dtype(out_grad.dtype()); + if (x_grad[i]) { + x_grad[i]->set_dims(phi::make_ddim(vec)); + x_grad[i]->set_dtype(out_grad.dtype()); + } } }