未验证 提交 b2390438 编写于 作者: Z zyfncg 提交者: GitHub

Fix problem of infermeta with vector output (#41646)

* remove stack_grad infershape

* fix bug of output with null

* fix bug
上级 5f2c5b9e
...@@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
for (auto& out_name : output_names) { for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name)) { if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name); auto output_var = ctx->GetOutputVarPtrs(out_name);
if (output_var.size() == 1) { if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>( infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
...@@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs; paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size()); outputs.reserve(output_var.size());
for (const auto& out : output_var) { for (const auto& out : output_var) {
outputs.emplace_back( if (ctx->IsRuntime()) {
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime())); if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
continue;
}
} else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
continue;
}
outputs.emplace_back(nullptr);
} }
infer_meta_context.EmplaceBackOutputs(std::move(outputs)); infer_meta_context.EmplaceBackOutputs(std::move(outputs));
} }
......
...@@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs( ...@@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs(
return true; return true;
} }
bool InterpretercoreInferShapeContext::HasOutputs( bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name,
const std::string& name) const { bool allow_null) const {
const auto& outs = ctx_.outputs; const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) { if (it == outs.end() || it->second.empty()) {
return false; return false;
} }
for (auto& output : it->second) { if (allow_null) {
if (output == nullptr) { for (auto& output : it->second) {
return false; if (output != nullptr) return true;
}
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) return false;
} }
return true;
} }
return true;
} }
AttrReader InterpretercoreInferShapeContext::Attrs() const { AttrReader InterpretercoreInferShapeContext::Attrs() const {
......
...@@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string& name) const override; bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override; bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override; AttrReader Attrs() const override;
......
...@@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string &name) const override; bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override; bool HasOutputs(const std::string &name,
bool allow_null = false) const override;
AttrReader Attrs() const override; AttrReader Attrs() const override;
...@@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { ...@@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
return true; return true;
} }
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
bool allow_null) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) { if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false; return false;
} }
...@@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { ...@@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if (output_names.empty()) { if (output_names.empty()) {
return false; return false;
} }
for (auto &output : output_names) { if (allow_null) {
if (!block_.HasVarRecursive(output)) return false; for (auto &output : output_names) {
if (block_.HasVarRecursive(output)) return true;
}
return false;
} else {
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
} }
return true;
} }
AttrReader CompileTimeInferShapeContext::Attrs() const { AttrReader CompileTimeInferShapeContext::Attrs() const {
......
...@@ -718,18 +718,24 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -718,18 +718,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
return true; return true;
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
const auto& outs = ctx_.outputs; const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) { if (it == outs.end() || it->second.empty()) {
return false; return false;
} }
for (auto& output : it->second) { if (allow_null) {
if (output == nullptr) { for (auto& output : it->second) {
return false; if (output != nullptr) return true;
}
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) return false;
} }
return true;
} }
return true;
} }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
......
...@@ -69,7 +69,8 @@ class InferShapeContext { ...@@ -69,7 +69,8 @@ class InferShapeContext {
const std::string &name) const = 0; const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name,
bool allow_null = false) const = 0;
virtual DDim GetInputDim(const std::string &name) const = 0; virtual DDim GetInputDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0; virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
......
...@@ -95,17 +95,27 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -95,17 +95,27 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return true; return true;
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
auto it = var_map_out_->find(name); auto it = var_map_out_->find(name);
if (it == var_map_out_->end() || it->second.empty()) { if (it == var_map_out_->end() || it->second.empty()) {
return false; return false;
} }
for (auto& output : it->second) { if (allow_null) {
if (output == nullptr) { for (auto& output : it->second) {
return false; if (output != nullptr) {
return true;
}
} }
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
} }
return true;
} }
framework::AttrReader Attrs() const override { framework::AttrReader Attrs() const override {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
namespace plat = paddle::platform; namespace plat = paddle::platform;
...@@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp ...@@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp
class StackOpGrad : public framework::OperatorWithKernel { class StackOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument("Input(Y@Grad) not exist."));
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE_GE(
axis, -rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
PADDLE_ENFORCE_LT(
axis, rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis]),
platform::errors::InvalidArgument(
"Number of Outputs(X@Grad) is equal to dy dim at axis, but"
" received outputs size is:%d, dy dims is:%d.",
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis])));
auto vec = phi::vectorize<int>(dy_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim(
framework::GradVarName("X"),
std::vector<framework::DDim>(dy_dim[axis], phi::make_ddim(vec)));
}
}; };
template <typename T> template <typename T>
...@@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor, DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor,
PD_INFER_META(phi::StackInferMeta)); PD_INFER_META(phi::StackInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(stack_grad, StackGradInferMetaFunctor,
PD_INFER_META(phi::StackGradInferMeta));
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpMaker<paddle::framework::OpDesc>, ops::StackGradOpMaker<paddle::framework::OpDesc>,
ops::StackGradOpMaker<paddle::imperative::OpBase>, ops::StackGradOpMaker<paddle::imperative::OpBase>,
StackInferMetaFunctor); StackInferMetaFunctor);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); REGISTER_OPERATOR(stack_grad, ops::StackOpGrad, StackGradInferMetaFunctor);
...@@ -541,8 +541,10 @@ void StackGradInferMeta(const MetaTensor& out_grad, ...@@ -541,8 +541,10 @@ void StackGradInferMeta(const MetaTensor& out_grad,
vec.erase(vec.begin() + axis); vec.erase(vec.begin() + axis);
for (size_t i = 0; i < x_grad.size(); ++i) { for (size_t i = 0; i < x_grad.size(); ++i) {
x_grad[i]->set_dims(phi::make_ddim(vec)); if (x_grad[i]) {
x_grad[i]->set_dtype(out_grad.dtype()); x_grad[i]->set_dims(phi::make_ddim(vec));
x_grad[i]->set_dtype(out_grad.dtype());
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册