未验证 提交 9f2ccf5b 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13237 from tensor-tang/refine/op/peephole

refine fusion lstm/peephole and fusion gru
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
......
...@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightX) of GRU should not be null."); "Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null."); "Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null."); "Output(Hidden) of GRU should not be null.");
...@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width; int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
} }
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
......
...@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest): ...@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self.M = 8 self.M = 8
self.D = 16 self.D = 16
self.has_initial_state = False self.has_initial_state = False
self.use_peepholes = False
self.is_reverse = False self.is_reverse = False
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False
self.use_seq = False
self.set_conf() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
...@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest): ...@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
} }
self.attrs = { self.attrs = {
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'use_seq': self.use_seq,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
...@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): ...@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self.is_reverse = True self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp):
def set_conf(self): def set_conf(self):
self.use_peepholes = True self.use_peepholes = True
self.lod = [[3]]
self.D = 16
class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.is_reverse = True
class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
def set_conf(self): def set_conf(self):
self.use_seq = True
self.use_peepholes = True self.use_peepholes = True
self.lod = [[2]]
self.D = 8
class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.is_reverse = True
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册