diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 01b8b22d1616635283cf234f3b2afaee3dc16214..878e8b12bfd8dd95438829adeb4aa7bab7c6b9c2 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, // Look up the Kernel registered for this node def. const KernelDef* kdef = nullptr; - status = FindKernelDef(device_type, ndef, &kdef); + status = + FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); if (!status.ok() || HasTypeList(*op_def)) { // When there is no kernel def for this op or the op's arg is a diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c984c8dc933d3212618476fb8c5c5ce8563bf42a..f3d60cc535f68820ef576e95e8cbd237a8813dfc 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, // OpKernel registration ------------------------------------------------------ struct KernelRegistration { - KernelRegistration(const KernelDef& d, + KernelRegistration(const KernelDef& d, StringPiece c, kernel_factory::OpKernelRegistrar::Factory f) - : def(d), factory(f) {} + : def(d), kernel_class_name(c.ToString()), factory(f) {} const KernelDef def; + const string kernel_class_name; const kernel_factory::OpKernelRegistrar::Factory factory; }; @@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type, namespace kernel_factory { void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, + StringPiece kernel_class_name, Factory factory) { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); - GlobalKernelRegistryTyped()->insert( - std::make_pair(key, KernelRegistration(*kernel_def, factory))); + GlobalKernelRegistryTyped()->insert(std::make_pair( + key, KernelRegistration(*kernel_def, kernel_class_name, factory))); delete kernel_def; } @@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, } // namespace Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, - const KernelDef** def) { + const KernelDef** def, string* kernel_class_name) { const KernelRegistration* reg = nullptr; TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, ®)); if (reg == nullptr) { @@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, " devices compatible with node ", SummarizeNodeDef(node_def)); } - *def = ®->def; + if (def != nullptr) *def = ®->def; + if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; return Status::OK(); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 6cbd3069010cc8387a55f78168c1b18391d86990..61d15edf7b10173f0d3e5b50335f23058ed65cc6 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1025,7 +1025,6 @@ namespace register_kernel { typedef ::tensorflow::KernelDefBuilder Name; } // namespace register_kernel - #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) @@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name; #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__##ctr##__object( \ - SHOULD_REGISTER_OP_KERNEL(__FILE__) \ + SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \ ? ::tensorflow::register_kernel::kernel_builder.Build() \ : nullptr, \ + #__VA_ARGS__, \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) void* GlobalKernelRegistry(); // If node_def has a corresponding kernel registered on device_type, -// returns OK and fill in the kernel def. +// returns OK and fill in the kernel def and kernel_class_name. and +// may be null. Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, - const KernelDef** def); + const KernelDef** def, string* kernel_class_name); // Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k' // registered with the current library's global kernel registry (obtained by @@ -1058,16 +1059,19 @@ namespace kernel_factory { class OpKernelRegistrar { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); - OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) { + + OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, + Factory factory) { // Perform the check in the header to allow compile-time optimization // to a no-op, allowing the linker to remove the kernel symbols. if (kernel_def != nullptr) { - InitInternal(kernel_def, factory); + InitInternal(kernel_def, kernel_class_name, factory); } } private: - void InitInternal(const KernelDef* kernel_def, Factory factory); + void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, + Factory factory); }; } // namespace kernel_factory diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 4275ad0f73aef728a5f8d0feb0f4dabef52094fa..6d3cfb0c920478bf12ce54f350c13e6b202d05ba 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test { } } } + + string GetKernelClassName(const string& op_type, DeviceType device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + const KernelDef* kernel_def = nullptr; + string kernel_class_name; + const Status status = + FindKernelDef(device_type, def, &kernel_def, &kernel_class_name); + if (status.ok()) { + return kernel_class_name; + } else if (errors::IsNotFound(status)) { + return "not found"; + } else { + return status.ToString(); + } + } }; REGISTER_OP("BuildCPU"); @@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); TEST_F(OpKernelBuilderTest, BuilderCPU) { ExpectSuccess("BuildCPU", DEVICE_CPU, {}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {})); ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); + EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {})); } REGISTER_OP("BuildGPU"); @@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[]"})); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[]"})); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL, DT_BOOL]"}); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, error::NOT_FOUND); + EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_FLOAT]"})); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + EXPECT_TRUE( + StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {})) + .contains("Invalid argument: ")); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, error::INVALID_ARGUMENT); } @@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) { ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); auto* get_labeled_kernel = static_cast(op_kernel.get()); EXPECT_EQ(0, get_labeled_kernel->Which()); + + EXPECT_EQ("LabeledKernel<0>", + GetKernelClassName("LabeledKernel", DEVICE_CPU, {})); } TEST_F(LabelTest, Specified) { @@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) { ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); auto* get_labeled_kernel = static_cast(op_kernel.get()); EXPECT_EQ(1, get_labeled_kernel->Which()); + EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU, + {"_kernel|string|'one'"})); } TEST_F(LabelTest, Duplicate) { diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h index 78fe5b7d9f88b84ce0fda0e39a70128dd707532d..c0c37fc987799250caf632977a541fa877a1ad97 100644 --- a/tensorflow/core/framework/selective_registration.h +++ b/tensorflow/core/framework/selective_registration.h @@ -34,13 +34,12 @@ limitations under the License. // out. #include "ops_to_register.h" -// Files which are not included in the whitelist provided by this -// graph-specific header file will not be allowed to register their -// operator kernels. -#define SHOULD_REGISTER_OP_KERNEL(filename) \ - (strstr(kNecessaryOpFiles, filename) != nullptr) +// Op kernel classes for which ShouldRegisterOpKernel returns false will not be +// registered. +#define SHOULD_REGISTER_OP_KERNEL(clz) \ + (strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr) -// Ops for which ShouldRegisterOp return false will no be registered. +// Ops for which ShouldRegisterOp returns false will not be registered. #define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) // If kRequiresSymbolicGradients is false, then no gradient ops are registered.