未验证 提交 8198cad7 编写于 作者: Y YuanRisheng 提交者: GitHub

remove KernelName (#38082)

上级 4c1e27cc
......@@ -1275,7 +1275,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->name);
auto pt_kernel_name = pt_kernel_signature_->name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
......
......@@ -165,7 +165,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << framework::KernelSignatureToString(pt_kernel_signature);
auto pt_kernel_name = pten::KernelName(pt_kernel_signature.name);
auto pt_kernel_name = pt_kernel_signature.name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
auto pt_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
// TODO(chenweihang): split KernelName, Key, Kernel, Factory into diff files
// TODO(chenweihang): split Key, Kernel, Factory into diff files
#include "paddle/pten/core/kernel_factory.h"
// See Note [ Why still include the fluid headers? ]
......
......@@ -37,7 +37,7 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory;
}
Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
......@@ -51,7 +51,7 @@ Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const KernelName& kernel_name, const KernelKey& kernel_key) const {
const std::string& kernel_name, const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(iter,
kernels_.end(),
......@@ -78,7 +78,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const KernelName& kernel_name,
const std::string& kernel_name,
Backend backend,
DataLayout layout,
DataType dtype) const {
......
......@@ -51,61 +51,6 @@ class KernelContext;
using KernelFn = void (*)(KernelContext* ctx);
class KernelName final {
public:
KernelName(std::string name, std::string overload_name)
: name_(std::move(name)), overload_name_(std::move(overload_name)) {}
KernelName(const std::string& kernel_name) {
ParseNameAndOverloadNameFromString(kernel_name);
}
KernelName(const char* kernel_name) {
std::string kernel_name_str(kernel_name);
ParseNameAndOverloadNameFromString(kernel_name_str);
}
const std::string& name() const { return name_; }
const std::string& overload_name() const { return overload_name_; }
struct Hash {
size_t operator()(const KernelName& kernel_name) const {
return std::hash<std::string>()(kernel_name.name()) ^
(std::hash<std::string>()(kernel_name.overload_name()) << 1);
}
};
size_t hash_value() const { return Hash()(*this); }
bool operator<(const KernelName& kernel_name) const {
return hash_value() < kernel_name.hash_value();
}
bool operator==(const KernelName& kernel_name) const {
return hash_value() == kernel_name.hash_value();
}
bool operator!=(const KernelName& kernel_name) const {
return hash_value() != kernel_name.hash_value();
}
private:
void ParseNameAndOverloadNameFromString(const std::string& kernel_name) {
size_t pos = kernel_name.find_first_of('.');
if (pos == std::string::npos) {
name_ = kernel_name;
overload_name_ = "";
} else {
name_ = kernel_name.substr(0, pos);
overload_name_ = kernel_name.substr(pos + 1, kernel_name.size());
}
}
// TODO(chenweihang): use string_view to improve performance later
std::string name_;
std::string overload_name_;
};
class KernelKey {
public:
KernelKey() = default;
......@@ -265,9 +210,8 @@ class KernelFactory {
public:
// replaced by paddle::flat_hash_map later
using KernelMap = paddle::flat_hash_map<
KernelName,
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>,
KernelName::Hash>;
std::string,
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>>;
static KernelFactory& Instance();
......@@ -277,15 +221,15 @@ class KernelFactory {
return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end();
}
const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const;
const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
Backend backend,
DataLayout layout,
DataType dtype) const;
Kernel SelectKernel(const KernelName& kernel_name,
Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
private:
......@@ -294,18 +238,6 @@ class KernelFactory {
KernelMap kernels_;
};
/** operator << overload **/
inline std::ostream& operator<<(std::ostream& os,
const KernelName& kernel_name) {
if (kernel_name.overload_name().empty()) {
os << kernel_name.name();
} else {
os << kernel_name.name() << "." << kernel_name.overload_name();
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
os << "(" << kernel_key.backend() << ", " << kernel_key.layout() << ", "
<< kernel_key.dtype() << ")";
......
......@@ -143,7 +143,7 @@ struct KernelRegistrar {
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn,
void* variadic_kernel_fn) {
KernelName kernel_name(kernel_name_cstr);
std::string kernel_name(kernel_name_cstr);
KernelKey kernel_key(backend, layout, dtype);
Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def());
......
......@@ -24,18 +24,6 @@ namespace tests {
// TODO(chenweihang): add more unittests later
TEST(KernelName, ConstructAndOStream) {
std::ostringstream oss;
oss << pten::KernelName("scale", "host");
EXPECT_EQ(oss.str(), "scale.host");
pten::KernelName kernel_name1("scale.host");
EXPECT_EQ(kernel_name1.name(), "scale");
EXPECT_EQ(kernel_name1.overload_name(), "host");
pten::KernelName kernel_name2("scale.host");
EXPECT_EQ(kernel_name2.name(), "scale");
EXPECT_EQ(kernel_name2.overload_name(), "host");
}
TEST(KernelKey, ConstructAndOStream) {
pten::KernelKey key(
pten::Backend::CPU, pten::DataLayout::NCHW, pten::DataType::FLOAT32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册