未验证 提交 cdf3a4c2 编写于 作者: C chengduo 提交者: GitHub

Fix concat_op InferShape (#13513)

* add ShareLoDs

* refine

* add Is EmptyVarName

* refine Sharedlod
上级 6537b175
......@@ -54,6 +54,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", in, i);
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
......@@ -63,6 +67,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLoDLevel());
}
......
......@@ -46,6 +46,16 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]);
}
void InferShapeContext::ShareLoDs(const std::string &in,
const std::string &out) const {
PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(),
"The number of arguments in %s and %s is not equal.", in,
out);
for (size_t i = 0; i < in.size(); ++i) {
ShareLoD(in, out, i, i);
}
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
......
......@@ -56,6 +56,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
void ShareLoDs(const std::string &in, const std::string &out) const;
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
......
......@@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x);
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
auto &in_names = ctx->Inputs(in_x);
auto &out_names = ctx->Outputs(out_x_g_n);
PADDLE_ENFORCE_EQ(
in_names.size(), out_names.size(),
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
in_names.size(), out_x_g_n, out_names.size());
for (size_t i = 0; i < in_names.size(); ++i) {
if (out_names[i] != framework::kEmptyVarName) {
ctx->ShareLoD(in_x, out_x_g_n, i, i);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册