提交 555083ae 编写于 作者: T tensor-tang

enforce only used

上级 90b5be85
......@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
......@@ -99,17 +89,26 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims);
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册