提交 209e9c3d 编写于 作者: T tensor-tang

refine peephole

test=develop
上级 01fda934
......@@ -77,10 +77,12 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
......@@ -90,6 +92,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetOutput("CheckedCell", {CheckedCell});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
......
......@@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes",
frame_size, frame_size);
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
ctx->SetOutputDim("CheckedCell", {2, frame_size});
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes",
frame_size);
}
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
......@@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() {
AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
.AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
......@@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D3 = D * 3; \
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/ \
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
#define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
}
/// Compute LSTM
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册