提交 916f42bc 编写于 作者: T tensor-tang

refine fusion gru infershape

上级 a5556d44
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <cstring> // for memcpy
#include <string>
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
......@@ -25,14 +26,46 @@ namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
if (runtime_ctx == nullptr) {
LOG(FATAL) << "Should have runtime infer context";
}
const auto& ins = runtime_ctx->OpBase().Inputs();
const auto& outs = runtime_ctx->OpBase().Outputs();
const auto& scope = runtime_ctx->InferScope();
const auto ins_end = ins.end();
const auto outs_end = outs.end();
auto fair_input = [&](const std::string& name) -> bool {
auto it = ins.find(name);
if (it == ins_end) {
return false;
}
const auto& in = it->second;
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(in[0]) != nullptr;
};
auto fair_output = [&](const std::string& name) -> bool {
auto it = outs.find(name);
if (it == outs_end) {
return false;
}
const auto& out = it->second;
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(out[0]) != nullptr;
};
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU.");
PADDLE_ENFORCE(fair_input("WeightX"),
"Assert only one Input(WeightX) of GRU.");
PADDLE_ENFORCE(fair_input("WeightH"),
"Assert only one Input(WeightH) of GRU.");
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU.");
PADDLE_ENFORCE(fair_output("Hidden"),
"Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
......@@ -58,12 +91,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) {
if (fair_input("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
if (fair_input("Bias")) {
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
......@@ -79,12 +112,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null.");
PADDLE_ENFORCE(fair_output("ReorderedH0"),
"Assert only one Output(ReorderedH0) of GRU.");
PADDLE_ENFORCE(fair_output("BatchedInput"),
"Assert only one Output(BatchedInput) of GRU.");
PADDLE_ENFORCE(fair_output("BatchedOut"),
"Assert only one Output(BatchedOut) of GRU.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册