diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 3a34aa86b6331e4fe2813eea97cb6644323807c3..582c75872ab2818cdf834f9a46278db1d6f91d54 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -13,16 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fusion_gru_op.h" +#include // for memcpy #include -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" -#include "paddle/fluid/operators/math/detail/gru_kernel.h" +#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" -#include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -35,12 +32,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "Input(WeightH) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), - "Output(BatchedGate) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), - "Output(BatchResetHiddenPrev) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Output(BatchedHidden) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Output(ReorderedH0) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Output(BatchedInput) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), + "Output(BatchedOut) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of GRU should not be null."); @@ -83,12 +80,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { } framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedHidden", out_dims); - ctx->SetOutputDim("BatchResetHiddenPrev", out_dims); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedOut", out_dims); ctx->ShareLoD("X", "Hidden"); - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + int xx_width; + if (ctx->Attrs().Get("use_seq")) { + xx_width = wx_dims[1]; + } else { + xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -115,22 +116,29 @@ void FusionGRUOpMaker::Make() { "(Tensor) The FC weight with shape (M x 3D)," "where M is the dim size of x, D is the hidden size. "); AddInput("WeightH", - "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "); + "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. " + "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}" + "Acutally they are D x 2D and D x D two part weights." + "{W_update, W_reset; W_state}" + "{D x (D + D); D x D}"); AddInput("Bias", "(Tensor, optional) (1 x 3D)." "Almost same as GRUOp." "Note: if have FC bias it should be added on this bias.") .AsDispensable(); + AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.") + .AsIntermediate(); AddOutput("XX", - "(LoDTensor) the result after X * WeightX (size is T x 4D)" + "(LoDTensor) the result after X * WeightX (size is T x 3D)" " or batched_X (size is T x M), this will be automatically chosen," " where T is the total time steps in this mini-batch," " D is the hidden size, M is the dim size of x input.") .AsIntermediate(); - AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate(); - AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.") + AddOutput("BatchedInput", + "(LoDTensor) This is the batched result of input X" + "or the batched result after fc, shape (T x 3D)") .AsIntermediate(); - AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.") + AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.") .AsIntermediate(); AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp"); AddAttr("activation", @@ -146,6 +154,10 @@ void FusionGRUOpMaker::Make() { "(bool, defalut: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("use_seq", + "(bool, defalut: True) " + "whether to use seq mode to compute GRU.") + .SetDefault(true); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, @@ -153,172 +165,261 @@ more details can refer to GRU op. )DOC"); } -template -inline void ReorderInitState(const DeviceContext& ctx, - const framework::Tensor& src, - framework::Vector index_lod, - framework::Tensor* dst, bool indexed_src) { - math::CopyMatrixRowsFunctor row_shuffle; - dst->mutable_data(src.dims(), ctx.GetPlace()); - row_shuffle(ctx, src, index_lod, dst, indexed_src); -} - -template +template class FusionGRUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + if (ctx.Attr("use_seq")) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } + +#define INIT_VEC_FUNC \ + std::function act_gate, act_state; \ + std::function cross; \ + auto& act_gate_str = ctx.Attr("gate_activation"); \ + auto& act_state_str = ctx.Attr("activation"); \ + if (platform::jit::MayIUse(platform::jit::avx)) { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_state = act_functor(act_state_str); \ + cross = math::vec_cross; \ + } else { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_state = act_functor(act_state_str); \ + cross = math::vec_cross; \ + } + +#define INIT_BASE_INPUT_OUTPUT \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); + +#define INIT_BASE_SIZES \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_dims[0]; \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D3 = wh_dims[1]; \ + const int D2 = D * 2; + + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - auto* h0 = ctx.Input("H0"); - - auto* xx = ctx.Output("XX"); - auto* batched_gate = ctx.Output("BatchedGate"); - auto* batch_reset_hidden_prev = - ctx.Output("BatchResetHiddenPrev"); - auto* batch_hidden = ctx.Output("BatchedHidden"); - auto* hidden_out = ctx.Output("Hidden"); - bool is_reverse = ctx.Attr("is_reverse"); + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES + INIT_VEC_FUNC + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : nullptr; + const T* wx_data = wx->data(); + const T* wh_data = wh->data(); + const T* wh_state_data = wh_data + D * D2; T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* batched_gate_data = batched_gate->mutable_data(ctx.GetPlace()); - batch_reset_hidden_prev->mutable_data(ctx.GetPlace()); - batch_hidden->mutable_data(ctx.GetPlace()); - hidden_out->mutable_data(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, D3, M, x_data, wx_data, + xx_data, + bias ? bias->data() : nullptr); + + int xx_offset = D3; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 3; + hidden_out_data = hidden_out_data + offset; + xx_offset = -D3; + gate_offset = -D; + } + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + }; + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; + const T* prev_hidden_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + bid * D; + } else { + // W: {W_update, W_reset; W_state} + // update gate + act_gate(D, xx_data, xx_data); + // state gate + act_state(D, xx_data + D2, xx_data + D2); + // out = a*b + blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data); + // save prev + prev_hidden_data = hidden_out_data; + tstart = 1; + move_step(); + } + for (int step = tstart; step < seq_len; ++step) { + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast(1), + prev_hidden_data, D, wh_data, D2, static_cast(1), xx_data, + D3); + act_gate(D2, xx_data, xx_data); + // rt = rt*ht_1 inplace result + blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data); + + // gemm rt * Ws + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), + hidden_out_data, D, wh_state_data, D, static_cast(1), + xx_data + D2, D3); + act_state(D, xx_data + D2, xx_data + D2); + // out = zt*ht~ + (1-zt)*ht_1 + cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data); + // save prev + prev_hidden_data = hidden_out_data; + move_step(); + } + } + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* x = ctx.Input("X"); + if (x->lod()[0].size() == 2) { + SeqCompute(ctx); + return; + } + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES + INIT_VEC_FUNC + + auto* reordered_h0 = ctx.Output("ReorderedH0"); + auto* batched_input = ctx.Output("BatchedInput"); + auto* batched_out = ctx.Output("BatchedOut"); const T* x_data = x->data(); const T* wx_data = wx->data(); const T* wh_data = wh->data(); - auto x_dims = x->dims(); - auto wx_dims = wx->dims(); + T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* batched_input_data = batched_input->mutable_data(ctx.GetPlace()); + T* batched_out_data = batched_out->mutable_data(ctx.GetPlace()); + hidden_out->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; - if (x_dims[1] > wx_dims[1]) { - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - x_data, wx_data, xx_data, - bias ? bias->data() : NULL); - to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); + if (M > D3) { + math::FCCompute(blas, total_T, D3, M, x_data, wx_data, + xx_data, + bias ? bias->data() : nullptr); + to_batch(dev_ctx, *xx, batched_input, true, is_reverse); } else { to_batch(dev_ctx, *x, xx, true, is_reverse); - batched_gate->set_lod(xx->lod()); - math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], - xx_data, wx_data, batched_gate_data, - bias ? bias->data() : NULL); + batched_input->set_lod(xx->lod()); + math::FCCompute(blas, total_T, D3, M, xx_data, wx_data, + batched_input_data, + bias ? bias->data() : nullptr); } - int frame_size = static_cast(wx_dims[1] / 3); - math::GRUMetaValue gru_value; - gru_value.gate_weight = const_cast(wh_data); - gru_value.state_weight = - const_cast(wh_data + 2 * frame_size * frame_size); - Tensor ordered_h0; - - framework::Vector order(batched_gate->lod()[2]); + auto batched_lod = batched_input->lod(); + const auto& seq_order = batched_lod[2]; + const int max_bs = seq_order.size(); + reordered_h0->Resize({max_bs, D}); + int tstart = 0; + T* prev_hidden_data = nullptr; if (h0) { - ReorderInitState( - ctx.template device_context(), *h0, order, &ordered_h0, - true); - gru_value.prev_out_value = ordered_h0.data(); + // reorder h0 + T* reordered_h0_data = reordered_h0->mutable_data(ctx.GetPlace()); + const T* h0_data = h0->data(); + prev_hidden_data = reordered_h0_data; + size_t sz = sizeof(T) * D; + for (int i = 0; i < max_bs; ++i) { + std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); + reordered_h0_data += D; + } } else { - gru_value.prev_out_value = nullptr; + // compute without h0 + T* cur_in_data = batched_input_data; + T* cur_out_data = batched_out_data; + // W: {W_update, W_reset; W_state} + for (int i = 0; i < max_bs; ++i) { + // update gate + act_gate(D, cur_in_data, cur_in_data); + // state gate + act_state(D, cur_in_data + D2, cur_in_data + D2); + // out = a*b + blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data); + // add offset + cur_in_data += D3; + cur_out_data += D; + } + tstart = 1; + prev_hidden_data = batched_out_data; } - auto batch_starts = batched_gate->lod()[0]; - size_t seq_len = batch_starts.size() - 1; - auto active_node = - math::detail::GetActivationType(ctx.Attr("activation")); - auto active_gate = math::detail::GetActivationType( - ctx.Attr("gate_activation")); - -#ifdef PADDLE_WITH_MKLML - // use MKL packed to speedup GEMM - if (FLAGS_paddle_num_threads >= 4) { - auto blas = math::GetBlas(dev_ctx); - T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size * 2 /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_gate); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, - frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, - packed_gate); - T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, - frame_size /*width of weight*/, - frame_size /*height of height*/); - PADDLE_ENFORCE(packed_state); - blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, - frame_size, T(1.0), gru_value.state_weight, frame_size, - packed_state); - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batched_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2, - frame_size, gru_value.prev_out_value, frame_size, packed_gate, - frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); - } - - math::detail::forward_reset_output( - math::detail::forward::gru_resetOutput(), gru_value, frame_size, - cur_batch_size, active_gate); - - if (gru_value.prev_out_value) { - blas.GEMM_COMPUTE( - CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, - gru_value.reset_output_value, frame_size, packed_state, - frame_size, T(1), gru_value.gate_value + frame_size * 2, - frame_size * 3); - } - - math::detail::forward_final_output( - math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); - - gru_value.prev_out_value = gru_value.output_value; + // Then start from next + const T* wh_state_data = wh_data + D * D2; + const auto& batch_starts = batched_lod[0]; + const int max_seq_len = batch_starts.size() - 1; + batched_input_data = batched_input_data + tstart * max_bs * D3; + batched_out_data = batched_out_data + tstart * max_bs * D; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast(1), + prev_hidden_data, D, wh_data, D2, static_cast(1), + batched_input_data, D3); + + T* cur_batched_data = batched_input_data; + T* cur_out_data = batched_out_data; + T* cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + act_gate(D2, cur_batched_data, cur_batched_data); + // rt = rt*ht_1 inplace result + blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); + + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; } - blas.GEMM_FREE(packed_gate); - blas.GEMM_FREE(packed_state); - } else { -#endif - for (size_t n = 0; n < seq_len; n++) { - int bstart = static_cast(batch_starts[n]); - int bend = static_cast(batch_starts[n + 1]); - int cur_batch_size = bend - bstart; - - Tensor gate_t = batched_gate->Slice(bstart, bend); - Tensor reset_hidden_prev_t = - batch_reset_hidden_prev->Slice(bstart, bend); - Tensor hidden_t = batch_hidden->Slice(bstart, bend); - gru_value.output_value = hidden_t.data(); - gru_value.gate_value = gate_t.data(); - gru_value.reset_output_value = reset_hidden_prev_t.data(); - - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); - - gru_value.prev_out_value = gru_value.output_value; + cur_batched_data = batched_input_data; + cur_out_data = batched_out_data; + blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast(1), + cur_out_data, D, wh_state_data, D, static_cast(1), + cur_batched_data + D2, D3); + + cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + // ht~ = act_state(...) + act_state(D, cur_batched_data + D2, cur_batched_data + D2); + // out = zt*ht~ + (1-zt)*ht_1 + cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data, + cur_out_data); + + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; } -#ifdef PADDLE_WITH_MKLML + prev_hidden_data = batched_out_data; + batched_out_data = cur_out_data; + batched_input_data = cur_batched_data; } -#endif + math::Batch2LoDTensorFunctor to_seq; - batch_hidden->set_lod(batched_gate->lod()); - to_seq(dev_ctx, *batch_hidden, hidden_out); + batched_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_out, hidden_out); } +#undef INIT_VEC_FUNC +#undef INIT_BASE_SIZES +#undef INIT_BASE_INPUT_OUTPUT }; } // namespace operators @@ -327,6 +428,5 @@ class FusionGRUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL( - fusion_gru, ops::FusionGRUKernel, - ops::FusionGRUKernel); +REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel, + ops::FusionGRUKernel); diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 5693761e9ffd96b40040223b5498b63b0274bf0f..9560e3a3c15ca63892fbe3552679a22f027f11e2 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -132,6 +132,121 @@ inline void vec_scal(const int n, vec_scal(n, a, x, y); } +template +inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = a - x[i]; + } +} + +template <> +inline void vec_bias_sub(const int n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_bias_sub(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_sub_ps(bias, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = a - x[i]; + } +#else + vec_bias_sub(n, a, x, y); +#endif +} + +template <> +inline void vec_bias_sub(const int n, const float a, + const float* x, float* y) { + vec_bias_sub(n, a, x, y); +} + +template <> +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_bias_sub(n, a, x, y); +} + +// out = x*y + (1-x)*z +template +inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { + for (int i = 0; i < n; ++i) { + out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; + } +} + +template <> +inline void vec_cross(const int n, const float* x, + const float* y, const float* z, + float* out) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_cross(n, x, y, z, out); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(1.f); + __m256 tmpx, tmpy, tmpz; + for (i = 0; i < end; i += block) { + tmpx = _mm256_loadu_ps(x + i); + tmpy = _mm256_loadu_ps(y + i); + tmpz = _mm256_loadu_ps(z + i); + tmpy = _mm256_mul_ps(tmpx, tmpy); + tmpx = _mm256_sub_ps(bias, tmpx); + tmpz = _mm256_mul_ps(tmpx, tmpz); + tmpz = _mm256_add_ps(tmpy, tmpz); + _mm256_storeu_ps(out + i, tmpz); + } + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; + } +#else + vec_cross(n, x, y, z, out); +#endif +} + +template <> +inline void vec_cross(const int n, const float* x, + const float* y, + const float* z, float* out) { + vec_cross(n, x, y, z, out); +} + +template <> +inline void vec_cross( + const int n, const float* x, const float* y, const float* z, float* out) { + // TODO(TJ): enable me + vec_cross(n, x, y, z, out); +} + template inline void vec_add_bias(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h index 07372235a7c23832e528c3e852a4747f4244b833..a3186f82d0c0cc6c9585735ddf7e9bb4db7126cb 100644 --- a/paddle/fluid/operators/math/sequence2batch.h +++ b/paddle/fluid/operators/math/sequence2batch.h @@ -92,7 +92,7 @@ class LoDTensor2BatchFunctor { // Calculate the start position of each batch. // example: sequences = {s0, s1, s2} // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 - // num_batch = 5, + // max_seqlen = 5, // batchIndex = {b0, b1, b2, b3, b4} // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} @@ -109,7 +109,7 @@ class LoDTensor2BatchFunctor { // where 1 is the second sequence, // 0 is the first sequence, // 2 is the third sequence. - // The num_batch represents batch size after rearranging the + // The max_seqlen represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. paddle::framework::LoD batch_lods; @@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor { batch_lods.emplace_back(std::vector{0}); // batch_lods[0] is the start positions for batch LoDTensor - int num_batch = seq_info[0].length; - batch_lods[0].resize(static_cast(num_batch + 1)); + int max_seqlen = seq_info[0].length; + batch_lods[0].resize(static_cast(max_seqlen + 1)); // batch_lods[1] is the raw index in the input LoDTensor batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); // batch_lods[2] is the sort order for the input LoDTensor. @@ -128,7 +128,7 @@ class LoDTensor2BatchFunctor { size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; - for (int n = 0; n < num_batch; n++) { + for (int n = 0; n < max_seqlen; n++) { auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { int seq_len = seq_info[i].length;