提交 c45cee03 编写于 作者: T tensor-tang

refine infershape and forward

上级 c7c25067
...@@ -23,91 +23,36 @@ namespace paddle { ...@@ -23,91 +23,36 @@ namespace paddle {
namespace operators { namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
"Input(X) of FusionSeqConcatFC should not be null."); "Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasInput("FCWeight"), PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqConcatFC should not be null."); "Input(FCWeight) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConcatFC should not be null."); "Output(Out) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("FCOut"), PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqConcatFC should not be null."); "Output(FCOut) of FusionSeqConcatFC should not be null.");
// need check fc height = all inputs width sum auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
const int M = x_dims[1]; const int D = w_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); int sum = ins_dims[0][1];
for (size_t i = 1; i < ins_dims.size(); ++i) {
auto w_dims = ctx->GetInputDim("LSTMWeight"); sum += ins_dims[i][1];
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
} }
PADDLE_ENFORCE_EQ(sum, w_dims[0],
if (ctx->HasInput("AttentionScalar")) { "FC height should be sum of all inputs width.");
auto dims = ctx->GetInputDim("AttentionScalar"); if (ctx->HasInput("FCBias")) {
PADDLE_ENFORCE_EQ(dims.size(), 2, auto b_dims = ctx->GetInputDim("FCBias");
"Input(AttentionScalar)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D);
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D);
} }
if (ctx->HasInput("AttentionScalarBias")) { ctx->SetOutputDim("Out", {ins_dims[0][0], D});
auto dims = ctx->GetInputDim("AttentionScalarBias"); // fcout should be reshape when run since can not get lod in infershape
PADDLE_ENFORCE( // explicit share the ref lod
ctx->HasInput("AttentionScalar"), ctx->ShareLoD("X", "Out", 0);
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
}
framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
...@@ -154,46 +99,46 @@ The concat axis should be 1. ...@@ -154,46 +99,46 @@ The concat axis should be 1.
)DOC"); )DOC");
} }
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) {
math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
math::vec_relu<T, platform::jit::avx>(n, y, y);
} else {
math::vec_relu<T, platform::jit::avx>(n, x, y);
}
}
template <typename T>
inline void vec_softmax(const int n, const T* x, T* y) {
T scalar = x[0];
// max
for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar;
}
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
math::vec_exp<T>(n, y, y); // exp
// sum
scalar = T(0);
for (int i = 0; i < n; ++i) {
scalar += y[i];
}
math::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
}
template <typename T> template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> { class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* ins = ctx.Input<LoDTensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("FCWeight"); auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias"); auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOUT"); auto* fc_out = ctx.Output<Tensor>("FCOUT");
auto* ref_in = ins[0];
auto ref_lod = ref_in->lod();
auto in1_lod = ins[1]->lod();
auto ref_dims = ref_in->dims(); // T x M0
auto in1_dims = ins[1]->dims(); // N x M1
auto w_dims = w->dims();
const int N = ref_lod[0].size() - 1;
const int total_T = ref_dims[0];
const int M0 = ref_dims[1];
const int M1 = in1_dims[1];
const int D = w_dims[1];
// some check and fcout should be reshape here
// since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
"Batch size of all inputs should be equal.");
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
"Seq_length of other inputs should be 1.");
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
for (size_t i = 2; i < ins.size(); ++i) {
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
"All other inputs height should be equal");
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
"All other inputs should have same lod");
}
fc_out->Resize({N, D});
std::function<void(const int, const T*, T*)> fc_act; std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation"); auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) { if (platform::jit::MayIUse(platform::jit::avx)) {
...@@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> { ...@@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
fc_act = act_functor(fc_act_str); fc_act = act_functor(fc_act_str);
} }
PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1."); const T* ref_in_data = ref_in->data<T>();
auto* ref_in = ins[0];
auto ref_in_lod = ref_in->lod();
const int N = ref_in_lod[0].size() - 1;
auto ref_in_dims = ref_in->dims(); // T x M0
auto w_dims = w->dims(); // (M0+M1+M2+..) x D
const int total_T = ref_in_dims[0];
const int M0 = ref_in_dims[1];
const int M1 = ins[1]->dims()[1];
const int D = w_dims[1];
const T* ref_in_data =
ref_in->data<T>(); // size should be check at infershape
const T* in1_data = ins[1]->data<T>(); const T* in1_data = ins[1]->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
...@@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> { ...@@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data, math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL); out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D; w_data = w_data + M0 * D;
// first one use write on // first one use write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D; w_data = w_data + M1 * D;
for (int i = 2; i < ins.size(); ++i) { for (size_t i = 2; i < ins.size(); ++i) {
// add on // add on
const T* in_data = ins[i]->data<T>(); const T* in_data = ins[i]->data<T>();
const int K = ins[i]->dims()[1]; const int K = ins[i]->dims()[1];
...@@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> { ...@@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} }
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i]; int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D; T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) { for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, out_data, src, out_data); blas.VADD(D, out_data, src, out_data);
...@@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> { ...@@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} }
} }
fc_act(out_dims[0] * out_dims[1], out_data, out_data); fc_act(total_T * D, out_data, out_data);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册