提交 c45cee03 编写于 作者: T tensor-tang

refine infershape and forward

上级 c7c25067
......@@ -23,91 +23,36 @@ namespace paddle {
namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqConcatFC should not be null.");
// need check fc height = all inputs width sum
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
const int D = w_dims[1];
int sum = ins_dims[0][1];
for (size_t i = 1; i < ins_dims.size(); ++i) {
sum += ins_dims[i][1];
}
if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(sum, w_dims[0],
"FC height should be sum of all inputs width.");
if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D);
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D);
}
if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
}
framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", {ins_dims[0][0], D});
// fcout should be reshape when run since can not get lod in infershape
// explicit share the ref lod
ctx->ShareLoD("X", "Out", 0);
}
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
......@@ -154,46 +99,46 @@ The concat axis should be 1.
)DOC");
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) {
math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
math::vec_relu<T, platform::jit::avx>(n, y, y);
} else {
math::vec_relu<T, platform::jit::avx>(n, x, y);
}
}
template <typename T>
inline void vec_softmax(const int n, const T* x, T* y) {
T scalar = x[0];
// max
for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar;
}
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
math::vec_exp<T>(n, y, y); // exp
// sum
scalar = T(0);
for (int i = 0; i < n; ++i) {
scalar += y[i];
}
math::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
}
template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* ins = ctx.Input<LoDTensor>("X");
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOUT");
auto* ref_in = ins[0];
auto ref_lod = ref_in->lod();
auto in1_lod = ins[1]->lod();
auto ref_dims = ref_in->dims(); // T x M0
auto in1_dims = ins[1]->dims(); // N x M1
auto w_dims = w->dims();
const int N = ref_lod[0].size() - 1;
const int total_T = ref_dims[0];
const int M0 = ref_dims[1];
const int M1 = in1_dims[1];
const int D = w_dims[1];
// some check and fcout should be reshape here
// since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
"Batch size of all inputs should be equal.");
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
"Seq_length of other inputs should be 1.");
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
for (size_t i = 2; i < ins.size(); ++i) {
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
"All other inputs height should be equal");
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
"All other inputs should have same lod");
}
fc_out->Resize({N, D});
std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
......@@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
fc_act = act_functor(fc_act_str);
}
PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1.");
auto* ref_in = ins[0];
auto ref_in_lod = ref_in->lod();
const int N = ref_in_lod[0].size() - 1;
auto ref_in_dims = ref_in->dims(); // T x M0
auto w_dims = w->dims(); // (M0+M1+M2+..) x D
const int total_T = ref_in_dims[0];
const int M0 = ref_in_dims[1];
const int M1 = ins[1]->dims()[1];
const int D = w_dims[1];
const T* ref_in_data =
ref_in->data<T>(); // size should be check at infershape
const T* ref_in_data = ref_in->data<T>();
const T* in1_data = ins[1]->data<T>();
const T* w_data = w->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
......@@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first one use write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D;
for (int i = 2; i < ins.size(); ++i) {
for (size_t i = 2; i < ins.size(); ++i) {
// add on
const T* in_data = ins[i]->data<T>();
const int K = ins[i]->dims()[1];
......@@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
}
for (int i = 0; i < N; ++i) {
int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i];
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, out_data, src, out_data);
......@@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
}
}
fc_act(out_dims[0] * out_dims[1], out_data, out_data);
fc_act(total_T * D, out_data, out_data);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册