提交 b9acbcc8 编写于 作者: T tensor-tang

init lstm kernel

上级 c260bf94
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <functional>
#include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
......@@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() {
return g_jit_kernels;
}
template <>
LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
const std::string& act_cand_str,
const std::string& act_cell_str)
: Kernel(), d_(d) {
if (platform::jit::MayIUse(platform::jit::avx512_common)) {
math::VecActivations<float, platform::jit::avx512_common> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else if (platform::jit::MayIUse(platform::jit::avx2)) {
math::VecActivations<float, platform::jit::avx2> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<float, platform::jit::avx> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
} else {
math::VecActivations<float, platform::jit::isa_any> act_functor;
act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str);
}
}
template <>
const std::shared_ptr<LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
const std::string&>(int d, const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell) {
return nullptr;
std::string key = "f" + std::to_string(d) + act_gate + act_cand + act_cell;
if (kers_.find(key) == kers_.end()) {
auto p =
std::make_shared<LSTMKernel<float>>(d, act_gate, act_cand, act_cell);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<LSTMKernel<float>>(kers_.at(key));
}
} // namespace jitkernel
......
......@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <memory> // for shared_ptr
#include <string>
#include <vector>
#include <unordered_map>
#include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet.
......@@ -27,23 +26,43 @@ namespace math {
namespace jitkernel {
class Kernel {
public:
Kernel() {}
virtual ~Kernel() = default;
private:
DISABLE_COPY_AND_ASSIGN(Kernel);
};
class KernelPool {
public:
static KernelPool &Instance();
static KernelPool& Instance();
template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args);
private:
KernelPool() = default;
// std::unordered_map<std::string, Kernel> kers_;
std::unordered_map<std::string, std::shared_ptr<Kernel>> kers_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
template <typename T>
class LSTMKernel : public Kernel {
public:
explicit LSTMKernel(int d, const std::string& act_gate,
const std::string& act_cand, const std::string& act_cell);
void ComputeCtHt(T* gates, const T* ct_1, T* ct);
void ComputeCtHt_NoC0H0(T* gates, const T* ct_1, T* ct);
private:
int d_;
std::function<void(const int, const T *, T *)> act_gate_, act_cell_,
act_cand_;
};
} // namespace jitkernel
} // namespace math
} // namespace operators
......
......@@ -21,12 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
template <typename T>
class LSTMKernel : public Kernel {};
} // namespace jitkernel
namespace jitkernel {} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册