未验证 提交 7f868bd1 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15737 from tensor-tang/fix/name

cherry pick: fix jitcode name
......@@ -63,7 +63,6 @@ class VActFunc : public JitCode {
public:
explicit VActFunc(size_t code_size, void* code_ptr)
: JitCode(code_size, code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
protected:
......@@ -269,7 +268,7 @@ class VActJitCode : public VActFunc {
this->genCode();
}
const char* name() const override {
std::string name() const override {
std::string base = "VActJitCode";
switch (type_) {
case operand_type::RELU:
......@@ -293,7 +292,7 @@ class VActJitCode : public VActFunc {
default:
break;
}
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -41,7 +41,7 @@ class VXXJitCode : public JitCode {
this->genCode();
}
virtual const char* name() const {
std::string name() const override {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
......@@ -62,7 +62,7 @@ class VXXJitCode : public JitCode {
}
base += (with_relu_ ? "_Relu" : "");
base += "_D" + std::to_string(num_);
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -49,7 +49,7 @@ class GRUJitCode : public VActFunc {
this->genCode();
}
const char* name() const override {
std::string name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
......@@ -81,7 +81,7 @@ class GRUJitCode : public VActFunc {
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -35,14 +35,14 @@ class HOPVJitCode : public JitCode {
this->genCode();
}
virtual const char* name() const {
std::string name() const override {
std::string base = "VXXJitCode";
if (type_ == operand_type::MAX) {
base += "_MAX";
} else {
base += "_SUM";
}
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -14,6 +14,7 @@
#pragma once
#include <string>
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -59,7 +60,7 @@ typedef enum {
} operand_type;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
std::string name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
......@@ -68,7 +69,6 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
......
......@@ -53,7 +53,7 @@ class LSTMJitCode : public VActFunc {
this->genCode();
}
const char* name() const override {
std::string name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
......@@ -85,7 +85,7 @@ class LSTMJitCode : public VActFunc {
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -36,11 +36,11 @@ class MatMulJitCode : public JitCode {
this->genCode();
}
virtual const char* name() const {
std::string name() const override {
std::string base = "MatMulJitCode";
base = base + "_M" + std::to_string(m_) + "_N" + std::to_string(n_) + "_K" +
std::to_string(k_);
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -38,7 +38,7 @@ class SeqPoolJitCode : public JitCode {
this->genCode();
}
virtual const char* name() const {
std::string name() const override {
std::string base = "SeqPoolJitCode";
if (type_ == SeqPoolType::kSum) {
base += "_Sum";
......@@ -48,7 +48,7 @@ class SeqPoolJitCode : public JitCode {
base += "_Sqrt";
}
base += ("_W" + std::to_string(w_));
return base.c_str();
return base;
}
void genCode() override;
......
......@@ -16,6 +16,7 @@
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h"
......@@ -28,7 +29,7 @@ namespace jit {
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual std::string name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename Func>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册