提交 a5556d44 编写于 作者: T tensor-tang

refine attentionlstm infershape

上级 e0436ad8
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h"
#include <string>
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
......@@ -23,29 +24,60 @@ namespace paddle {
namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
if (runtime_ctx == nullptr) {
LOG(FATAL) << "Should have runtime infer context";
}
const auto& ins = runtime_ctx->OpBase().Inputs();
const auto& outs = runtime_ctx->OpBase().Outputs();
const auto& scope = runtime_ctx->InferScope();
const auto ins_end = ins.end();
const auto outs_end = outs.end();
auto fair_input = [&](const std::string& name) -> bool {
auto it = ins.find(name);
if (it == ins_end) {
return false;
}
const auto& in = it->second;
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(in[0]) != nullptr;
};
auto fair_output = [&](const std::string& name) -> bool {
auto it = outs.find(name);
if (it == outs_end) {
return false;
}
const auto& out = it->second;
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(out[0]) != nullptr;
};
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("C0"),
"Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMWeight"),
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMBias"),
"Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Hidden"),
"Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Cell"),
"Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionedX"),
"Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionFCOut"),
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
......@@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
if (fair_input("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
......@@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
if (fair_input("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
......@@ -89,7 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionBias shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalar")) {
if (fair_input("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
......@@ -97,10 +129,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalarBias")) {
if (fair_input("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
fair_input("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册