diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 8a540108302f77e1ca3bfe1db0013d76a22d5eb4..2b8c758a032fd7edff0d4b7e23bd8e685eb3ab15 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1) USE_JITKERNEL_GEN(kGRUHtPart1) USE_JITKERNEL_GEN(kGRUHtPart2) USE_JITKERNEL_GEN(kNCHW16CMulNC) +USE_JITKERNEL_GEN(kSeqPool) diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce6801b0307663d9b6d2fd827521fc93d7abf414 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/seqpool.h" +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void SeqPoolJitCode::genCode() { + constexpr int block = YMM_FLOAT_BLOCK; + constexpr int max_num_regs = 8; + const int num_block = w_ / block; + const int num_groups = num_block / max_num_regs; + int rest_num_regs = num_block % max_num_regs; + if (type_ == SeqPoolType::kAvg) { + float scalar = 1.f / h_; + mov(reg32_scalar, scalar); + } else if (type_ == SeqPoolType::kSqrt) { + float scalar = 1.f / std::sqrt(static_cast(h_)); + mov(reg32_scalar, scalar); + } + + // TODO(TJ): make height load from params + const int group_len = max_num_regs * block * sizeof(float); + for (int g = 0; g < num_groups; ++g) { + pool_height(g * group_len, block, max_num_regs); + } + if (rest_num_regs > 0) { + pool_height(num_groups * group_len, block, rest_num_regs); + } + + // rest part + const int rest = w_ % block; + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float); + for (int h = 0; h < h_; ++h) { + int offset = h * w_ * sizeof(float) + w_offset; + const int shift_regs = (h == 0) ? 0 : max_num_regs; + int reg_idx = 0; + if (has_block4) { + vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]); + reg_idx++; + } + rest_num_regs = reg_idx; + if (h > 0) { + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + } + } + // save right now + int offset = w_offset; + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar); + for (int i = 0; i < rest_num_regs; ++i) { + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1)); + } + } + int reg_idx = 0; + if (has_block4) { + vmovups(ptr[param2 + offset], xmm_t(reg_idx)); + offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(ptr[param2 + offset], xmm_t(reg_idx)); + offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(ptr[param2 + offset], xmm_t(reg_idx)); + } + ret(); +} + +class SeqPoolCreator : public JitCodeCreator { + public: + bool UseMe(const seq_pool_attr_t& attr) const override { + return platform::MayIUse(platform::avx); + } + size_t CodeSize(const seq_pool_attr_t& attr) const override { + // TODO(TJ): remove attr.h when enabled height + bool yes = + attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt; + return 96 /* basic */ + + ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */ + * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) * + 8; + } + std::unique_ptr CreateJitCode( + const seq_pool_attr_t& attr) const override { + PADDLE_ENFORCE_GT(attr.w, 0); + PADDLE_ENFORCE_GT(attr.h, 0); + return make_unique(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator); diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h new file mode 100644 index 0000000000000000000000000000000000000000..eb2d19138267300a8e685c1af482891efea89d00 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class SeqPoolJitCode : public JitCode { + public: + explicit SeqPoolJitCode(const seq_pool_attr_t& attr, + size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) { + if (type_ != SeqPoolType::kSum) { + LOG(FATAL) << "Only support sum pool yet "; + } + this->genCode(); + } + + virtual const char* name() const { + std::string base = "SeqPoolJitCode"; + if (type_ == SeqPoolType::kSum) { + base += "_Sum"; + } else if (type_ == SeqPoolType::kAvg) { + base += "_Avg"; + } else if (type_ == SeqPoolType::kSqrt) { + base += "_Sqrt"; + } + base += ("_W" + std::to_string(w_)); + // TODO(TJ): make h load from params + base += ("_H" + std::to_string(h_)); + return base.c_str(); + } + void genCode() override; + + protected: + template + void pool_height(int w_offset, int block, int max_num_regs) { + for (int h = 0; h < h_; ++h) { + int offset = h * w_ * sizeof(float) + w_offset; + const int shift_regs = (h == 0) ? 0 : max_num_regs; + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + shift_regs), ptr[param1 + offset]); + offset += sizeof(float) * block; + } + if (h > 0) { + // sum anyway + for (int i = 0; i < max_num_regs; ++i) { + vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + } + } + } + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vbroadcastss(JMM(max_num_regs), reg32_scalar); + } + int offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vmulps(JMM(i), JMM(i), JMM(max_num_regs)); + } + vmovups(ptr[param2 + offset], JMM(i)); + offset += sizeof(float) * block; + } + } + + private: + int h_; + int w_; + SeqPoolType type_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + reg64_t param3{abi_param3}; + reg32_t reg32_scalar{r8d}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 6b0025a75a4a6bba53e695c344e57e9e325a64fa..db78ed8ad8293fb804623fc6af8b3d21a19f00b4 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -44,8 +44,11 @@ size_t JitCodeKey(const gru_attr_t& attr) { template <> size_t JitCodeKey(const seq_pool_attr_t& attr) { - size_t key = static_cast(attr.type); - return key + (attr.w << act_type_shift); + size_t key = attr.w; + // TODO(TJ): support height, then removed it from key + constexpr int w_shift = 30; + return (key << act_type_shift) + static_cast(attr.type) + + (static_cast(attr.h) << (act_type_shift + w_shift)); } } // namespace jit diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 98707c936dac710f0c4596a3bdc95b264eab78b1..283e2e251a45788df537352919e55c2987c56276 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -255,11 +255,11 @@ class SequencePoolFunctor { jit::seq_pool_attr_t attr; attr.w = input.numel() / input.dims()[0]; attr.type = jit::SeqPoolType::kSum; - auto seqpool = - jit::Get, platform::CPUPlace>( - attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); + auto seqpool = + jit::Get, platform::CPUPlace>( + attr); seqpool(src, dst, &attr); dst += attr.w; src += attr.h * attr.w;