提交 b9acbcc8 编写于 作者: T tensor-tang

init lstm kernel

上级 c260bf94
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <functional>
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() { ...@@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() {
return g_jit_kernels; return g_jit_kernels;
} }
template <>
LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
const std::string& act_cand_str,
const std::string& act_cell_str)
: Kernel(), d_(d) {
if (platform::jit::MayIUse(platform::jit::avx512_common)) {
math::VecActivations<float, platform::jit::avx512_common> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else if (platform::jit::MayIUse(platform::jit::avx2)) {
math::VecActivations<float, platform::jit::avx2> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<float, platform::jit::avx> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else {
math::VecActivations<float, platform::jit::isa_any> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
}
}
template <> template <>
const std::shared_ptr<LSTMKernel<float>> const std::shared_ptr<LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&, KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate, const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand, const std::string& act_cand,
const std::string& act_cell) { const std::string& act_cell) {
return nullptr; std::string key = "f" + std::to_string(d) + act_gate + act_cand + act_cell;
if (kers_.find(key) == kers_.end()) {
auto p =
std::make_shared<LSTMKernel<float>>(d, act_gate, act_cand, act_cell);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<LSTMKernel<float>>(kers_.at(key));
} }
} // namespace jitkernel } // namespace jitkernel
......
...@@ -14,10 +14,9 @@ limitations under the License. */ ...@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <functional> #include <functional>
#include <map>
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string>
#include <vector> #include <unordered_map>
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet. // Note: Only support on CPU yet.
...@@ -27,23 +26,43 @@ namespace math { ...@@ -27,23 +26,43 @@ namespace math {
namespace jitkernel { namespace jitkernel {
class Kernel { class Kernel {
public:
Kernel() {}
virtual ~Kernel() = default;
private:
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
class KernelPool { class KernelPool {
public: public:
static KernelPool &Instance(); static KernelPool& Instance();
template <typename Ker, typename... ARGS> template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args); const std::shared_ptr<Ker> Get(ARGS... args);
private: private:
KernelPool() = default; KernelPool() = default;
// std::unordered_map<std::string, Kernel> kers_; std::unordered_map<std::string, std::shared_ptr<Kernel>> kers_;
DISABLE_COPY_AND_ASSIGN(KernelPool); DISABLE_COPY_AND_ASSIGN(KernelPool);
}; };
template <typename T>
class LSTMKernel : public Kernel {
public:
explicit LSTMKernel(int d, const std::string& act_gate,
const std::string& act_cand, const std::string& act_cell);
void ComputeCtHt(T* gates, const T* ct_1, T* ct);
void ComputeCtHt_NoC0H0(T* gates, const T* ct_1, T* ct);
private:
int d_;
std::function<void(const int, const T *, T *)> act_gate_, act_cell_,
act_cand_;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -21,12 +21,7 @@ limitations under the License. */ ...@@ -21,12 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {} // namespace jitkernel
template <typename T>
class LSTMKernel : public Kernel {};
} // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册