未验证 提交 6938e6cf 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13603 from tensor-tang/refine/peephole

refine peephole
...@@ -77,10 +77,12 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -77,10 +77,12 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
const std::string BatchedCellPreAct = const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct"); patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>(); scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>(); scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>(); scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
...@@ -90,6 +92,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -90,6 +92,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("BatchedGate", {BatchedGate}); op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput}); op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetOutput("CheckedCell", {CheckedCell});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr // TODO(TJ): get from attr
......
...@@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); "The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ( if (ctx->Attrs().Get<bool>("use_peepholes")) {
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size, PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or" "7 * %d if enable peepholes connection",
"4 * %d if disable peepholes", frame_size);
frame_size, frame_size); ctx->SetOutputDim("CheckedCell", {2, frame_size});
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes",
frame_size);
}
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
...@@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() { ...@@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() {
AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate(); AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate(); AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate(); AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
.AsIntermediate();
AddAttr<bool>("use_peepholes", AddAttr<bool>("use_peepholes",
"(bool, defalut: True) " "(bool, defalut: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
...@@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D3 = D * 3; \ const int D3 = D * 3; \
const int D4 = wh_dims[1]; const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \ #define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wc_data = bias->data<T>() + D4; \ const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
Tensor checked_cell; \ T* checked_cell_data = nullptr; \
T* checked_cell_data = nullptr; \ auto place = ctx.GetPlace(); \
auto place = ctx.GetPlace(); \ if (use_peepholes) { \
if (use_peepholes) { \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} }
/// Compute LSTM /// Compute LSTM
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册