diff --git a/paddle/fluid/operators/jit/gen/act.h b/paddle/fluid/operators/jit/gen/act.h index c35579c3adf25947c986f44ac8cac8ca972fe07a..3ebd785d509556fb0f73f902cba2c102d7364987 100644 --- a/paddle/fluid/operators/jit/gen/act.h +++ b/paddle/fluid/operators/jit/gen/act.h @@ -59,43 +59,12 @@ extern int g_tmp_mem[]; #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) -class VActJitCode : public JitCode { +class VActFunc : public JitCode { public: - explicit VActJitCode(int d, operand_type type, size_t code_size, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d), type_(type) { - if (!(type_ == operand_type::relu || type_ == operand_type::exp || - type_ == operand_type::sigmoid || type_ == operand_type::tanh || - type_ == operand_type::identity)) { - LOG(FATAL) << "Do not support this operand type: " << type_; - } - this->genCode(); - } - - const char* name() const override { - std::string base = "VActJitCode"; - switch (type_) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - return base.c_str(); - } - void genCode() override; + explicit VActFunc(size_t code_size, void* code_ptr) + : JitCode(code_size, code_ptr) {} + virtual const char* name() const = 0; + virtual void genCode() = 0; protected: // compute relu with ymm, xmm @@ -272,10 +241,49 @@ class VActJitCode : public JitCode { identity_jmm(dst, src, 15); break; default: - LOG(FATAL) << "Do not support this operand type: " << type_; + LOG(FATAL) << "Do not support this operand type: " << type; break; } } +}; + +class VActJitCode : public VActFunc { + public: + explicit VActJitCode(int d, operand_type type, size_t code_size, + void* code_ptr = nullptr) + : VActFunc(code_size, code_ptr), num_(d), type_(type) { + if (!(type_ == operand_type::relu || type_ == operand_type::exp || + type_ == operand_type::sigmoid || type_ == operand_type::tanh || + type_ == operand_type::identity)) { + LOG(FATAL) << "Do not support this operand type: " << type_; + } + this->genCode(); + } + + const char* name() const override { + std::string base = "VActJitCode"; + switch (type_) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + return base.c_str(); + } + void genCode() override; protected: int num_; diff --git a/paddle/fluid/operators/jit/gen/gru.h b/paddle/fluid/operators/jit/gen/gru.h index bab1c6a4eee5900994a19c73c76f457bcf5ba7c9..1a1c9f297347a3e90ced42b10ee3faa3959c0303 100644 --- a/paddle/fluid/operators/jit/gen/gru.h +++ b/paddle/fluid/operators/jit/gen/gru.h @@ -24,13 +24,11 @@ namespace operators { namespace jit { namespace gen { -class GRUJitCode : public VActJitCode { +class GRUJitCode : public VActFunc { public: explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - id_(id) { + : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) { auto typeExchange = [](KernelType type) -> gen::operand_type { if (type == KernelType::vsigmoid) { return operand_type::sigmoid; @@ -45,7 +43,6 @@ class GRUJitCode : public VActJitCode { } return operand_type::identity; }; - num_ = attr.d; act_gate_ = typeExchange(attr.act_gate); act_cand_ = typeExchange(attr.act_cand); diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 898d7df34510e7335facaa98b49b360f3e7a0179..b2eeb9b65e413de57a53bb7474a8a8de5d0d7609 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -62,7 +62,9 @@ typedef enum { class JitCode : public GenBase, public Xbyak::CodeGenerator { public: explicit JitCode(size_t code_size, void* code_ptr = nullptr) - : Xbyak::CodeGenerator((code_size < 4096 ? 4096 : code_size), code_ptr) {} + : Xbyak::CodeGenerator( + (code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size), + code_ptr) {} virtual const char* name() const = 0; virtual void genCode() = 0; diff --git a/paddle/fluid/operators/jit/gen/lstm.h b/paddle/fluid/operators/jit/gen/lstm.h index cb8705c6d95c38e1836bd9879a8328df36748232..b2493878ceacc7e4b62517c43bae2029602476e8 100644 --- a/paddle/fluid/operators/jit/gen/lstm.h +++ b/paddle/fluid/operators/jit/gen/lstm.h @@ -24,13 +24,14 @@ namespace operators { namespace jit { namespace gen { -class LSTMJitCode : public VActJitCode { +class LSTMJitCode : public VActFunc { public: explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, size_t code_size, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - compute_c1h1_(compute_c1h1) { + : VActFunc(code_size, code_ptr), + num_(attr.d), + compute_c1h1_(compute_c1h1), + use_peephole_(attr.use_peephole) { auto typeExchange = [](KernelType type) -> gen::operand_type { if (type == KernelType::vsigmoid) { return operand_type::sigmoid; @@ -45,8 +46,6 @@ class LSTMJitCode : public VActJitCode { } return operand_type::identity; }; - num_ = attr.d; - use_peephole_ = attr.use_peephole; act_gate_ = typeExchange(attr.act_gate); act_cand_ = typeExchange(attr.act_cand); act_cell_ = typeExchange(attr.act_cell); diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 9ba0a9583131de0c0251bd6da099fba318047aa3..bdc9c1250e82718ce07e6eabfac62531d452b18b 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -80,7 +80,7 @@ struct rnn_attr_s { int d; KernelType act_gate, act_cand; rnn_attr_s() = default; - rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) + explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} }; @@ -88,8 +88,8 @@ struct lstm_attr_s : public rnn_attr_s { bool use_peephole; KernelType act_cell; lstm_attr_s() = default; - lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, - KernelType _act_cell, bool _use_peephole = false) + explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, + KernelType _act_cell, bool _use_peephole = false) : rnn_attr_s(_d, _act_gate, _act_cand), use_peephole(_use_peephole), act_cell(_act_cell) {} @@ -145,6 +145,8 @@ class Kernel { template class KernelImpl : public Kernel { + // TODO(TJ): rename KernelImpl to KernelMore which seems only used in more + // and add name interface for more implements easy for debug using T = typename KernelTuples::data_type; using Func = typename KernelTuples::func_type; using Attr = typename KernelTuples::attr_type;