From 35c1683e803462d1ae78c49a8c8fb392ff6e2d32 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 27 Dec 2017 15:23:14 +0800 Subject: [PATCH] "refine kernel registrar" (#6998) * "refine kernel registrar" * "refine registrar with multikey" * "fix register" * "refine multikernel register" * "fix CI" * "fix CI" * "fix registry" * "swtich GPU to CUDA" * "add register macro test case" * "fix CI" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/library_type.h | 27 +++++++- paddle/framework/op_kernel_type_test.cc | 2 +- paddle/framework/op_registry.h | 16 ++--- paddle/framework/op_registry_test.cc | 82 +++++++++++++++++++++++++ paddle/operators/conv_cudnn_op.cu.cc | 4 ++ 6 files changed, 122 insertions(+), 11 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 7436e8c228..738684795d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -37,7 +37,7 @@ cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h index 6baae6c2bb..7707799cae 100644 --- a/paddle/framework/library_type.h +++ b/paddle/framework/library_type.h @@ -20,7 +20,11 @@ namespace framework { // For more details about the design of LibraryType, Please refer to // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library -enum class LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 }; +enum class LibraryType { + kPlain = 0, + kMKLDNN = 1, + kCUDNN = 2, +}; inline std::string LibraryTypeToString(const LibraryType& library_type) { switch (library_type) { @@ -31,7 +35,26 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) { case LibraryType::kCUDNN: return "CUDNN"; default: - PADDLE_THROW("unknown LibraryType %d", library_type); + PADDLE_THROW("unknown LibraryType %d", static_cast(library_type)); + } +} + +inline LibraryType StringToLibraryType(const char* ctype) { + std::string s(ctype); + if (s == std::string("PLAIN")) { + return LibraryType::kPlain; + } else if (s == std::string("MKLDNN")) { + return LibraryType::kMKLDNN; + } else if (s == std::string("CUDNN")) { + return LibraryType::kCUDNN; + // To be compatible with register macro. + // CPU, CUDA, PLAIN are same library type. + } else if (s == std::string("CPU")) { + return LibraryType::kPlain; + } else if (s == std::string("CUDA")) { + return LibraryType::kPlain; + } else { + PADDLE_THROW("Unknown LibraryType %s", s.c_str()); } } diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc index 8753d7cc37..dd04840500 100644 --- a/paddle/framework/op_kernel_type_test.cc +++ b/paddle/framework/op_kernel_type_test.cc @@ -48,4 +48,4 @@ TEST(OpKernelType, Hash) { OpKernelType::Hash hasher; ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2)); -} \ No newline at end of file +} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 9bb2a3b5c2..bdaa259181 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor { using KERNEL_TYPE = typename std::tuple_element>::type; - void operator()(const char* op_type) const { + void operator()(const char* op_type, const char* library_type) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; - OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType()); + OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), + DataLayout::kAnyLayout, StringToLibraryType(library_type)); OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; - func(op_type); + func(op_type, library_type); } }; template struct OpKernelRegistrarFunctor { - void operator()(const char* op_type) const {} + void operator()(const char* op_type, const char* library_type) const {} }; // User can register many kernel in one place. The data type could be different. template class OpKernelRegistrar : public Registrar { public: - explicit OpKernelRegistrar(const char* op_type) { + explicit OpKernelRegistrar(const char* op_type, const char* library_type) { OpKernelRegistrarFunctor func; - func(op_type); + func(op_type, library_type); } }; @@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar { __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be called in global namespace"); \ static ::paddle::framework::OpKernelRegistrar \ - __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \ + __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \ + #DEVICE_TYPE); \ int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \ return 0; \ diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 4cdf6e0865..cef530c6e6 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/framework/op_registry.h" #include @@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) { using namespace paddle::framework; OperatorRegistrar reg("cos"); } + +namespace paddle { +namespace framework { + +class OpKernelTestMaker : public OpProtoAndCheckerMaker { + public: + OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddComment("NoGradOp, same input output. no Grad"); + } +}; + +class OpWithKernelTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(InferShapeContext* ctx) const override {} + + framework::OpKernelType GetActualKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); + } +}; + +template +class OpKernelTest : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const {} +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel, + paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestMaker); +REGISTER_OP_CPU_KERNEL( + op_with_kernel, + paddle::framework::OpKernelTest); + +REGISTER_OP_CUDA_KERNEL(op_with_kernel, + paddle::framework::OpKernelTest< + paddle::platform::CUDADeviceContext, float>); + +TEST(OperatorRegistrar, CPU) { + paddle::framework::proto::OpDesc op_desc; + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + op_desc.set_type("op_with_kernel"); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + + op->Run(scope, cpu_place); +} + +#ifdef PADDLE_WITH_CUDA +TEST(OperatorRegistrar, CUDA) { + paddle::framework::proto::OpDesc op_desc; + paddle::platform::CUDAPlace cuda_place(0); + paddle::framework::Scope scope; + + op_desc.set_type("op_with_kernel"); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + + op->Run(scope, cuda_place); +} +#endif diff --git a/paddle/operators/conv_cudnn_op.cu.cc b/paddle/operators/conv_cudnn_op.cu.cc index 08ff0db086..0aa7dd48ca 100644 --- a/paddle/operators/conv_cudnn_op.cu.cc +++ b/paddle/operators/conv_cudnn_op.cu.cc @@ -315,6 +315,10 @@ class CudnnConvGradOpKernel : public framework::OpKernel { } // namespace operators } // namespace paddle +REGISTER_OP_KERNEL(conv2d, CUDNN, paddle::platform::CUDAPlace, + paddle::operators::CudnnConvOpKernel, + paddle::operators::CudnnConvOpKernel); + REGISTER_OP_CUDA_KERNEL(conv2d_cudnn, paddle::operators::CudnnConvOpKernel, paddle::operators::CudnnConvOpKernel); -- GitLab