提交 e30d80d4 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Change selective registration of op kernels to work off of the class name

instead of the filename.

Change FindKernelDef to also return the class name, to help a tool that
generates ops_to_register.h to find the set of class names.
Change: 118381472
上级 69eaa5d7
...@@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, ...@@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// Look up the Kernel registered for this node def. // Look up the Kernel registered for this node def.
const KernelDef* kdef = nullptr; const KernelDef* kdef = nullptr;
status = FindKernelDef(device_type, ndef, &kdef); status =
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
if (!status.ok() || HasTypeList(*op_def)) { if (!status.ok() || HasTypeList(*op_def)) {
// When there is no kernel def for this op or the op's arg is a // When there is no kernel def for this op or the op's arg is a
......
...@@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, ...@@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
// OpKernel registration ------------------------------------------------------ // OpKernel registration ------------------------------------------------------
struct KernelRegistration { struct KernelRegistration {
KernelRegistration(const KernelDef& d, KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f) kernel_factory::OpKernelRegistrar::Factory f)
: def(d), factory(f) {} : def(d), kernel_class_name(c.ToString()), factory(f) {}
const KernelDef def; const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory; const kernel_factory::OpKernelRegistrar::Factory factory;
}; };
...@@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type, ...@@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type,
namespace kernel_factory { namespace kernel_factory {
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
Factory factory) { Factory factory) {
const string key = const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()), Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label()); kernel_def->label());
GlobalKernelRegistryTyped()->insert( GlobalKernelRegistryTyped()->insert(std::make_pair(
std::make_pair(key, KernelRegistration(*kernel_def, factory))); key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
delete kernel_def; delete kernel_def;
} }
...@@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, ...@@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
} // namespace } // namespace
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def) { const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr; const KernelRegistration* reg = nullptr;
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg)); TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg));
if (reg == nullptr) { if (reg == nullptr) {
...@@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, ...@@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
" devices compatible with node ", " devices compatible with node ",
SummarizeNodeDef(node_def)); SummarizeNodeDef(node_def));
} }
*def = &reg->def; if (def != nullptr) *def = &reg->def;
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
return Status::OK(); return Status::OK();
} }
......
...@@ -1025,7 +1025,6 @@ namespace register_kernel { ...@@ -1025,7 +1025,6 @@ namespace register_kernel {
typedef ::tensorflow::KernelDefBuilder Name; typedef ::tensorflow::KernelDefBuilder Name;
} // namespace register_kernel } // namespace register_kernel
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
...@@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name; ...@@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name;
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
static ::tensorflow::kernel_factory::OpKernelRegistrar \ static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \ registrar__body__##ctr##__object( \
SHOULD_REGISTER_OP_KERNEL(__FILE__) \ SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
? ::tensorflow::register_kernel::kernel_builder.Build() \ ? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \ : nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \ [](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
void* GlobalKernelRegistry(); void* GlobalKernelRegistry();
// If node_def has a corresponding kernel registered on device_type, // If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def. // returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def); const KernelDef** def, string* kernel_class_name);
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k' // Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
// registered with the current library's global kernel registry (obtained by // registered with the current library's global kernel registry (obtained by
...@@ -1058,16 +1059,19 @@ namespace kernel_factory { ...@@ -1058,16 +1059,19 @@ namespace kernel_factory {
class OpKernelRegistrar { class OpKernelRegistrar {
public: public:
typedef OpKernel* (*Factory)(OpKernelConstruction*); typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) {
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
// Perform the check in the header to allow compile-time optimization // Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols. // to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) { if (kernel_def != nullptr) {
InitInternal(kernel_def, factory); InitInternal(kernel_def, kernel_class_name, factory);
} }
} }
private: private:
void InitInternal(const KernelDef* kernel_def, Factory factory); void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory);
}; };
} // namespace kernel_factory } // namespace kernel_factory
......
...@@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test { ...@@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test {
} }
} }
} }
string GetKernelClassName(const string& op_type, DeviceType device_type,
const std::vector<string>& attrs,
DataTypeSlice input_types = {}) {
NodeDef def = CreateNodeDef(op_type, attrs);
for (size_t i = 0; i < input_types.size(); ++i) {
def.add_input("a:0");
}
const KernelDef* kernel_def = nullptr;
string kernel_class_name;
const Status status =
FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
if (status.ok()) {
return kernel_class_name;
} else if (errors::IsNotFound(status)) {
return "not found";
} else {
return status.ToString();
}
}
}; };
REGISTER_OP("BuildCPU"); REGISTER_OP("BuildCPU");
...@@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); ...@@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
TEST_F(OpKernelBuilderTest, BuilderCPU) { TEST_F(OpKernelBuilderTest, BuilderCPU) {
ExpectSuccess("BuildCPU", DEVICE_CPU, {}); ExpectSuccess("BuildCPU", DEVICE_CPU, {});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
} }
REGISTER_OP("BuildGPU"); REGISTER_OP("BuildGPU");
...@@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") ...@@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_BOOL, DT_BOOL]"}); {"T|list(type)|[DT_BOOL, DT_BOOL]"});
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
error::NOT_FOUND); error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_FLOAT]"}));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
EXPECT_TRUE(
StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
.contains("Invalid argument: "));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
error::INVALID_ARGUMENT); error::INVALID_ARGUMENT);
} }
...@@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) { ...@@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(0, get_labeled_kernel->Which()); EXPECT_EQ(0, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<0>",
GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
} }
TEST_F(LabelTest, Specified) { TEST_F(LabelTest, Specified) {
...@@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) { ...@@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(1, get_labeled_kernel->Which()); EXPECT_EQ(1, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
{"_kernel|string|'one'"}));
} }
TEST_F(LabelTest, Duplicate) { TEST_F(LabelTest, Duplicate) {
......
...@@ -34,13 +34,12 @@ limitations under the License. ...@@ -34,13 +34,12 @@ limitations under the License.
// out. // out.
#include "ops_to_register.h" #include "ops_to_register.h"
// Files which are not included in the whitelist provided by this // Op kernel classes for which ShouldRegisterOpKernel returns false will not be
// graph-specific header file will not be allowed to register their // registered.
// operator kernels. #define SHOULD_REGISTER_OP_KERNEL(clz) \
#define SHOULD_REGISTER_OP_KERNEL(filename) \ (strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)
(strstr(kNecessaryOpFiles, filename) != nullptr)
// Ops for which ShouldRegisterOp return false will no be registered. // Ops for which ShouldRegisterOp returns false will not be registered.
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) #define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
// If kRequiresSymbolicGradients is false, then no gradient ops are registered. // If kRequiresSymbolicGradients is false, then no gradient ops are registered.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册