diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index be9c01fb04f4428b5754c3d963b079ca347c45ee..5f826aeb8371b944bdb8db8c842f5555e41c9549 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -61,3 +61,5 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) + +cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context) diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h index 7429de7ee39297c26360984809e2451100f7b3ff..7d7a444cf02ba0da88178a34c98f5ef7a95c3852 100644 --- a/paddle/framework/data_layout.h +++ b/paddle/framework/data_layout.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include +#include "paddle/platform/enforce.h" + namespace paddle { namespace framework { @@ -33,5 +36,23 @@ inline DataLayout StringToDataLayout(const std::string& str) { } } +inline std::string DataLayoutToString(const DataLayout& data_layout) { + switch (data_layout) { + case kNHWC: + return "NHWC"; + case kNCHW: + return "NCHW"; + case kAnyLayout: + return "ANY_LAYOUT"; + default: + PADDLE_THROW("unknown DataLayou %d", data_layout); + } +} + +inline std::ostream& operator<<(std::ostream& out, DataLayout l) { + out << DataLayoutToString(l); + return out; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h index 49b273656bf57f183209e3d0996358da28ec0e7a..aa66cf00f3be14c4ac9496326190428688fc3496 100644 --- a/paddle/framework/library_type.h +++ b/paddle/framework/library_type.h @@ -22,5 +22,23 @@ namespace framework { enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 }; +inline std::string LibraryTypeToString(const LibraryType& library_type) { + switch (library_type) { + case kPlain: + return "PLAIN"; + case kMKLDNN: + return "MKLDNN"; + case kCUDNN: + return "CUDNN"; + default: + PADDLE_THROW("unknown LibraryType %d", library_type); + } +} + +inline std::ostream& operator<<(std::ostream& out, LibraryType l) { + out << LibraryTypeToString(l); + return out; +} + } // namespace } // framework diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index a1dea0d9d864881ef1f60b117dfaa02da3aa4275..e9c45b958cd0a65bca62099324e951f298d9ecb1 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/data_layout.h" #include "paddle/framework/data_type.h" #include "paddle/framework/library_type.h" +#include "paddle/platform/device_context.h" #include "paddle/platform/place.h" namespace paddle { @@ -68,5 +69,13 @@ struct OpKernelType { } }; +inline std::ostream& operator<<(std::ostream& os, + const OpKernelType& kernel_key) { + os << "data_type[" << kernel_key.data_type_ << "]:data_layout[" + << kernel_key.data_layout_ << "]:place[" << kernel_key.place_ + << "]:library_type[" << kernel_key.library_type_ << "]"; + return os; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..899676b5c1a3799427d6ef03bae47284550c781b --- /dev/null +++ b/paddle/framework/op_kernel_type_test.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_kernel_type.h" +#include +#include + +TEST(OpKernelType, ToString) { + using OpKernelType = paddle::framework::OpKernelType; + using DataType = paddle::framework::proto::DataType; + using CPUPlace = paddle::platform::CPUPlace; + using DataLayout = paddle::framework::DataLayout; + using LibraryType = paddle::framework::LibraryType; + + OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW, + LibraryType::kCUDNN); + + std::ostringstream stream; + stream << op_kernel_type; + ASSERT_EQ( + stream.str(), + "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]"); +} + +TEST(OpKernelType, Hash) { + using OpKernelType = paddle::framework::OpKernelType; + using DataType = paddle::framework::proto::DataType; + using CPUPlace = paddle::platform::CPUPlace; + using GPUPlace = paddle::platform::GPUPlace; + using DataLayout = paddle::framework::DataLayout; + using LibraryType = paddle::framework::LibraryType; + + OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, + LibraryType::kCUDNN); + OpKernelType op_kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, + LibraryType::kCUDNN); + + OpKernelType::Hash hasher; + ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2)); +} \ No newline at end of file diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 06184f6ba968c438f6baa571d7a5c12a69109c84..f147cc5a6e233ff88208f4f966ab93e8737b6f13 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -242,13 +242,6 @@ std::vector ExecutionContext::MultiOutput( return res; } -std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) { - os << "data_type[" << kernel_key.data_type_ << "]:data_layout[" - << kernel_key.data_layout_ << "]:place[" << kernel_key.place_ - << "]:library_type[" << kernel_key.library_type_ << "]"; - return os; -} - bool OpSupportGPU(const std::string& op_type) { auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto it = all_kernels.find(op_type); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index aba34c5bcb81c85db21e9d82894fc0b937c3c060..b592eea1b96113a8cbcb0e137890927cfcc22670 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -381,8 +381,6 @@ class OperatorWithKernel : public OperatorBase { proto::DataType IndicateDataType(const ExecutionContext& ctx) const; }; -std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key); - extern bool OpSupportGPU(const std::string& op_type); } // namespace framework