diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 5bf70d1126b892d6fd46450aec3b84c1f3b8493b..215c81a00e89a8a267b203faad30fcbd9cce00ab 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -14,11 +14,15 @@ limitations under the License. */ #include #include +#include #include +#include #include #include #include // NOLINT // for call_once #include +#include +#include #include #include #include @@ -189,6 +193,64 @@ bool SupportsBfloat16FastPerformance() { #endif } +// According to the input `place` and `dtype`, this function returns a tuple +// consists of three sets: +// 1) All operators registered in the Paddle framework. +// 2) All operators supported for `place` and `dtype`. +// 3) All operators unsupported for `place` and `dtype`. +// The input `place` is a type of string, which can only be `GPU` or `CPU`. +// The input `dtype` is a type of paddle::framework::proto::VarType::Type, +// which can be paddle::framework::proto::VarType::FP16, +// paddle::framework::proto::VarType::FP32 and so on. +std::tuple, std::unordered_set, + std::unordered_set> +OpSupportedInfos(const std::string &place, + framework::proto::VarType::Type dtype) { + std::string query_place; + std::transform(place.begin(), place.end(), std::back_inserter(query_place), + [](unsigned char c) { return std::toupper(c); }); + using fn_type = std::add_pointer::type; + std::unordered_map is_target_place{ + {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, + }; + PADDLE_ENFORCE_NE( + is_target_place.count(query_place), 0, + platform::errors::InvalidArgument( + "The argument `place` should be 'GPU' or 'CPU', but get '%s'.", + place)); + + std::unordered_set all_ops; + const auto &op_info = framework::OpInfoMap::Instance().map(); + for (auto it = op_info.begin(); it != op_info.end(); it++) { + all_ops.emplace(it->first); + } + + std::unordered_set supported_ops; + auto &all_kernels = framework::OperatorWithKernel::AllOpKernels(); + for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) { + for (auto &kernel_type : it->second) { + if (is_target_place[query_place](kernel_type.first.place_) && + kernel_type.first.data_type_ == dtype) { + supported_ops.emplace(it->first); + } + } + } + + std::unordered_set unsupported_ops; + for (auto &op : all_ops) { + if (!supported_ops.count(op)) { + unsupported_ops.emplace(op); + } + } + + VLOG(4) << "-- The size of all_ops: " << all_ops.size() << " --"; + VLOG(4) << "-- The size of supported_ops: " << supported_ops.size() << " --"; + VLOG(4) << "-- The size of unsupported_ops: " << unsupported_ops.size() + << " --"; + return std::make_tuple(std::move(all_ops), std::move(supported_ops), + std::move(unsupported_ops)); +} + bool IsCompiledWithBrpc() { #ifndef PADDLE_WITH_DISTRIBUTE return false; @@ -1770,6 +1832,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); + m.def("op_supported_infos", OpSupportedInfos); m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_dist", IsCompiledWithDIST); m.def("_cuda_synchronize", [](const platform::CUDAPlace &place) { diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 6a524af4ee240fa4a64bcf203278e45d38994393..f940f6a3143a09fa82d4e10fba38f7d86b9c025d 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +from ... import core __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] @@ -147,147 +148,10 @@ gray_list = { } # The set of ops that don't support fp16 calculation -unsupported_fp16_list = { - # from python/paddle/fluid/layers/io.py - 'send', - 'send_barrier', - 'recv', - 'fetch_barrier', - 'create_py_reader', - 'create_double_buffer_reader', - 'read', - 'load', - - # from python/paddle/fluid/control_flow.py - 'increment', - 'less_than', - 'less_equal', - 'greater_than', - 'greater_equal', - 'equal', - 'not_equal', - 'read_from_array', - 'shrink_rnn_memory', - 'lod_array_length', - 'logical_and', - 'logical_or', - 'logical_xor', - 'logical_not', - 'print', - 'conditional_block', - 'while', - 'ifelse', - 'is_empty', - 'lstm', - 'cudnn_lstm', - 'lstmp', - 'gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'bpr_loss', - 'chunk_eval', - 'sequence_conv', - 'sequence_softmax', - # Depthwise conv2d isn't fast and safe currently. - # ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79 - 'depthwise_conv2d', - # Tensor Core kernels are not available for 3D convolutions currently. - 'conv3d', - 'sequence_pool', - 'sequence_concat', - 'sequence_slice', - 'data_norm', - 'group_norm', - 'spectral_norm', - 'depthwise_conv2d_transpose', - 'sequence_expand', - 'conv_transposed2d', - 'conv_transposed3d', - 'sequence_expand_as', - 'sequence_pad', - 'sequence_unpad', - 'sequence_erase', - 'beam_search', - 'beam_search_decode', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'reduce_all', - 'reduce_any', - 'split', - 'edit_distance', - 'ctc_align', - 'warpctc', - 'sequence_reshape', - 'nce', - 'hierarchical_sigmoid', - 'im2sequence', - 'row_conv', - 'multiplex', - 'sample_logits', - 'one_hot', - 'smooth_l1_loss', - 'squeeze2', - 'unsqueeze2', - 'lod_reset', - 'lrn', - 'pad', - 'pad_constant_like', - 'label_smooth', - 'scatter', - 'sequence_scatter', - 'random_crop', - 'mean_iou', - 'selu', - 'crop', - 'affine_grid', - 'rank_loss', - 'margin_rank_loss', - 'pad2d', - 'elu', - 'pow', - 'stanh', - 'hard_sigmoid', - 'swish', - 'prelu', - 'brelu', - 'sequence_enumerate', - 'sequence_mask', - 'expand', - 'sampling_id', - 'maxout', - 'space_to_depth', - 'sequence_reverse', - 'similarity_focus', - 'hash', - 'grid_sampler', - 'log_loss', - 'teacher_student_sigmoid_loss', - 'add_position_encoding', - 'bilinear_tensor_product', - 'shuffle_channel', - 'temporal_shift', - 'psroi_pool', - 'huber_loss', - 'kldiv_loss', - 'tree_conv', - 'pixel_shuffle', - 'fsp', - 'cvm', - 'affine_channel', - 'roi_pool', - 'roi_align', - 'anchor_generator', - 'generate_proposals', - 'generate_proposal_labels', - 'generate_mask_labels', - # fp16 is slower than fp32, though fp16 is supported. - 'lookup_table', - 'lookup_table_v2', -} +# lookup_table fp16 is slower than fp32, though fp16 is supported. +_, _, _sys_unsupported_fp16_list = core.op_supported_infos( + 'GPU', core.VarDesc.VarType.FP16) +unsupported_fp16_list = {'lookup_table', + 'lookup_table_v2'} | _sys_unsupported_fp16_list CustomOpLists = AutoMixedPrecisionLists diff --git a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py index b190a5d02efc4ce34a7062f1bf3e2ad1939c9399..850b267411ed5d98d21f8dd0cc14ad76fd9b641c 100644 --- a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py +++ b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py @@ -258,7 +258,8 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): cast_model_to_fp16(main_prog, use_fp16_guard=False) def test_non_iterable_dataloader(self): - self.decorate_with_data_loader() + if fluid.core.is_compiled_with_cuda(): + self.decorate_with_data_loader() if __name__ == '__main__':