提交 e30d80d4 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Change selective registration of op kernels to work off of the class name

instead of the filename.

Change FindKernelDef to also return the class name, to help a tool that
generates ops_to_register.h to find the set of class names.
Change: 118381472
上级 69eaa5d7
......@@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// Look up the Kernel registered for this node def.
const KernelDef* kdef = nullptr;
status = FindKernelDef(device_type, ndef, &kdef);
status =
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
if (!status.ok() || HasTypeList(*op_def)) {
// When there is no kernel def for this op or the op's arg is a
......
......@@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
// OpKernel registration ------------------------------------------------------
struct KernelRegistration {
KernelRegistration(const KernelDef& d,
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
: def(d), factory(f) {}
: def(d), kernel_class_name(c.ToString()), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
};
......@@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type,
namespace kernel_factory {
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
Factory factory) {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
GlobalKernelRegistryTyped()->insert(
std::make_pair(key, KernelRegistration(*kernel_def, factory)));
GlobalKernelRegistryTyped()->insert(std::make_pair(
key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
delete kernel_def;
}
......@@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
} // namespace
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def) {
const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr;
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg));
if (reg == nullptr) {
......@@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
" devices compatible with node ",
SummarizeNodeDef(node_def));
}
*def = &reg->def;
if (def != nullptr) *def = &reg->def;
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
return Status::OK();
}
......
......@@ -1025,7 +1025,6 @@ namespace register_kernel {
typedef ::tensorflow::KernelDefBuilder Name;
} // namespace register_kernel
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
......@@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name;
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
SHOULD_REGISTER_OP_KERNEL(__FILE__) \
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
void* GlobalKernelRegistry();
// If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def.
// returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def);
const KernelDef** def, string* kernel_class_name);
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
// registered with the current library's global kernel registry (obtained by
......@@ -1058,16 +1059,19 @@ namespace kernel_factory {
class OpKernelRegistrar {
public:
typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) {
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
// Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) {
InitInternal(kernel_def, factory);
InitInternal(kernel_def, kernel_class_name, factory);
}
}
private:
void InitInternal(const KernelDef* kernel_def, Factory factory);
void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory);
};
} // namespace kernel_factory
......
......@@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test {
}
}
}
string GetKernelClassName(const string& op_type, DeviceType device_type,
const std::vector<string>& attrs,
DataTypeSlice input_types = {}) {
NodeDef def = CreateNodeDef(op_type, attrs);
for (size_t i = 0; i < input_types.size(); ++i) {
def.add_input("a:0");
}
const KernelDef* kernel_def = nullptr;
string kernel_class_name;
const Status status =
FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
if (status.ok()) {
return kernel_class_name;
} else if (errors::IsNotFound(status)) {
return "not found";
} else {
return status.ToString();
}
}
};
REGISTER_OP("BuildCPU");
......@@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
TEST_F(OpKernelBuilderTest, BuilderCPU) {
ExpectSuccess("BuildCPU", DEVICE_CPU, {});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
}
REGISTER_OP("BuildGPU");
......@@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_BOOL, DT_BOOL]"});
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_FLOAT]"}));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
EXPECT_TRUE(
StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
.contains("Invalid argument: "));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
error::INVALID_ARGUMENT);
}
......@@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(0, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<0>",
GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
}
TEST_F(LabelTest, Specified) {
......@@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(1, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
{"_kernel|string|'one'"}));
}
TEST_F(LabelTest, Duplicate) {
......
......@@ -34,13 +34,12 @@ limitations under the License.
// out.
#include "ops_to_register.h"
// Files which are not included in the whitelist provided by this
// graph-specific header file will not be allowed to register their
// operator kernels.
#define SHOULD_REGISTER_OP_KERNEL(filename) \
(strstr(kNecessaryOpFiles, filename) != nullptr)
// Op kernel classes for which ShouldRegisterOpKernel returns false will not be
// registered.
#define SHOULD_REGISTER_OP_KERNEL(clz) \
(strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)
// Ops for which ShouldRegisterOp return false will no be registered.
// Ops for which ShouldRegisterOp returns false will not be registered.
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
// If kRequiresSymbolicGradients is false, then no gradient ops are registered.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册