未验证 提交 b2390438 编写于 作者: Z zyfncg 提交者: GitHub

Fix problem of infermeta with vector output (#41646)

* remove stack_grad infershape

* fix bug of output with null

* fix bug
上级 5f2c5b9e
......@@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name)) {
if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name);
if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
......@@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size());
for (const auto& out : output_var) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
if (ctx->IsRuntime()) {
if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
continue;
}
} else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
continue;
}
outputs.emplace_back(nullptr);
}
infer_meta_context.EmplaceBackOutputs(std::move(outputs));
}
......
......@@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs(
return true;
}
bool InterpretercoreInferShapeContext::HasOutputs(
const std::string& name) const {
bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name,
bool allow_null) const {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) return true;
}
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
return true;
}
return true;
}
AttrReader InterpretercoreInferShapeContext::Attrs() const {
......
......@@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override;
bool HasOutputs(const std::string& name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
......
......@@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override;
bool HasOutputs(const std::string &name,
bool allow_null = false) const override;
AttrReader Attrs() const override;
......@@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
return true;
}
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
bool allow_null) const {
if (op_.Outputs().find(name) == op_.Outputs().end()) {
return false;
}
......@@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if (output_names.empty()) {
return false;
}
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
if (allow_null) {
for (auto &output : output_names) {
if (block_.HasVarRecursive(output)) return true;
}
return false;
} else {
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}
return true;
}
AttrReader CompileTimeInferShapeContext::Attrs() const {
......
......@@ -718,18 +718,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
return true;
}
bool HasOutputs(const std::string& name) const override {
bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) return true;
}
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
return true;
}
return true;
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
......
......@@ -69,7 +69,8 @@ class InferShapeContext {
const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name,
bool allow_null = false) const = 0;
virtual DDim GetInputDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
......
......@@ -95,17 +95,27 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return true;
}
bool HasOutputs(const std::string& name) const override {
bool HasOutputs(const std::string& name,
bool allow_null = false) const override {
auto it = var_map_out_->find(name);
if (it == var_map_out_->end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) {
return true;
}
}
return false;
} else {
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
return true;
}
framework::AttrReader Attrs() const override {
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace plat = paddle::platform;
......@@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp
class StackOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument("Input(Y@Grad) not exist."));
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE_GE(
axis, -rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
PADDLE_ENFORCE_LT(
axis, rank,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d.",
rank, axis));
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis]),
platform::errors::InvalidArgument(
"Number of Outputs(X@Grad) is equal to dy dim at axis, but"
" received outputs size is:%d, dy dims is:%d.",
ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis])));
auto vec = phi::vectorize<int>(dy_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim(
framework::GradVarName("X"),
std::vector<framework::DDim>(dy_dim[axis], phi::make_ddim(vec)));
}
};
template <typename T>
......@@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor,
PD_INFER_META(phi::StackInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(stack_grad, StackGradInferMetaFunctor,
PD_INFER_META(phi::StackGradInferMeta));
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpMaker<paddle::framework::OpDesc>,
ops::StackGradOpMaker<paddle::imperative::OpBase>,
StackInferMetaFunctor);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad, StackGradInferMetaFunctor);
......@@ -541,8 +541,10 @@ void StackGradInferMeta(const MetaTensor& out_grad,
vec.erase(vec.begin() + axis);
for (size_t i = 0; i < x_grad.size(); ++i) {
x_grad[i]->set_dims(phi::make_ddim(vec));
x_grad[i]->set_dtype(out_grad.dtype());
if (x_grad[i]) {
x_grad[i]->set_dims(phi::make_ddim(vec));
x_grad[i]->set_dtype(out_grad.dtype());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册