未验证 提交 62a98210 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13346 from tensor-tang/refine/infershape

Refine/infershape
...@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) { // has only one input
const auto& ins = op_.Inputs();
auto it = ins.find(name);
if (it == ins.end()) {
return false; return false;
} }
auto& ins = Inputs(name); const auto& in = it->second;
size_t length = ins.size(); if (in.size() == 0 || in[0] == kEmptyVarName) {
if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name); "Input %s should not have more than one inputs", name);
auto ipt = ins[0]; return scope_.FindVar(in[0]) != nullptr;
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) { // has only one output
const auto& outs = op_.Outputs();
auto it = outs.find(name);
if (it == outs.end()) {
return false; return false;
} }
auto& outs = Outputs(name); const auto& out = it->second;
size_t length = outs.size(); if (out.size() == 0 || out[0] == kEmptyVarName) {
if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one inputs", name); "Output %s should not have more than one outputs", name);
auto ipt = outs[0]; return scope_.FindVar(out[0]) != nullptr;
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
......
...@@ -24,28 +24,28 @@ namespace operators { ...@@ -24,28 +24,28 @@ namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null."); "Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null."); "Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null."); "Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null."); "Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null."); "Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null."); "Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null."); "Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null."); "Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null."); "Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null."); "Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null."); "Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1]; const int M = x_dims[1];
......
...@@ -25,14 +25,14 @@ namespace paddle { ...@@ -25,14 +25,14 @@ namespace paddle {
namespace operators { namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null."); "Assert only one Input(WeightX) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null."); "Assert only one Input(WeightH) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null."); "Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
...@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null."); "Assert only one Output(ReorderedH0) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null."); "Assert only one Output(BatchedInput) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null."); "Assert only one Output(BatchedOut) of GRU.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
} }
......
...@@ -24,20 +24,17 @@ namespace paddle { ...@@ -24,20 +24,17 @@ namespace paddle {
namespace operators { namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null."); "Assert only one Input(WeightX) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null."); "Assert only one Input(WeightH) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
"Input(Bias) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null."); "Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null."); "Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
...@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null."); "Assert only one Output(BatchedInput) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null."); "Assert only one Output(BatchedHidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null."); "Assert only one Output(BatchedCell) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null."); "Assert only one Output(ReorderedH0) of LSTM");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null."); "Assert only one Output(ReorderedC0) of LSTM.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册