未验证 提交 9f2ccf5b 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13237 from tensor-tang/refine/op/peephole

refine fusion lstm/peephole and fusion gru
......@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
......
......@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
......@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
}
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
......
......@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self.M = 8
self.D = 16
self.has_initial_state = False
self.use_peepholes = False
self.is_reverse = False
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_peepholes = False
self.use_seq = False
self.set_conf()
T = sum(self.lod[0])
......@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
}
self.attrs = {
'use_peepholes': self.use_peepholes,
'use_seq': self.use_seq,
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
......@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.lod = [[3]]
self.D = 16
class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.is_reverse = True
class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
self.is_reverse = True
class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp):
class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.is_reverse = True
self.lod = [[2]]
self.D = 8
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册