提交 a5556d44 编写于 作者: T tensor-tang

refine attentionlstm infershape

上级 e0436ad8
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h" #include "paddle/fluid/operators/attention_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
...@@ -23,29 +24,60 @@ namespace paddle { ...@@ -23,29 +24,60 @@ namespace paddle {
namespace operators { namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
"Input(X) of AttentionLSTM should not be null."); if (runtime_ctx == nullptr) {
PADDLE_ENFORCE(ctx->HasInput("C0"), LOG(FATAL) << "Should have runtime infer context";
"Input(C0) of AttentionLSTM should not be null."); }
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), const auto& ins = runtime_ctx->OpBase().Inputs();
"Input(LSTMWeight) of AttentionLSTM should not be null."); const auto& outs = runtime_ctx->OpBase().Outputs();
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), const auto& scope = runtime_ctx->InferScope();
"Input(LSTMBias) of AttentionLSTM should not be null."); const auto ins_end = ins.end();
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), const auto outs_end = outs.end();
"Input(AttentionWeight) of AttentionLSTM should not be null."); auto fair_input = [&](const std::string& name) -> bool {
auto it = ins.find(name);
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), if (it == ins_end) {
"Output(Hidden) of AttentionLSTM should not be null."); return false;
PADDLE_ENFORCE(ctx->HasOutput("Cell"), }
"Output(Cell) of AttentionLSTM should not be null."); const auto& in = it->second;
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
"Output(AttentionedX) of AttentionLSTM should not be null."); return false;
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), }
"Output(AttentionFCOut) of AttentionLSTM should not be null."); return scope.FindVar(in[0]) != nullptr;
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), };
"Output(LSTMX) of AttentionLSTM should not be null."); auto fair_output = [&](const std::string& name) -> bool {
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), auto it = outs.find(name);
"Output(LSTMOUT) of AttentionLSTM should not be null."); if (it == outs_end) {
return false;
}
const auto& out = it->second;
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(out[0]) != nullptr;
};
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("C0"),
"Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMWeight"),
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMBias"),
"Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Hidden"),
"Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Cell"),
"Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionedX"),
"Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionFCOut"),
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1]; const int M = x_dims[1];
...@@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) { if (fair_input("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) " "The dimension of Input(H0) and Input(C0) "
...@@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) { if (fair_input("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); "Input(AttentionBias)'s rank must be 2.");
...@@ -89,7 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -89,7 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionBias shapes must be 1 * 1."); "AttentionBias shapes must be 1 * 1.");
} }
if (ctx->HasInput("AttentionScalar")) { if (fair_input("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); "Input(AttentionScalar)'s rank must be 2.");
...@@ -97,10 +129,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -97,10 +129,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
} }
if (ctx->HasInput("AttentionScalarBias")) { if (fair_input("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias"); auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"), fair_input("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias."); "AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); "Input(AttentionScalarBias)'s rank must be 2.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册