未验证 提交 62a98210 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13346 from tensor-tang/refine/infershape

Refine/infershape
......@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) {
// has only one input
const auto& ins = op_.Inputs();
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
const auto& in = it->second;
if (in.size() == 0 || in[0] == kEmptyVarName) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
return scope_.FindVar(in[0]) != nullptr;
}
bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
// has only one output
const auto& outs = op_.Outputs();
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
const auto& out = it->second;
if (out.size() == 0 || out[0] == kEmptyVarName) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one outputs", name);
return scope_.FindVar(out[0]) != nullptr;
}
bool HasInputs(const std::string& name) const override {
......
......@@ -24,28 +24,28 @@ namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null.");
"Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null.");
"Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null.");
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
"Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null.");
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null.");
"Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null.");
"Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null.");
"Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
......
......@@ -25,14 +25,14 @@ namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
"Assert only one Input(WeightX) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
"Assert only one Input(WeightH) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
"Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
......@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
"Assert only one Output(ReorderedH0) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
"Assert only one Output(BatchedInput) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
"Assert only one Output(BatchedOut) of GRU.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
}
......
......@@ -24,20 +24,17 @@ namespace paddle {
namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null.");
"Assert only one Input(WeightX) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
"Assert only one Input(WeightH) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
"Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
"Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
......@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
"Assert only one Output(BatchedInput) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
"Assert only one Output(BatchedHidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
"Assert only one Output(BatchedCell) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
"Assert only one Output(ReorderedH0) of LSTM");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
"Assert only one Output(ReorderedC0) of LSTM.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册