未验证 提交 7f868bd1 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15737 from tensor-tang/fix/name

cherry pick: fix jitcode name
...@@ -63,7 +63,6 @@ class VActFunc : public JitCode { ...@@ -63,7 +63,6 @@ class VActFunc : public JitCode {
public: public:
explicit VActFunc(size_t code_size, void* code_ptr) explicit VActFunc(size_t code_size, void* code_ptr)
: JitCode(code_size, code_ptr) {} : JitCode(code_size, code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0; virtual void genCode() = 0;
protected: protected:
...@@ -269,7 +268,7 @@ class VActJitCode : public VActFunc { ...@@ -269,7 +268,7 @@ class VActJitCode : public VActFunc {
this->genCode(); this->genCode();
} }
const char* name() const override { std::string name() const override {
std::string base = "VActJitCode"; std::string base = "VActJitCode";
switch (type_) { switch (type_) {
case operand_type::RELU: case operand_type::RELU:
...@@ -293,7 +292,7 @@ class VActJitCode : public VActFunc { ...@@ -293,7 +292,7 @@ class VActJitCode : public VActFunc {
default: default:
break; break;
} }
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -41,7 +41,7 @@ class VXXJitCode : public JitCode { ...@@ -41,7 +41,7 @@ class VXXJitCode : public JitCode {
this->genCode(); this->genCode();
} }
virtual const char* name() const { std::string name() const override {
std::string base = "VXXJitCode"; std::string base = "VXXJitCode";
if (scalar_index_ == 1) { if (scalar_index_ == 1) {
base += "_Scalar"; base += "_Scalar";
...@@ -62,7 +62,7 @@ class VXXJitCode : public JitCode { ...@@ -62,7 +62,7 @@ class VXXJitCode : public JitCode {
} }
base += (with_relu_ ? "_Relu" : ""); base += (with_relu_ ? "_Relu" : "");
base += "_D" + std::to_string(num_); base += "_D" + std::to_string(num_);
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -49,7 +49,7 @@ class GRUJitCode : public VActFunc { ...@@ -49,7 +49,7 @@ class GRUJitCode : public VActFunc {
this->genCode(); this->genCode();
} }
const char* name() const override { std::string name() const override {
std::string base = "GRUJitCode"; std::string base = "GRUJitCode";
if (id_ == 0) { if (id_ == 0) {
base += "_H1"; base += "_H1";
...@@ -81,7 +81,7 @@ class GRUJitCode : public VActFunc { ...@@ -81,7 +81,7 @@ class GRUJitCode : public VActFunc {
}; };
AddTypeStr(act_gate_); AddTypeStr(act_gate_);
AddTypeStr(act_cand_); AddTypeStr(act_cand_);
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -35,14 +35,14 @@ class HOPVJitCode : public JitCode { ...@@ -35,14 +35,14 @@ class HOPVJitCode : public JitCode {
this->genCode(); this->genCode();
} }
virtual const char* name() const { std::string name() const override {
std::string base = "VXXJitCode"; std::string base = "VXXJitCode";
if (type_ == operand_type::MAX) { if (type_ == operand_type::MAX) {
base += "_MAX"; base += "_MAX";
} else { } else {
base += "_SUM"; base += "_SUM";
} }
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <string>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -59,7 +60,7 @@ typedef enum { ...@@ -59,7 +60,7 @@ typedef enum {
} operand_type; } operand_type;
#define DECLARE_JIT_CODE(codename) \ #define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; } std::string name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator { class JitCode : public GenBase, public Xbyak::CodeGenerator {
public: public:
...@@ -68,7 +69,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -68,7 +69,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size), (code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {} code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0; virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); } size_t getSize() const override { return CodeGenerator::getSize(); }
......
...@@ -53,7 +53,7 @@ class LSTMJitCode : public VActFunc { ...@@ -53,7 +53,7 @@ class LSTMJitCode : public VActFunc {
this->genCode(); this->genCode();
} }
const char* name() const override { std::string name() const override {
std::string base = "LSTMJitCode"; std::string base = "LSTMJitCode";
if (use_peephole_) { if (use_peephole_) {
base += "_Peephole"; base += "_Peephole";
...@@ -85,7 +85,7 @@ class LSTMJitCode : public VActFunc { ...@@ -85,7 +85,7 @@ class LSTMJitCode : public VActFunc {
AddTypeStr(act_gate_); AddTypeStr(act_gate_);
AddTypeStr(act_cand_); AddTypeStr(act_cand_);
AddTypeStr(act_cell_); AddTypeStr(act_cell_);
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -36,11 +36,11 @@ class MatMulJitCode : public JitCode { ...@@ -36,11 +36,11 @@ class MatMulJitCode : public JitCode {
this->genCode(); this->genCode();
} }
virtual const char* name() const { std::string name() const override {
std::string base = "MatMulJitCode"; std::string base = "MatMulJitCode";
base = base + "_M" + std::to_string(m_) + "_N" + std::to_string(n_) + "_K" + base = base + "_M" + std::to_string(m_) + "_N" + std::to_string(n_) + "_K" +
std::to_string(k_); std::to_string(k_);
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -38,7 +38,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -38,7 +38,7 @@ class SeqPoolJitCode : public JitCode {
this->genCode(); this->genCode();
} }
virtual const char* name() const { std::string name() const override {
std::string base = "SeqPoolJitCode"; std::string base = "SeqPoolJitCode";
if (type_ == SeqPoolType::kSum) { if (type_ == SeqPoolType::kSum) {
base += "_Sum"; base += "_Sum";
...@@ -48,7 +48,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -48,7 +48,7 @@ class SeqPoolJitCode : public JitCode {
base += "_Sqrt"; base += "_Sqrt";
} }
base += ("_W" + std::to_string(w_)); base += ("_W" + std::to_string(w_));
return base.c_str(); return base;
} }
void genCode() override; void genCode() override;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -28,7 +29,7 @@ namespace jit { ...@@ -28,7 +29,7 @@ namespace jit {
class GenBase : public Kernel { class GenBase : public Kernel {
public: public:
virtual ~GenBase() = default; virtual ~GenBase() = default;
virtual const char* name() const = 0; virtual std::string name() const = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0; virtual const unsigned char* getCodeInternal() = 0;
template <typename Func> template <typename Func>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册