From 4383494f012b6613aa65496e7892ae3f0052ddd9 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 4 Jan 2023 12:36:26 +0800 Subject: [PATCH] [Unify KernelKey] change OpKernelType->KernelKey (#49138) * execute use kernel_key first * change OpKernelType->KernelKey * fix py3 compile error, remove redundant header files * fix build_strategy_test * fix DataType::RAW * fix custom_type test: operator_test.cc * fix transform place * fix backends_are_same_class * try fix place TransDataDevice * support all KernelKey * fix TransformData * fix place_are_same_class * fix merge * fix test_params_no_grad * fix specific place of GetExpectedKernelType * fix specific place of GetExpectedKernelType * fix GetKernelTypeForVar * fix dtype error * fix fetch_v2 * change GetKernelTypeForVar * fix interpreter * fix typo error * polish codes * polish codes * polish codes * fix conflict --- paddle/fluid/framework/custom_operator.cc | 14 +- .../framework/data_device_transform_test.cu | 10 +- .../fluid/framework/data_layout_transform.cc | 24 +-- .../fluid/framework/data_layout_transform.h | 7 +- .../framework/data_layout_transform_test.cc | 14 +- paddle/fluid/framework/data_transform.cc | 38 ++-- paddle/fluid/framework/data_transform.h | 7 +- paddle/fluid/framework/data_type.h | 6 + paddle/fluid/framework/data_type_transform.cc | 21 +- paddle/fluid/framework/data_type_transform.h | 4 +- .../framework/data_type_transform_test.cc | 49 ++--- .../framework/data_type_transform_test.cu | 53 ++--- .../framework/details/build_strategy_test.cc | 6 +- .../fluid/framework/infershape_utils_test.cc | 4 +- .../new_executor/interpreter/data_transfer.cc | 52 +++-- .../new_executor/interpreter/data_transfer.h | 31 +-- .../interpreter/interpreter_util.cc | 8 +- paddle/fluid/framework/op_kernel_type.h | 32 ++- paddle/fluid/framework/op_registry_test.cc | 14 +- paddle/fluid/framework/operator.cc | 151 +++++++------ paddle/fluid/framework/operator.h | 22 +- paddle/fluid/framework/operator_test.cc | 38 ++-- .../fluid/framework/transfer_scope_cache.cc | 7 +- paddle/fluid/framework/transfer_scope_cache.h | 4 +- paddle/fluid/imperative/CMakeLists.txt | 6 +- paddle/fluid/imperative/infer_shape_context.h | 11 +- paddle/fluid/imperative/layer.cc | 4 +- paddle/fluid/imperative/prepared_operator.cc | 200 +++++++++--------- paddle/fluid/imperative/prepared_operator.h | 25 ++- paddle/fluid/imperative/tests/test_eager.cc | 18 +- .../fluid/imperative/tests/test_prepare_op.cc | 6 +- paddle/fluid/imperative/var_helper.cc | 38 ++-- paddle/fluid/imperative/var_helper.h | 10 +- paddle/fluid/imperative/variable_wrapper.h | 10 +- paddle/fluid/operators/abs_op.cc | 21 +- paddle/fluid/operators/activation_op.cc | 50 +++-- .../operators/add_position_encoding_op.cc | 15 +- paddle/fluid/operators/affine_channel_op.cc | 8 +- paddle/fluid/operators/affine_grid_op.cc | 8 +- paddle/fluid/operators/allclose_op.cc | 7 +- .../operators/amp/alloc_float_status_op.cc | 5 +- .../amp/check_finite_and_unscale_op.cc | 4 +- .../operators/amp/clear_float_status_op.cc | 5 +- .../operators/amp/get_float_status_op.cc | 5 +- .../operators/amp/update_loss_scaling_op.cc | 12 +- paddle/fluid/operators/arg_min_max_op_base.h | 4 +- paddle/fluid/operators/ascend_trigger_op.cc | 6 +- paddle/fluid/operators/assign_op.cc | 21 +- paddle/fluid/operators/assign_pos_op.cc | 4 +- paddle/fluid/operators/assign_value_op.cc | 4 +- paddle/fluid/operators/attention_lstm_op.cc | 6 +- paddle/fluid/operators/attention_lstm_op.h | 2 +- .../fluid/operators/average_accumulates_op.cc | 6 +- paddle/fluid/operators/batch_fc_op.cc | 15 +- paddle/fluid/operators/batch_norm_op.cc | 40 ++-- paddle/fluid/operators/batch_norm_op.h | 14 +- paddle/fluid/operators/bce_loss_op.cc | 14 +- paddle/fluid/operators/beam_search_op.cc | 6 +- paddle/fluid/operators/bilateral_slice_op.cc | 14 +- paddle/fluid/operators/bincount_op.cc | 6 +- paddle/fluid/operators/bpr_loss_op.cc | 14 +- .../fluid/operators/broadcast_tensors_op.cc | 14 +- paddle/fluid/operators/cast_op.cc | 20 +- paddle/fluid/operators/center_loss_op.cc | 13 +- paddle/fluid/operators/chunk_eval_op.cc | 6 +- .../operators/cinn/cinn_instruction_run_op.cc | 5 +- paddle/fluid/operators/cinn/cinn_launch_op.cc | 5 +- .../fluid/operators/class_center_sample_op.cc | 7 +- paddle/fluid/operators/clip_op.cc | 8 +- paddle/fluid/operators/coalesce_tensor_op.cc | 14 +- .../operators/collective/allreduce_op.cc | 6 +- .../fluid/operators/collective/alltoall_op.cc | 6 +- .../operators/collective/c_allreduce_op.h | 18 +- .../operators/collective/c_broadcast_op.cc | 6 +- .../fluid/operators/collective/c_concat_op.cc | 6 +- .../operators/collective/c_embedding_op.cc | 8 +- .../operators/collective/c_identity_op.cc | 6 +- .../fluid/operators/collective/c_reduce_op.h | 6 +- .../operators/collective/c_scatter_op.cc | 6 +- .../c_softmax_with_cross_entropy_op.cc | 15 +- .../fluid/operators/collective/c_split_op.cc | 6 +- .../collective/c_sync_calc_stream_op.h | 8 +- .../collective/c_sync_comm_stream_op.cc | 5 +- .../operators/collective/global_gather_op.cc | 6 +- .../operators/collective/global_scatter_op.cc | 6 +- .../operators/collective/partial_recv_op.cc | 4 +- .../operators/collective/partial_send_op.cc | 6 +- .../fluid/operators/collective/recv_v2_op.cc | 4 +- .../fluid/operators/collective/send_v2_op.cc | 9 +- paddle/fluid/operators/concat_op.cc | 32 +-- .../fluid/operators/controlflow/bitwise_op.cc | 14 +- .../fluid/operators/controlflow/compare_op.cc | 11 +- .../operators/controlflow/fetch_v2_op.cc | 35 +-- .../fluid/operators/controlflow/logical_op.cc | 7 +- paddle/fluid/operators/conv_op.cc | 38 ++-- paddle/fluid/operators/conv_op.h | 14 +- paddle/fluid/operators/conv_transpose_op.cc | 25 ++- paddle/fluid/operators/conv_transpose_op.h | 10 +- paddle/fluid/operators/correlation_op.cc | 8 +- paddle/fluid/operators/crf_decoding_op.cc | 4 +- paddle/fluid/operators/crop_op.cc | 15 +- paddle/fluid/operators/cross_entropy_op.cc | 15 +- paddle/fluid/operators/ctc_align_op.cc | 7 +- paddle/fluid/operators/cudnn_lstm_op.cc | 15 +- paddle/fluid/operators/cum_op.cc | 4 +- paddle/fluid/operators/cvm_op.cc | 15 +- paddle/fluid/operators/data_norm_op.cc | 8 +- paddle/fluid/operators/decode_jpeg_op.cc | 19 +- paddle/fluid/operators/deformable_conv_op.cc | 14 +- .../fluid/operators/deformable_conv_v1_op.cc | 14 +- .../operators/deformable_psroi_pooling_op.cc | 14 +- .../fluid/operators/dequantize_abs_max_op.cc | 5 +- paddle/fluid/operators/dequantize_log_op.cc | 5 +- paddle/fluid/operators/dequantize_op.cc | 9 +- paddle/fluid/operators/dequantize_op.h | 4 +- .../detection/anchor_generator_op.cc | 7 +- .../operators/detection/bipartite_match_op.cc | 4 +- .../detection/collect_fpn_proposals_op.cc | 4 +- .../detection/density_prior_box_op.cc | 6 +- .../detection/distribute_fpn_proposals_op.cc | 4 +- .../detection/generate_mask_labels_op.cc | 4 +- .../detection/generate_proposal_labels_op.cc | 4 +- .../detection/generate_proposals_op.cc | 6 +- .../detection/generate_proposals_v2_op.cc | 6 +- .../detection/locality_aware_nms_op.cc | 4 +- .../operators/detection/matrix_nms_op.cc | 4 +- .../detection/mine_hard_examples_op.cc | 4 +- .../operators/detection/multiclass_nms_op.cc | 4 +- paddle/fluid/operators/detection/nms_op.cc | 6 +- .../fluid/operators/detection/prior_box_op.cc | 4 +- .../retinanet_detection_output_op.cc | 6 +- .../detection/roi_perspective_transform_op.cc | 14 +- .../detection/rpn_target_assign_op.cc | 8 +- .../detection/sigmoid_focal_loss_op.cc | 14 +- .../operators/detection/target_assign_op.cc | 7 +- .../fluid/operators/detection/yolo_box_op.cc | 6 +- .../operators/detection/yolov3_loss_op.cc | 14 +- paddle/fluid/operators/detection_map_op.cc | 4 +- paddle/fluid/operators/determinant_op.cc | 8 +- paddle/fluid/operators/dgc_clip_by_norm_op.cc | 8 +- paddle/fluid/operators/dgc_op.cc | 8 +- paddle/fluid/operators/dropout_op.cc | 26 +-- paddle/fluid/operators/edit_distance_op.cc | 6 +- paddle/fluid/operators/eigvalsh_op.cc | 6 +- paddle/fluid/operators/einsum_op.cc | 4 +- .../elementwise/elementwise_div_op.h | 19 +- .../elementwise/elementwise_mul_op.h | 19 +- .../operators/elementwise/elementwise_op.h | 103 ++++----- paddle/fluid/operators/empty_op.cc | 16 +- paddle/fluid/operators/expand_as_op.cc | 8 +- paddle/fluid/operators/expand_as_v2_op.cc | 15 +- paddle/fluid/operators/expand_op.cc | 39 ++-- paddle/fluid/operators/expand_v2_op.cc | 32 +-- paddle/fluid/operators/exponential_op.cc | 6 +- paddle/fluid/operators/eye_op.cc | 4 +- paddle/fluid/operators/fake_quantize_op.cc | 43 ++-- paddle/fluid/operators/fc_op.cc | 4 +- paddle/fluid/operators/fill_any_like_op.cc | 17 +- .../fill_constant_batch_size_like_op.cc | 8 +- paddle/fluid/operators/fill_constant_op.cc | 27 +-- paddle/fluid/operators/fill_diagonal_op.cc | 10 +- paddle/fluid/operators/fill_op.cc | 4 +- paddle/fluid/operators/fill_zeros_like_op.cc | 4 +- paddle/fluid/operators/filter_by_instag_op.cc | 8 +- paddle/fluid/operators/flatten_op.cc | 24 +-- paddle/fluid/operators/fsp_op.cc | 19 +- .../operators/fused/fused_attention_op.cc | 8 +- ...sed_bias_dropout_residual_layer_norm_op.cc | 8 +- .../operators/fused/fused_bn_activation_op.cc | 20 +- .../operators/fused/fused_bn_activation_op.h | 4 +- .../fused/fused_bn_add_activation_op.cc | 20 +- .../fused/fused_bn_add_activation_op.h | 4 +- .../fused/fused_elemwise_activation_op.cc | 14 +- .../fused_embedding_eltwise_layernorm_op.cc | 4 +- .../fused/fused_embedding_fc_lstm_op.cc | 6 +- .../fused/fused_embedding_fc_lstm_op.h | 2 +- .../fused/fused_embedding_seq_pool_op.cc | 8 +- .../operators/fused/fused_feedforward_op.cc | 10 +- .../operators/fused/fused_gemm_epilogue_op.cc | 12 +- .../fused/fused_multi_transformer_int8_op.cc | 18 +- .../fused/fused_multi_transformer_op.cc | 18 +- .../operators/fused/fused_seqpool_cvm_op.cc | 16 +- .../fused/fusion_conv_inception_op.cc | 7 +- .../fluid/operators/fused/fusion_group_op.cc | 6 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 4 +- paddle/fluid/operators/fused/fusion_gru_op.h | 2 +- .../fluid/operators/fused/fusion_lstm_op.cc | 4 +- paddle/fluid/operators/fused/fusion_lstm_op.h | 2 +- .../fused/fusion_repeated_fc_relu_op.cc | 6 +- .../fused/fusion_repeated_fc_relu_op.h | 2 +- .../fused/fusion_seqconv_eltadd_relu_op.cc | 6 +- .../fused/fusion_seqconv_eltadd_relu_op.h | 2 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 6 +- .../fused/fusion_seqexpand_concat_fc_op.h | 2 +- .../fused/fusion_seqpool_concat_op.cc | 6 +- .../fused/fusion_seqpool_concat_op.h | 2 +- .../fused/fusion_seqpool_cvm_concat_op.cc | 6 +- .../fused/fusion_seqpool_cvm_concat_op.h | 2 +- .../fused/fusion_squared_mat_sub_op.cc | 6 +- .../fused/fusion_squared_mat_sub_op.h | 2 +- paddle/fluid/operators/fused/multi_gru_op.cc | 11 +- paddle/fluid/operators/fused/multi_gru_op.h | 2 +- .../operators/fused/resnet_basic_block_op.cc | 19 +- .../fluid/operators/fused/resnet_unit_op.cc | 19 +- paddle/fluid/operators/gather_op.cc | 39 ++-- .../gaussian_random_batch_size_like_op.cc | 4 +- paddle/fluid/operators/gaussian_random_op.cc | 16 +- .../generator/templates/operator_utils.c.j2 | 4 +- .../get_tensor_from_selected_rows_op.cc | 7 +- .../fluid/operators/graph_khop_sampler_op.cc | 7 +- paddle/fluid/operators/graph_reindex_op.cc | 7 +- .../operators/graph_sample_neighbors_op.cc | 7 +- paddle/fluid/operators/graph_send_recv_op.cc | 15 +- .../fluid/operators/graph_send_ue_recv_op.cc | 15 +- paddle/fluid/operators/group_norm_op.cc | 6 +- paddle/fluid/operators/gru_op.cc | 8 +- paddle/fluid/operators/gru_unit_op.cc | 8 +- .../operators/hierarchical_sigmoid_op.cc | 12 +- paddle/fluid/operators/identity_loss_op.cc | 11 +- paddle/fluid/operators/imag_op.cc | 4 +- paddle/fluid/operators/increment_op.cc | 7 +- paddle/fluid/operators/index_add_op.cc | 14 +- paddle/fluid/operators/inplace_abn_op.cc | 15 +- paddle/fluid/operators/instance_norm_op.cc | 16 +- paddle/fluid/operators/instance_norm_op.h | 6 +- paddle/fluid/operators/interpolate_op.cc | 41 ++-- paddle/fluid/operators/interpolate_v2_op.cc | 41 ++-- paddle/fluid/operators/isfinite_op.cc | 6 +- paddle/fluid/operators/kldiv_loss_op.cc | 14 +- paddle/fluid/operators/kron_op.cc | 38 ++-- paddle/fluid/operators/layer_norm_op.cc | 16 +- .../fluid/operators/limit_by_capacity_op.cc | 4 +- paddle/fluid/operators/linear_chain_crf_op.cc | 13 +- paddle/fluid/operators/linspace_op.cc | 16 +- paddle/fluid/operators/load_combine_op.cc | 6 +- paddle/fluid/operators/load_combine_op.h | 11 +- paddle/fluid/operators/load_op.cc | 6 +- paddle/fluid/operators/load_op_npu.cc | 10 +- paddle/fluid/operators/lod_reset_op.cc | 25 ++- paddle/fluid/operators/log_softmax_op.cc | 12 +- paddle/fluid/operators/logspace_op.cc | 4 +- .../operators/lookup_table_dequant_op.cc | 4 +- paddle/fluid/operators/lookup_table_op.cc | 8 +- paddle/fluid/operators/lookup_table_v2_op.cc | 8 +- paddle/fluid/operators/lrn_op.cc | 34 ++- paddle/fluid/operators/lstm_op.cc | 14 +- paddle/fluid/operators/lstmp_op.cc | 13 +- paddle/fluid/operators/lstsq_op.cc | 4 +- paddle/fluid/operators/lu_op.cc | 10 +- .../operators/margin_cross_entropy_op.cc | 14 +- paddle/fluid/operators/marker_op.cc | 5 +- paddle/fluid/operators/matmul_op.cc | 31 ++- paddle/fluid/operators/matmul_v2_op.cc | 46 ++-- paddle/fluid/operators/matrix_rank_op.cc | 6 +- paddle/fluid/operators/mean_iou_op.cc | 4 +- paddle/fluid/operators/mean_op.cc | 4 +- paddle/fluid/operators/memcpy_d2h_op.cc | 17 +- paddle/fluid/operators/memcpy_h2d_op.cc | 17 +- paddle/fluid/operators/memcpy_op.cc | 17 +- paddle/fluid/operators/meshgrid_op.cc | 12 +- paddle/fluid/operators/metrics/accuracy_op.cc | 6 +- paddle/fluid/operators/metrics/auc_op.cc | 6 +- .../operators/metrics/precision_recall_op.cc | 6 +- paddle/fluid/operators/moe_op.cc | 4 +- paddle/fluid/operators/mul_op.cc | 10 +- paddle/fluid/operators/multiplex_op.cc | 15 +- paddle/fluid/operators/nanmedian_op.cc | 14 +- paddle/fluid/operators/nce_op.cc | 14 +- paddle/fluid/operators/nop_op.cc | 5 +- paddle/fluid/operators/number_count_op.cc | 4 +- paddle/fluid/operators/one_hot_op.cc | 19 +- paddle/fluid/operators/one_hot_v2_op.cc | 19 +- .../fluid/operators/optimizers/adadelta_op.cc | 6 +- .../fluid/operators/optimizers/adagrad_op.cc | 6 +- paddle/fluid/operators/optimizers/adam_op.h | 16 +- .../fluid/operators/optimizers/adamax_op.cc | 6 +- .../optimizers/decayed_adagrad_op.cc | 6 +- .../operators/optimizers/dgc_momentum_op.cc | 8 +- .../distributed_fused_lamb_init_op.cc | 4 +- .../optimizers/distributed_fused_lamb_op.cc | 12 +- paddle/fluid/operators/optimizers/dpsgd_op.cc | 6 +- paddle/fluid/operators/optimizers/ftrl_op.cc | 4 +- paddle/fluid/operators/optimizers/lamb_op.cc | 16 +- .../operators/optimizers/lars_momentum_op.cc | 4 +- .../operators/optimizers/merged_adam_op.cc | 16 +- .../optimizers/merged_momentum_op.cc | 4 +- .../fluid/operators/optimizers/momentum_op.h | 4 +- .../pow2_decay_with_linear_warmup_op.cc | 4 +- .../optimizers/proximal_adagrad_op.cc | 6 +- .../operators/optimizers/proximal_gd_op.cc | 6 +- paddle/fluid/operators/optimizers/sgd_op.cc | 17 +- .../operators/optimizers/sparse_momentum_op.h | 4 +- paddle/fluid/operators/pad2d_op.cc | 35 ++- paddle/fluid/operators/pad3d_op.cc | 35 ++- .../fluid/operators/pad_constant_like_op.cc | 14 +- paddle/fluid/operators/pad_op.cc | 14 +- paddle/fluid/operators/partial_concat_op.cc | 12 +- paddle/fluid/operators/partial_sum_op.cc | 12 +- paddle/fluid/operators/pool_op.cc | 37 ++-- paddle/fluid/operators/pool_op.h | 12 +- paddle/fluid/operators/pool_with_index_op.cc | 15 +- .../operators/positive_negative_pair_op.cc | 7 +- paddle/fluid/operators/prelu_op.cc | 8 +- paddle/fluid/operators/prroi_pool_op.cc | 14 +- .../operators/prune_gate_by_capacity_op.cc | 4 +- .../pscore/distributed_lookup_table_op.cc | 4 +- .../pscore/distributed_push_sparse_op.cc | 4 +- .../operators/pscore/send_and_recv_op.cc | 4 +- paddle/fluid/operators/psroi_pool_op.cc | 14 +- .../operators/pull_box_extended_sparse_op.cc | 13 +- paddle/fluid/operators/pull_box_sparse_op.cc | 13 +- .../fluid/operators/pull_gpups_sparse_op.cc | 13 +- paddle/fluid/operators/pull_sparse_op.cc | 13 +- paddle/fluid/operators/pull_sparse_v2_op.cc | 13 +- paddle/fluid/operators/push_dense_op.cc | 5 +- paddle/fluid/operators/pyramid_hash_op.cc | 12 +- paddle/fluid/operators/quantize_linear_op.cc | 6 +- paddle/fluid/operators/quantize_op.cc | 10 +- paddle/fluid/operators/quantize_op.h | 4 +- paddle/fluid/operators/randint_op.cc | 4 +- paddle/fluid/operators/random_crop_op.cc | 7 +- paddle/fluid/operators/random_routing_op.cc | 4 +- paddle/fluid/operators/randperm_op.cc | 4 +- paddle/fluid/operators/range_op.cc | 12 +- paddle/fluid/operators/rank_attention_op.cc | 15 +- paddle/fluid/operators/read_file_op.cc | 6 +- paddle/fluid/operators/real_op.cc | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 27 +-- .../operators/reduce_ops/reduce_sum_op.cc | 11 +- .../operators/reduce_ops/reduce_sum_op.h | 10 +- .../fluid/operators/repeat_interleave_op.cc | 12 +- paddle/fluid/operators/requantize_op.cc | 10 +- paddle/fluid/operators/requantize_op.h | 4 +- paddle/fluid/operators/reshape_op.cc | 56 ++--- paddle/fluid/operators/reverse_op.cc | 4 +- paddle/fluid/operators/rnn_op.cc | 15 +- paddle/fluid/operators/roi_align_op.cc | 14 +- paddle/fluid/operators/roi_pool_op.cc | 14 +- paddle/fluid/operators/rrelu_op.cc | 6 +- paddle/fluid/operators/run_program_op.cc | 26 +-- paddle/fluid/operators/sample_logits_op.cc | 12 +- paddle/fluid/operators/save_combine_op.cc | 14 +- paddle/fluid/operators/save_combine_op.h | 10 +- paddle/fluid/operators/save_op.cc | 4 +- paddle/fluid/operators/save_op.h | 10 +- paddle/fluid/operators/scale_op.cc | 4 +- paddle/fluid/operators/seed_op.cc | 5 +- paddle/fluid/operators/segment_pool_op.cc | 15 +- .../sequence_ops/sequence_concat_op.cc | 8 +- .../sequence_ops/sequence_expand_as_op.cc | 14 +- .../sequence_ops/sequence_expand_op.cc | 14 +- .../sequence_ops/sequence_mask_op.cc | 19 +- .../operators/sequence_ops/sequence_pad_op.cc | 8 +- .../sequence_ops/sequence_pool_op.cc | 8 +- .../sequence_ops/sequence_scatter_op.cc | 15 +- .../sequence_ops/sequence_slice_op.cc | 15 +- .../sequence_ops/sequence_softmax_op.cc | 10 +- .../sequence_topk_avg_pooling_op.cc | 4 +- .../sequence_ops/sequence_unpad_op.cc | 8 +- paddle/fluid/operators/set_value_op.cc | 38 ++-- paddle/fluid/operators/shape_op.cc | 14 +- paddle/fluid/operators/shuffle_batch_op.cc | 16 +- paddle/fluid/operators/shuffle_channel_op.cc | 12 +- paddle/fluid/operators/similarity_focus_op.cc | 7 +- paddle/fluid/operators/size_op.cc | 12 +- paddle/fluid/operators/slice_op.cc | 71 ++++--- paddle/fluid/operators/softmax_op.cc | 10 +- .../softmax_with_cross_entropy_op.cc | 15 +- paddle/fluid/operators/space_to_depth_op.cc | 8 +- paddle/fluid/operators/sparse_attention_op.cc | 12 +- paddle/fluid/operators/spectral_norm_op.cc | 8 +- paddle/fluid/operators/split_op.cc | 26 +-- .../fluid/operators/squared_l2_distance_op.cc | 4 +- paddle/fluid/operators/squeeze_op.cc | 8 +- paddle/fluid/operators/stack_op.cc | 4 +- paddle/fluid/operators/stft_op.cc | 8 +- paddle/fluid/operators/strided_slice_op.cc | 56 ++--- .../operators/string/faster_tokenizer_op.cc | 16 +- paddle/fluid/operators/sum_op.cc | 17 +- paddle/fluid/operators/tdm_child_op.cc | 4 +- paddle/fluid/operators/tdm_sampler_op.cc | 4 +- .../teacher_student_sigmoid_loss_op.cc | 14 +- paddle/fluid/operators/temporal_shift_op.cc | 14 +- paddle/fluid/operators/tile_op.cc | 39 ++-- paddle/fluid/operators/top_k_op.cc | 15 +- paddle/fluid/operators/transfer_layout_op.cc | 12 +- paddle/fluid/operators/transpose_op.cc | 20 +- paddle/fluid/operators/tree_conv_op.cc | 12 +- paddle/fluid/operators/triangular_solve_op.cc | 6 +- paddle/fluid/operators/tril_indices_op.cc | 4 +- paddle/fluid/operators/triu_indices_op.cc | 4 +- .../operators/truncated_gaussian_random_op.cc | 10 +- .../uniform_random_batch_size_like_op.cc | 4 +- .../operators/uniform_random_inplace_op.cc | 6 +- paddle/fluid/operators/uniform_random_op.cc | 16 +- .../fluid/operators/unique_consecutive_op.cc | 6 +- paddle/fluid/operators/unique_op.cc | 11 +- .../fluid/operators/unique_with_counts_op.cc | 7 +- paddle/fluid/operators/unpool_op.cc | 28 ++- paddle/fluid/operators/unsqueeze_op.cc | 29 +-- paddle/fluid/operators/warpctc_op.cc | 19 +- paddle/fluid/operators/where_index_op.cc | 4 +- paddle/phi/core/kernel_factory.h | 18 ++ paddle/phi/core/utils/data_type.h | 5 + .../kernels/impl/searchsorted_kernel_impl.h | 6 +- 405 files changed, 2610 insertions(+), 2699 deletions(-) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index c34e727486..e0a6fbd37d 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -422,9 +422,9 @@ class CustomOperator : public OperatorWithKernel { * The RAW type is used here as the data type, indicating that * it can only be determined at runtime. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace()); + return phi::KernelKey(ctx.GetPlace()); } /** @@ -432,13 +432,13 @@ class CustomOperator : public OperatorWithKernel { * Because the kernel data type is RAW, we should skip the cast for * data type difference when PrepareData. */ - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const override { - return OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index 777d7a6770..9723fde1cc 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -47,15 +47,17 @@ class TestOpWithKernel : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { if (Attr("use_gpu")) { VLOG(3) << "force use gpu kernel"; - return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); + return phi::KernelKey(phi::Backend::GPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32); } else { VLOG(3) << "use default kernel"; - return OpKernelType(proto::VarType::FP32, - ctx.Input("input")->place()); + return phi::KernelKey(proto::VarType::FP32, + ctx.Input("input")->place()); } } }; diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 3c0e8d4f0e..73ce635f57 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -50,13 +50,14 @@ void CastDataLayout::apply() { } } -void TransDataLayout(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataLayout(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, - phi::DenseTensor* out) { + phi::DenseTensor* out, + const phi::Place& place) { PADDLE_ENFORCE( - platform::places_are_same_class(kernel_type_for_var.place_, - expected_kernel_type.place_), + backends_are_same_class(kernel_type_for_var.backend(), + expected_kernel_type.backend()), platform::errors::PreconditionNotMet( "TransDataLayout only support DataLayout transform on same place.")); @@ -72,21 +73,20 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, auto src_dim = in.dims(); std::vector dst_dim; - auto axis = GetAxis(kernel_type_for_var.data_layout_, - expected_kernel_type.data_layout_); + auto axis = + GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout()); dst_dim.resize(axis.size()); for (size_t i = 0; i < axis.size(); i++) { dst_dim[i] = src_dim[axis[i]]; } out->Resize(phi::make_ddim(dst_dim)); - out->mutable_data(expected_kernel_type.place_, in.dtype()); + out->mutable_data(place, in.dtype()); - framework::VisitDataType( - framework::TransToProtoVarType(in.dtype()), - CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); + framework::VisitDataType(framework::TransToProtoVarType(in.dtype()), + CastDataLayout(pool.Get(place), axis, in, out)); - out->set_layout(expected_kernel_type.data_layout_); + out->set_layout(expected_kernel_type.layout()); } } // namespace framework diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index bad13e7e90..3bc55b8ad8 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -54,10 +54,11 @@ struct CastDataLayout { std::vector GetAxis(const DataLayout& from, const DataLayout& to); -void TransDataLayout(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataLayout(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, - phi::DenseTensor* out); + phi::DenseTensor* out, + const phi::Place& place); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/paddle/fluid/framework/data_layout_transform_test.cc index 9b314fbb2c..880fa5b057 100644 --- a/paddle/fluid/framework/data_layout_transform_test.cc +++ b/paddle/fluid/framework/data_layout_transform_test.cc @@ -24,22 +24,16 @@ TEST(DataTransform, DataLayoutFunction) { in.set_layout(phi::DataLayout::kNHWC); auto kernel_nhwc = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kNHWC, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::kNHWC, phi::DataType::FLOAT32); auto kernel_ncwh = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kNCHW, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::kNCHW, phi::DataType::FLOAT32); - paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); + paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out, place); EXPECT_TRUE(out.layout() == phi::DataLayout::kNCHW); EXPECT_TRUE(out.dims() == phi::make_ddim({2, 2, 3, 1})); - TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out); + paddle::framework::TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out, place); EXPECT_TRUE(in.layout() == phi::DataLayout::kNHWC); EXPECT_TRUE(in.dims() == phi::make_ddim({2, 3, 1, 2})); diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index fff4f6acb3..38e1ce1c31 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -36,16 +36,17 @@ static void PassTensorData(phi::DenseTensor *from, phi::DenseTensor *to) { *from = phi::DenseTensor(); } -void TransformData(const OpKernelType &expected_kernel_type, - const OpKernelType &kernel_type_for_var, +void TransformData(const phi::KernelKey &expected_kernel_type, + const phi::KernelKey &kernel_type_for_var, const phi::DenseTensor &input_tensor, - phi::DenseTensor *output_tensor) { + phi::DenseTensor *output_tensor, + const phi::Place &place) { bool transformed = false; phi::DenseTensor in; in.ShareDataWith(input_tensor); phi::DenseTensor out; - const DataLayout lin = kernel_type_for_var.data_layout_; - const DataLayout lout = expected_kernel_type.data_layout_; + const DataLayout lin = kernel_type_for_var.layout(); + const DataLayout lout = expected_kernel_type.layout(); // do layout transform if (NeedTransformLayout(lout, lin)) { #ifdef PADDLE_WITH_MKLDNN @@ -79,43 +80,42 @@ void TransformData(const OpKernelType &expected_kernel_type, } else { // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel // Do transform via ONEDNN lib - PADDLE_ENFORCE( - kernel_type_for_var.data_layout_ == DataLayout::ONEDNN && - expected_kernel_type.data_layout_ != DataLayout::ONEDNN, - platform::errors::InvalidArgument( - "TransDataLayoutFromOneDNN only supports " - "transform from ONEDNN to non-ONEDNN")); + PADDLE_ENFORCE(lin == DataLayout::ONEDNN && lout != DataLayout::ONEDNN, + platform::errors::InvalidArgument( + "TransDataLayoutFromOneDNN only supports " + "transform from ONEDNN to non-ONEDNN")); phi::funcs::TransDataLayoutFromOneDNN( - kernel_type_for_var.data_layout_, + lin, phi::OneDNNContext::tls().get_cur_paddle_data_layout(), in, &out, - expected_kernel_type.place_); + place); } } else { // Case3 - transfrom between Non-ONEDNN OPKernels - TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); + TransDataLayout( + kernel_type_for_var, expected_kernel_type, in, &out, place); } #else // Case3 - transfrom between Non-ONEDNN OPKernels - TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); + TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out, place); #endif transformed = true; PassTensorData(&out, &in); } // do data type transform - if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { + if (NeedTransformDataType(expected_kernel_type, kernel_type_for_var)) { TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); transformed = true; PassTensorData(&out, &in); } // do device transform - if (!platform::is_same_place(kernel_type_for_var.place_, - expected_kernel_type.place_)) { - TransDataDevice(in, expected_kernel_type.place_, &out); + if (kernel_type_for_var.backend() != phi::Backend::ALL_BACKEND && + !platform::is_same_place(in.place(), place)) { + TransDataDevice(in, place, &out); transformed = true; PassTensorData(&out, &in); } diff --git a/paddle/fluid/framework/data_transform.h b/paddle/fluid/framework/data_transform.h index 2fcea7803e..27bc0086c2 100644 --- a/paddle/fluid/framework/data_transform.h +++ b/paddle/fluid/framework/data_transform.h @@ -33,10 +33,11 @@ namespace framework { class OpKernelType; class Variable; -void TransformData(const OpKernelType &expected_kernel_type, - const OpKernelType &kernel_type_for_var, +void TransformData(const phi::KernelKey &expected_kernel_type, + const phi::KernelKey &kernel_type_for_var, const phi::DenseTensor &input_tensor, - phi::DenseTensor *out); + phi::DenseTensor *out, + const phi::Place &place); /** * Set OutVar from InVar, except the tensor is shared with `tensor` diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index fd1c06fc64..a05f2858c0 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace framework { @@ -226,6 +227,11 @@ extern inline bool IsComplexType(const proto::VarType::Type& type) { type == proto::VarType::COMPLEX128); } +extern inline bool IsComplexType(const phi::DataType& type) { + return (type == phi::DataType::COMPLEX64 || + type == phi::DataType::COMPLEX128); +} + extern proto::VarType::Type PromoteTypesIfComplexExists( const proto::VarType::Type type_a, const proto::VarType::Type type_b); diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 0768d2d82f..0f2e244af0 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -129,19 +129,18 @@ struct CastDataType { } }; -void TransDataType(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataType(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, phi::DenseTensor* out) { - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(in.dtype()), - kernel_type_for_var.data_type_, - platform::errors::InvalidArgument( - "The src dtype(%s) of input tensor and kernel_type(%s) " - "are not conststent.", - DataTypeToString(framework::TransToProtoVarType(in.dtype())), - DataTypeToString(kernel_type_for_var.data_type_))); - auto dst_type = expected_kernel_type.data_type_; + PADDLE_ENFORCE_EQ(in.dtype(), + kernel_type_for_var.dtype(), + platform::errors::InvalidArgument( + "The src dtype(%s) of input tensor and kernel_type(%s) " + "are not conststent.", + DataTypeToString(in.dtype()), + DataTypeToString(kernel_type_for_var.dtype()))); + auto dst_type = framework::TransToProtoVarType(expected_kernel_type.dtype()); TransDataType(in, dst_type, out); } diff --git a/paddle/fluid/framework/data_type_transform.h b/paddle/fluid/framework/data_type_transform.h index 619e15b604..2ec193b675 100644 --- a/paddle/fluid/framework/data_type_transform.h +++ b/paddle/fluid/framework/data_type_transform.h @@ -28,8 +28,8 @@ class OpKernelType; using KernelTypePair = std::pair; -void TransDataType(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataType(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, phi::DenseTensor* out); void TransDataType(const phi::DenseTensor& in, diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index e57d9d6d26..44ebdc96e6 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -19,47 +19,26 @@ limitations under the License. */ TEST(DataTypeTransform, CPUTransform) { auto place = paddle::platform::CPUPlace(); - auto kernel_fp16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_bf16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BF16, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + auto kernel_fp16 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16); + + auto kernel_bf16 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BFLOAT16); + + auto kernel_fp32 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); + + auto kernel_fp64 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64); auto kernel_int32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32); auto kernel_int64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64); auto kernel_bool = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); // data type transform from float32 { diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu index 6e047bbbf1..f9394bea7f 100644 --- a/paddle/fluid/framework/data_type_transform_test.cu +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -24,41 +24,24 @@ TEST(DataTypeTransform, GPUTransform) { .GetAllocator(gpu_place, context.stream()) .get()); context.PartialInitWithAllocator(); - auto kernel_fp16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_int32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_int64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_bool = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + + auto kernel_fp16 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16); + + auto kernel_fp32 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); + + auto kernel_fp64 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64); + + auto kernel_int32 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32); + + auto kernel_int64 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64); + + auto kernel_bool = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); // data type transform from float32 { diff --git a/paddle/fluid/framework/details/build_strategy_test.cc b/paddle/fluid/framework/details/build_strategy_test.cc index c39388fa5b..7ec7d93ee6 100644 --- a/paddle/fluid/framework/details/build_strategy_test.cc +++ b/paddle/fluid/framework/details/build_strategy_test.cc @@ -50,10 +50,10 @@ class SumOpWithKernel : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext &ctx) const override { - return OpKernelType(proto::VarType::FP32, - ctx.Input("X")->place()); + return phi::KernelKey(proto::VarType::FP32, + ctx.Input("X")->place()); } }; diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/paddle/fluid/framework/infershape_utils_test.cc index 6aef5b7a89..43fbb7d550 100644 --- a/paddle/fluid/framework/infershape_utils_test.cc +++ b/paddle/fluid/framework/infershape_utils_test.cc @@ -84,9 +84,9 @@ class InferShapeUtilsTestOp : public OperatorWithKernel { public: using OperatorWithKernel::OperatorWithKernel; - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); + return phi::KernelKey(proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index d41a1dca44..7e97d82c78 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -27,22 +27,26 @@ namespace paddle { namespace framework { namespace interpreter { -bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key, - const std::string& var_name, - std::string* new_var_name, - std::vector* op_func_nodes, - bool use_local_scope, - bool is_fetch_v2, - bool skip_run) { +bool DataTranferHelper::apply( + const phi::KernelKey& kernel_type_for_var, + const framework::OpKernelType& expected_kernel_key, + const phi::DenseTensor* tensor, + const std::string& var_name, + std::string* new_var_name, + std::vector* op_func_nodes, + bool use_local_scope, + bool is_fetch_v2, + bool skip_run) { bool is_transferred = false; auto* src_var_name = &var_name; // 1. layout transform - if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { + if (need_layout_transform( + kernel_type_for_var, + TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { auto op = TransferLayout(*src_var_name, new_var_name, - kernel_type_for_var.data_layout_, + kernel_type_for_var.layout(), expected_kernel_key.data_layout_, var_scope_, scope_, @@ -56,13 +60,16 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, is_transferred = true; } // 2. dype transform - if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) { - auto op = TransferDtype(*src_var_name, - new_var_name, - kernel_type_for_var.data_type_, - expected_kernel_key.data_type_, - var_scope_, - scope_); + if (need_dtype_transform( + kernel_type_for_var, + TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { + auto op = TransferDtype( + *src_var_name, + new_var_name, + framework::TransToProtoVarType(kernel_type_for_var.dtype()), + expected_kernel_key.data_type_, + var_scope_, + scope_); if (op) { RunAndConstructOpFuncNode( op, *src_var_name, *new_var_name, op_func_nodes, skip_run); @@ -72,8 +79,9 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, is_transferred = true; } // 3. device transform - if (need_device_transform(kernel_type_for_var, expected_kernel_key)) { - auto src_place = kernel_type_for_var.place_; + if (need_device_transform( + kernel_type_for_var, tensor, expected_kernel_key.place_)) { + auto src_place = tensor->place(); auto dst_place = expected_kernel_key.place_; auto op = TransferDevice( @@ -526,11 +534,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, auto kernel_type_for_var = static_cast(op_base) ->GetKernelTypeForVar( - var_name_item.first, *tensor_in, expected_kernel_key); + var_name_item.first, + *tensor_in, + framework::TransOpKernelTypeToPhiKernelKey( + expected_kernel_key)); // apply data transform is_transferred = data_transfer_helper.apply(kernel_type_for_var, expected_kernel_key, + tensor_in, var_name, &new_var_name, new_op_func_nodes, diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h index e74fe8066e..604f120380 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h @@ -34,8 +34,9 @@ class DataTranferHelper { Scope* local_scope) : place_(place), var_scope_(var_scope), scope_(local_scope) {} - bool apply(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key, + bool apply(const phi::KernelKey& kernel_type_for_var, + const framework::OpKernelType& expected_kernel_key, + const phi::DenseTensor* tensor, const std::string& var_name, std::string* new_var_name, std::vector* new_op_func_nodes, @@ -79,28 +80,28 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, framework::Scope* local_scope, bool skip_run = false); -inline bool need_device_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { - auto& src_place = kernel_type_for_var.place_; - auto& dst_place = expected_kernel_key.place_; - if (platform::is_same_place(src_place, dst_place) || - (platform::is_cuda_pinned_place(src_place) && - platform::is_cpu_place(dst_place))) { +inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var, + const phi::DenseTensor* tensor, + const phi::Place& expected_place) { + if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND || + platform::is_same_place(tensor->place(), expected_place) || + (platform::is_cuda_pinned_place(tensor->place()) && + platform::is_cpu_place(expected_place))) { return false; } return true; } -inline bool need_dtype_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { +inline bool need_dtype_transform(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key) { return framework::NeedTransformDataType(kernel_type_for_var, expected_kernel_key); } -inline bool need_layout_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { - return framework::NeedTransformLayout(kernel_type_for_var.data_layout_, - expected_kernel_key.data_layout_); +inline bool need_layout_transform(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key) { + return framework::NeedTransformLayout(kernel_type_for_var.layout(), + expected_kernel_key.layout()); } std::shared_ptr TransferLayout(const std::string& var_name, diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f5b430e829..f98acfdccf 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -730,8 +730,8 @@ bool BuildOpFuncList(const platform::Place& place, auto* dev_ctx = pool.Get(place); auto exec_ctx = ExecutionContext( *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); - auto expected_kernel_key = - op_with_kernel->GetExpectedKernelType(exec_ctx); + auto expected_kernel_key = framework::TransPhiKernelKeyToOpKernelType( + op_with_kernel->GetExpectedKernelType(exec_ctx)); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (op_with_kernel->CanCUDNNBeUsed(exec_ctx, expected_kernel_key.data_type_)) { @@ -741,6 +741,10 @@ bool BuildOpFuncList(const platform::Place& place, VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // change device by the device_guard() ApplyDeviceGuard(op, place, &expected_kernel_key); + if (platform::places_are_same_class(exec_ctx.GetPlace(), + expected_kernel_key.place_)) { + expected_kernel_key.place_ = exec_ctx.GetPlace(); + } // step 2. select op kernel auto run_phi_kernel = false; diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index a609313e84..eb969a94d8 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace framework { @@ -108,15 +109,32 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { return ret; } -inline bool NeedTransformDataType(const OpKernelType& l, - const OpKernelType& r) { - return (l.data_type_ != r.data_type_); +inline bool NeedTransformDataType(const phi::KernelKey& l, + const phi::KernelKey& r) { + return l.dtype() != phi::DataType::ALL_DTYPE && + r.dtype() != phi::DataType::ALL_DTYPE && l.dtype() != r.dtype(); } -inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) { - return (!platform::places_are_same_class(l.place_, r.place_)) || - (l.data_type_ != r.data_type_) || - NeedTransformLayout(l.data_layout_, r.data_layout_); +inline bool backends_are_same_class(const phi::Backend& l, + const phi::Backend& r) { + if (l == phi::Backend::ALL_BACKEND || r == phi::Backend::ALL_BACKEND) { + return true; + } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + size_t num_backends = static_cast(phi::Backend::NUM_BACKENDS); + if (static_cast(l) > num_backends && + static_cast(r) > num_backends) { + return phi::TransToPhiPlace(l).GetDeviceType() == + phi::TransToPhiPlace(r).GetDeviceType(); + } +#endif + return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r); +} + +inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) { + return !backends_are_same_class(l.backend(), r.backend()) || + NeedTransformDataType(l, r) || + NeedTransformLayout(l.layout(), r.layout()); } } // namespace framework diff --git a/paddle/fluid/framework/op_registry_test.cc b/paddle/fluid/framework/op_registry_test.cc index 9ef577f628..5a40a4df00 100644 --- a/paddle/fluid/framework/op_registry_test.cc +++ b/paddle/fluid/framework/op_registry_test.cc @@ -214,9 +214,10 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(InferShapeContext* ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::FP32, ctx.device_context()); + return phi::KernelKey(proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; @@ -275,12 +276,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel { protected: void InferShape(InferShapeContext* ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::FP32, - platform::CUDAPlace(0), - DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); + return phi::KernelKey(phi::Backend::GPUDNN, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32); } }; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index dcb822afb4..5e8d0b1b87 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1380,8 +1380,7 @@ bool OperatorWithKernel::SupportXPU() const { #endif } -bool OperatorWithKernel::SupportsMKLDNN( - const proto::VarType::Type data_type) const { +bool OperatorWithKernel::SupportsMKLDNN(const phi::DataType data_type) const { auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( phi::TransToPhiKernelName(type_)); auto has_phi_kernel = @@ -1389,8 +1388,7 @@ bool OperatorWithKernel::SupportsMKLDNN( phi_kernels.end(), [data_type](phi::KernelKeyMap::const_reference kern_pair) { return kern_pair.first.backend() == phi::Backend::ONEDNN && - kern_pair.first.dtype() == - framework::TransToPhiDataType(data_type); + kern_pair.first.dtype() == data_type; }); if (has_phi_kernel) { return true; @@ -1406,25 +1404,22 @@ bool OperatorWithKernel::SupportsMKLDNN( [data_type](OpKernelMap::const_reference kern_pair) { return platform::is_cpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == LibraryType::kMKLDNN && - kern_pair.first.data_type_ == data_type; + kern_pair.first.data_type_ == TransToProtoVarType(data_type); }); } } } -bool OperatorWithKernel::SupportsCUDNN( - const proto::VarType::Type data_type) const { +bool OperatorWithKernel::SupportsCUDNN(const phi::DataType data_type) const { auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( phi::TransToPhiKernelName(type_)); - paddle::experimental::DataType phi_data_type = - framework::TransToPhiDataType(data_type); - auto has_phi_kernel = std::any_of( - phi_kernels.begin(), - phi_kernels.end(), - [phi_data_type](phi::KernelKeyMap::const_reference kern_pair) { - return kern_pair.first.backend() == phi::Backend::GPUDNN && - kern_pair.first.dtype() == phi_data_type; - }); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [data_type](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::GPUDNN && + kern_pair.first.dtype() == data_type; + }); if (has_phi_kernel) { return true; } else { @@ -1433,13 +1428,15 @@ bool OperatorWithKernel::SupportsCUDNN( return false; } else { auto& op_kernels = op_kernel_iter->second; + proto::VarType::Type fluid_data_type = + framework::TransToProtoVarType(data_type); return std::any_of( op_kernels.begin(), op_kernels.end(), - [data_type](OpKernelMap::const_reference kern_pair) { + [fluid_data_type](OpKernelMap::const_reference kern_pair) { return platform::is_gpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == LibraryType::kCUDNN && - kern_pair.first.data_type_ == data_type; + kern_pair.first.data_type_ == fluid_data_type; }); } } @@ -1509,14 +1506,19 @@ bool OperatorWithKernel::SupportsKernelType( } bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, - proto::VarType::Type data_type) const { + phi::DataType data_type) const { return ctx.HasAttr("use_mkldnn") && ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()) && this->SupportsMKLDNN(data_type); } +bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { + return this->CanMKLDNNBeUsed(ctx, phi::TransToPhiDataType(data_type)); +} + bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, - proto::VarType::Type data_type) const { + phi::DataType data_type) const { bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn") && paddle::platform::is_gpu_place(ctx.GetPlace()); @@ -1528,7 +1530,7 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #if defined(PADDLE_WITH_CUDA) - if (use_cudnn && data_type == framework::proto::VarType::BF16) { + if (use_cudnn && data_type == phi::DataType::BFLOAT16) { PADDLE_ENFORCE_GE( platform::DnnVersion(), 8100, @@ -1540,6 +1542,11 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, return use_cudnn && this->SupportsCUDNN(data_type); } +bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { + return this->CanCUDNNBeUsed(ctx, phi::TransToPhiDataType(data_type)); +} + void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { PADDLE_THROW(platform::errors::PermissionDenied( "The default InferShape function of OperatorWithKernel is not allowed to " @@ -1839,8 +1846,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, 1, platform::EventRole::kInnerOp); if (need_prepare_data_) { - transfer_scope = PrepareData( - scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); + transfer_scope = + PrepareData(scope, + framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_), + &transfered_inplace_vars, + runtime_ctx, + dev_ctx->GetPlace()); } } // exec scope is the scope that kernel actually executed on. @@ -1960,7 +1971,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( const ExecutionContext& ctx) const { - auto expected_kernel_key = this->GetExpectedKernelType(ctx); + phi::KernelKey phi_kernel_key = this->GetExpectedKernelType(ctx); + auto expected_kernel_key = + framework::TransPhiKernelKeyToOpKernelType(phi_kernel_key); // NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function // GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and @@ -2063,6 +2076,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( } } } + + if (platform::places_are_same_class(expected_kernel_key.place_, + ctx.GetPlace())) { + expected_kernel_key.place_ = ctx.GetPlace(); + } + VLOG(3) << "op type:" << type_ << ", expected_kernel_key:" << expected_kernel_key; return expected_kernel_key; @@ -2333,9 +2352,10 @@ void OperatorWithKernel::HandleComplexGradToRealGrad( Scope* OperatorWithKernel::PrepareData( const Scope& scope, - const OpKernelType& expected_kernel_key, + const phi::KernelKey& expected_kernel_key, std::vector* transfered_inplace_vars, - RuntimeContext* ctx) const { + RuntimeContext* ctx, + const phi::Place& place) const { Scope* new_scope = nullptr; const std::unordered_set* no_buffer_ins = nullptr; @@ -2378,7 +2398,7 @@ Scope* OperatorWithKernel::PrepareData( // has to be created and registered if ((tensor_in->layout() == DataLayout::ONEDNN) && (var->IsType() == true) && - (expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && + (expected_kernel_key.layout() != DataLayout::ONEDNN) && (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == DataLayout::kNHWC) && (tensor_in->dims().size() >= 3)) { @@ -2411,35 +2431,33 @@ Scope* OperatorWithKernel::PrepareData( auto kernel_type_for_var = GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key); bool need_trans_dtype = - kernel_type_for_var.data_type_ != expected_kernel_key.data_type_; + NeedTransformDataType(expected_kernel_key, kernel_type_for_var); bool need_trans_layout = NeedTransformLayout( - kernel_type_for_var.data_layout_, expected_kernel_key.data_layout_); + kernel_type_for_var.layout(), expected_kernel_key.layout()); if (!need_trans_dtype && !need_trans_layout) { if (!run_phi_kernel_ && - platform::places_are_same_class(kernel_type_for_var.place_, - expected_kernel_key.place_)) { + backends_are_same_class(kernel_type_for_var.backend(), + expected_kernel_key.backend())) { continue; } } - std::unique_ptr new_expected_kernel_key = nullptr; + std::unique_ptr new_expected_kernel_key = nullptr; if (run_phi_kernel_ && in_def != nullptr && in_def->backend != phi::Backend::ALL_BACKEND) { auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); if ((in_def->backend != tensor_backend && - (in_def->backend != phi::Backend::GPUDNN || - tensor_backend != phi::Backend::GPU) && - (in_def->backend != phi::Backend::KPS || - tensor_backend != phi::Backend::XPU) && - (in_def->backend != phi::Backend::ONEDNN || - tensor_backend != phi::Backend::CPU)) || + !(in_def->backend == phi::Backend::GPUDNN && + tensor_backend == phi::Backend::GPU) && + !(in_def->backend == phi::Backend::KPS && + tensor_backend == phi::Backend::XPU) && + !(in_def->backend == phi::Backend::ONEDNN && + tensor_backend == phi::Backend::CPU)) || tensor_in->place().GetType() == AllocationType::GPUPINNED) { - new_expected_kernel_key = std::make_unique( - expected_kernel_key.data_type_, - phi::TransToPhiPlace(in_def->backend), - expected_kernel_key.data_layout_, - expected_kernel_key.library_type_, - expected_kernel_key.customized_type_value_); + new_expected_kernel_key = + std::make_unique(in_def->backend, + expected_kernel_key.layout(), + expected_kernel_key.dtype()); } } @@ -2474,14 +2492,18 @@ Scope* OperatorWithKernel::PrepareData( enable_cache_transfer_scope_ = false; if (!run_by_executor_) { if (new_expected_kernel_key) { - if ((platform::is_gpu_place(kernel_type_for_var.place_) || - platform::is_gpu_place(new_expected_kernel_key->place_))) { + if (kernel_type_for_var.backend() == phi::Backend::GPU || + kernel_type_for_var.backend() == phi::Backend::GPUDNN || + new_expected_kernel_key->backend() == phi::Backend::GPU || + new_expected_kernel_key->backend() == phi::Backend::GPUDNN) { new_scope = TryCreateTransferScope( kernel_type_for_var, *new_expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; } - } else if ((platform::is_gpu_place(kernel_type_for_var.place_) || - platform::is_gpu_place(expected_kernel_key.place_))) { + } else if (kernel_type_for_var.backend() == phi::Backend::GPU || + kernel_type_for_var.backend() == phi::Backend::GPUDNN || + expected_kernel_key.backend() == phi::Backend::GPU || + expected_kernel_key.backend() == phi::Backend::GPUDNN) { new_scope = TryCreateTransferScope( kernel_type_for_var, expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; @@ -2523,11 +2545,15 @@ Scope* OperatorWithKernel::PrepareData( // Do transfer phi::DenseTensor out; - TransformData(new_expected_kernel_key ? *new_expected_kernel_key - : expected_kernel_key, - kernel_type_for_var, - *tensor_in, - &out); + TransformData( + new_expected_kernel_key ? *new_expected_kernel_key + : expected_kernel_key, + kernel_type_for_var, + *tensor_in, + &out, + new_expected_kernel_key + ? phi::TransToPhiPlace(new_expected_kernel_key->backend()) + : place); SetTensorToVariable(*var, out, trans_var); } }; @@ -2818,30 +2844,29 @@ proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes( return target_type; } -OpKernelType OperatorWithKernel::GetExpectedKernelType( +phi::KernelKey OperatorWithKernel::GetExpectedKernelType( const ExecutionContext& ctx) const { - return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); + return phi::KernelKey(IndicateDataType(ctx), ctx.GetPlace()); } -OpKernelType OperatorWithKernel::GetKernelTypeForVar( +phi::KernelKey OperatorWithKernel::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // When the op is first oneDNN op (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey( + tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype()); } #endif - return OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 07e1a26c7c..b4e0c94c20 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -638,16 +638,22 @@ class OperatorWithKernel : public OperatorBase { bool SupportXPU() const override; - bool SupportsMKLDNN(proto::VarType::Type data_type) const; + bool SupportsMKLDNN(phi::DataType data_type) const; - bool SupportsCUDNN(proto::VarType::Type data_type) const; + bool SupportsCUDNN(phi::DataType data_type) const; bool SupportsKernelType(const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + phi::DataType data_type) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; + bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + phi::DataType data_type) const; + bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; @@ -665,14 +671,15 @@ class OperatorWithKernel : public OperatorBase { const std::string& name1, const std::string& name2) const; - virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; + virtual phi::KernelKey GetExpectedKernelType( + const ExecutionContext& ctx) const; // change this to public so that in dygraph mode we can call it to check if we // need transform data - virtual OpKernelType GetKernelTypeForVar( + virtual phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const; + const phi::KernelKey& expected_kernel_type) const; platform::Place GetExecutionPlace( const platform::Place& platform) const override { @@ -734,9 +741,10 @@ class OperatorWithKernel : public OperatorBase { * transfered_inplace_vars is a output vector. */ Scope* PrepareData(const Scope& scope, - const OpKernelType& expected_kernel_key, + const phi::KernelKey& expected_kernel_key, std::vector* transfered_inplace_vars, - RuntimeContext* ctx) const; + RuntimeContext* ctx, + const phi::Place& place) const; void TransferInplaceVarsBack(const Scope& scope, const std::vector& inplace_vars, diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index a7b597fdb3..1d57efd875 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -127,14 +127,10 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - int sub_type = ctx.Attr("kernel_sub_type"); - return OpKernelType(proto::VarType::FP32, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kPlain, - sub_type); + return phi::KernelKey( + ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); } }; @@ -256,16 +252,6 @@ TEST(OpKernel, all) { // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); - - attr = op_desc.mutable_attrs()->Add(); - attr->set_name("kernel_sub_type"); - attr->set_type(paddle::framework::proto::AttrType::INT); - attr->set_i(1); - auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc); - op2->Run(scope, cpu_place); - // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called. - ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); - ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1); } REGISTER_OP_WITHOUT_GRADIENT( @@ -339,11 +325,11 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "phi::DenseTensor"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -361,11 +347,11 @@ class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; class IndicateSelectedRowsDataTypeTestProtoMaker @@ -383,10 +369,10 @@ class IndicateOtherDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { @@ -597,10 +583,10 @@ class OpUnusedVarTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType( - proto::VarType::FP32, ctx.GetPlace(), phi::DataLayout::kAnyLayout); + return phi::KernelKey( + ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); } }; diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index c812a6dc95..60c2516c00 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -34,12 +34,13 @@ global_transfer_scope_key() { return *x; } -Scope* TryCreateTransferScope(OpKernelType type0, - OpKernelType type1, +Scope* TryCreateTransferScope(const phi::KernelKey& type0, + const phi::KernelKey& type1, const Scope* scope) { Scope* new_scope{nullptr}; size_t infer_cache_key = - CombineHash(OpKernelType::Hash()(type0), OpKernelType::Hash()(type1)); + CombineHash(static_cast(phi::KernelKey::Hash()(type0)), + static_cast(phi::KernelKey::Hash()(type1))); infer_cache_key = CombineHash(infer_cache_key, std::hash()(scope)); diff --git a/paddle/fluid/framework/transfer_scope_cache.h b/paddle/fluid/framework/transfer_scope_cache.h index da2e319d5b..58707f501a 100644 --- a/paddle/fluid/framework/transfer_scope_cache.h +++ b/paddle/fluid/framework/transfer_scope_cache.h @@ -39,8 +39,8 @@ static size_t CombineHash(size_t seed, size_t a) { return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } -Scope* TryCreateTransferScope(OpKernelType type0, - OpKernelType type1, +Scope* TryCreateTransferScope(const phi::KernelKey& type0, + const phi::KernelKey& type1, const Scope* scope); } // namespace framework diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index cc3ed77c39..43c83a7237 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -23,7 +23,8 @@ if(WITH_XPU) scalar int_array var_helper - profiler) + profiler + place) else() cc_library( prepared_operator @@ -40,7 +41,8 @@ else() scalar int_array var_helper - profiler) + profiler + place) endif() cc_library( layer diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index d3f163da2a..6f1f54de8a 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -24,6 +24,7 @@ #include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace imperative { @@ -39,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const framework::AttributeMap* attr, const framework::AttributeMap* default_attr, const std::string op_type, - const framework::OpKernelType* op_kernel_type = nullptr, + const phi::KernelKey* op_kernel_key = nullptr, const phi::ArgumentMappingFn* arg_map_fn = nullptr, const phi::KernelSignature* default_kernel_signature = nullptr) : var_map_in_(in), @@ -47,7 +48,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { attrs_(attr), default_attrs_(default_attr), op_type_(op_type), - op_kernel_type_(op_kernel_type), + op_kernel_key_(op_kernel_key), arg_map_fn_(arg_map_fn), default_kernel_signature_(default_kernel_signature) {} @@ -250,8 +251,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { bool IsRuntime() const override { return true; } bool IsRunMKLDNNKernel() const override { - return (op_kernel_type_ && - (op_kernel_type_->data_layout_ == phi::DataLayout::ONEDNN)); + return (op_kernel_key_ && + (op_kernel_key_->layout() == phi::DataLayout::ONEDNN)); } paddle::small_vector @@ -497,7 +498,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const framework::AttributeMap* attrs_; const framework::AttributeMap* default_attrs_; const std::string op_type_; - const framework::OpKernelType* op_kernel_type_; + const phi::KernelKey* op_kernel_key_; // arg_map_fn_ and default_kernel_signature_ may be nullptr const phi::ArgumentMappingFn* arg_map_fn_; const phi::KernelSignature* default_kernel_signature_; diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 89398c5246..2ac43c39d7 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -519,8 +519,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, */ auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); - auto tmp_ins_ptr = - PrepareData(*op_kernel, ins, prepared_op.kernel_type()); + auto tmp_ins_ptr = PrepareData( + *op_kernel, ins, prepared_op.kernel_key(), prepared_op.place()); if (tmp_ins_ptr == nullptr) { prepared_op.Run(ins, outs, attrs, default_attrs); } else { diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 5eb045a0c5..32a4515624 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -30,6 +30,7 @@ #endif #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h" @@ -116,14 +117,14 @@ void TestHandleComplexGradToRealGradEager( PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), - kernel_type_(kernel_type), + kernel_key_(kernel_key), func_(func), dev_ctx_(dev_ctx), arg_map_fn_(arg_map_fn), @@ -132,7 +133,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, phi::KernelSignature&& kernel_signature, @@ -140,7 +141,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), - kernel_type_(kernel_type), + kernel_key_(kernel_key), func_(nullptr), dev_ctx_(dev_ctx), run_phi_kernel_(true), @@ -228,7 +229,6 @@ PreparedOp PrepareImpl( const phi::KernelSignature* default_kernel_signature = nullptr; phi::KernelSignature kernel_signature; - phi::KernelKey phi_kernel_key; std::string phi_kernel_name; // NOTE(jiahongyu): The registered MKLDNN kernel have library_type = @@ -240,29 +240,27 @@ PreparedOp PrepareImpl( // 3. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) && - op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { - expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; - expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN; + op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) { + expected_kernel_key.set_backend(phi::Backend::ONEDNN); + expected_kernel_key.set_layout(phi::DataLayout::ONEDNN); } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { - expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) { + expected_kernel_key.set_backend(phi::Backend::GPUDNN); } #endif #if defined(PADDLE_WITH_XPU) - bool is_xpu_unsupport = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + bool is_xpu_unsupport = expected_kernel_key.backend() == phi::Backend::XPU && + !paddle::platform::is_xpu_support_op( + op.Type(), expected_kernel_key.dtype()); #endif #ifdef PADDLE_WITH_MLU if (is_in_mlu_black_list(op.Type())) { - expected_kernel_key.place_ = platform::CPUPlace(); + expected_kernel_key.set_backend(phi::Backend::CPU); } #endif @@ -290,12 +288,10 @@ PreparedOp PrepareImpl( // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. #ifdef PADDLE_WITH_XPU_KP - if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + if (expected_kernel_key.backend() == phi::Backend::XPU) { bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op( + op.Type(), expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(op.Type()); if (use_xpu_kp_kernel_rt) { @@ -307,17 +303,14 @@ PreparedOp PrepareImpl( bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { - auto expected_kernel_key_library_type = - expected_kernel_key.library_type_; - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + auto expected_kernel_key_backend = expected_kernel_key.backend(); + expected_kernel_key.set_backend(phi::Backend::KPS); VLOG(3) << "modifing XPU KP kernel: " << phi_kernel_name << ", using_kernel_key:" << expected_kernel_key; - phi::KernelKey try_phi_kernel_key = - TransOpKernelTypeToPhiKernelKey(expected_kernel_key); if (!phi_kernel_factory.HasKernel(phi_kernel_name, - try_phi_kernel_key)) { - expected_kernel_key.library_type_ = expected_kernel_key_library_type; + expected_kernel_key)) { + expected_kernel_key.set_backend(expected_kernel_key_backend); VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name << " in dynamic graph is failed " << expected_kernel_key; } else { @@ -328,9 +321,8 @@ PreparedOp PrepareImpl( } #endif - phi_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); auto& phi_kernel = - phi_kernel_factory.SelectKernel(phi_kernel_name, phi_kernel_key); + phi_kernel_factory.SelectKernel(phi_kernel_name, expected_kernel_key); if (phi_kernel.IsValid() #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) @@ -338,13 +330,14 @@ PreparedOp PrepareImpl( #endif ) { VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name - << " | kernel key: " << phi_kernel_key + << " | kernel key: " << expected_kernel_key << " | kernel: " << phi_kernel; - if (expected_kernel_key.place_ != place) { - dev_ctx = pool.Get(expected_kernel_key.place_); + if (!framework::backends_are_same_class( + expected_kernel_key.backend(), + phi::TransToPhiBackend(dev_ctx->GetPlace()))) { + dev_ctx = pool.Get(phi::TransToPhiPlace(expected_kernel_key.backend())); } - return PreparedOp(op, empty_ctx, expected_kernel_key, @@ -368,22 +361,23 @@ PreparedOp PrepareImpl( // registered in KP use library_type[KP], we need to modify it. #ifdef PADDLE_WITH_XPU_KP bool use_xpu_kp_kernel_rt = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && + expected_kernel_key.backend() == phi::Backend::XPU && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + paddle::platform::is_xpu_support_op(op.Type(), + expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && + expected_kernel_key.backend() == phi::Backend::XPU && paddle::platform::is_in_xpu_kpwhite_list(op.Type()); bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + expected_kernel_key.set_backend(phi::Backend::KPS); } #endif + paddle::framework::OpKernelType fluid_kernel_type = + paddle::framework::TransPhiKernelKeyToOpKernelType(expected_kernel_key); if ((kernels_iter == all_op_kernels.end() || - kernels_iter->second.find(expected_kernel_key) == + kernels_iter->second.find(fluid_kernel_type) == kernels_iter->second.end()) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport @@ -393,7 +387,7 @@ PreparedOp PrepareImpl( #endif ) { if (has_phi_kernel) { - auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op); + auto phi_cpu_kernel_key = FallBackToCpu(expected_kernel_key, op); auto& phi_cpu_kernel = phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key); if (phi_cpu_kernel.IsValid()) { @@ -401,15 +395,14 @@ PreparedOp PrepareImpl( << " | kernel key: " << phi_cpu_kernel_key << " | kernel: " << phi_cpu_kernel; auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp( - op, - empty_ctx, - framework::TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key), - arg_map_fn, - default_kernel_signature, - std::move(kernel_signature), - phi_cpu_kernel, - cpu_ctx); + return PreparedOp(op, + empty_ctx, + phi_cpu_kernel_key, + arg_map_fn, + default_kernel_signature, + std::move(kernel_signature), + phi_cpu_kernel, + cpu_ctx); } } } @@ -422,21 +415,21 @@ PreparedOp PrepareImpl( op.Type())); auto& kernels = kernels_iter->second; - auto kernel_iter = kernels.find(expected_kernel_key); + auto kernel_iter = kernels.find(fluid_kernel_type); #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) - if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && + if (paddle::platform::is_xpu_place(fluid_kernel_type.place_) && (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "fluid missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_XPU_KP - if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + if (paddle::platform::is_xpu_place(fluid_kernel_type.place_)) { if (use_xpu_kp_kernel_rt) { VLOG(3) << "fluid xpu_kp using rt mode "; } @@ -444,60 +437,60 @@ PreparedOp PrepareImpl( VLOG(3) << "fluid xpu_kp using debug mode "; } if (is_xpu_kp_support) { - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.library_type_ = paddle::framework::LibraryType::kKP; + kernel_iter = kernels.find(fluid_kernel_type); VLOG(3) << "using fluid XPU KP kernel: " << op.Type() - << ", using_kernel_key:" << expected_kernel_key; + << ", using_kernel_key:" << fluid_kernel_type; } if (!is_xpu_kp_support && (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "fluid missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } } #endif #ifdef PADDLE_WITH_ASCEND_CL if (kernel_iter == kernels.end() && - paddle::platform::is_npu_place(expected_kernel_key.place_)) { + paddle::platform::is_npu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing NPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_IPU if (kernel_iter == kernels.end() && - paddle::platform::is_ipu_place(expected_kernel_key.place_)) { + paddle::platform::is_ipu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing IPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_MLU if (kernel_iter == kernels.end() && - paddle::platform::is_mlu_place(expected_kernel_key.place_)) { + paddle::platform::is_mlu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing MLU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE if (kernel_iter == kernels.end() && - paddle::platform::is_custom_place(expected_kernel_key.place_)) { + paddle::platform::is_custom_place(fluid_kernel_type.place_)) { VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type() << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif // TODO(jiabin): Add operator.cc's line 1000 part back when we need that @@ -507,19 +500,20 @@ PreparedOp PrepareImpl( kernels.end(), platform::errors::NotFound("Operator %s does not have kernel for %s.", op.Type(), - KernelTypeToString(expected_kernel_key))); - - if (!(expected_kernel_key.place_ == place)) { - dev_ctx = pool.Get(expected_kernel_key.place_); - } - - return PreparedOp(op, - empty_ctx, - expected_kernel_key, - kernel_iter->second, - arg_map_fn, - default_kernel_signature, - dev_ctx); + KernelTypeToString(fluid_kernel_type))); + + if (!platform::places_are_same_class(fluid_kernel_type.place_, + dev_ctx->GetPlace())) { + dev_ctx = pool.Get(fluid_kernel_type.place_); + } + return PreparedOp( + op, + empty_ctx, + framework::TransOpKernelTypeToPhiKernelKey(fluid_kernel_type), + kernel_iter->second, + arg_map_fn, + default_kernel_signature, + dev_ctx); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, @@ -576,7 +570,7 @@ template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, @@ -597,7 +591,7 @@ static void PreparedOpRunImpl( &attrs, &default_attrs, op.Type(), - &kernel_type, + &kernel_key, arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); @@ -641,7 +635,7 @@ static void PreparedOpRunImpl( * grad op kernel executed, we need to recognize this situation and * convert dx to float32 type. HandleComplexGradToRealGrad does this thing. */ - if (framework::IsComplexType(kernel_type.data_type_)) { + if (framework::IsComplexType(kernel_key.dtype())) { HandleComplexGradToRealGrad(outs); } } @@ -649,7 +643,7 @@ static void PreparedOpRunImpl( template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature& kernel_signature, @@ -669,7 +663,7 @@ static void PreparedOpRunPtImpl( &attrs, &default_attrs, op.Type(), - &kernel_type, + &kernel_key, arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); @@ -712,7 +706,7 @@ static void PreparedOpRunPtImpl( #endif } - if (framework::IsComplexType(kernel_type.data_type_)) { + if (framework::IsComplexType(kernel_key.dtype())) { HandleComplexGradToRealGrad(outs); } } @@ -723,7 +717,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -736,7 +730,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, @@ -754,7 +748,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -767,7 +761,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, @@ -785,7 +779,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -798,7 +792,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index a3d90939fa..fb36a03e01 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -29,6 +29,7 @@ #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/var_helper.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" @@ -75,7 +76,8 @@ template std::shared_ptr> PrepareData( const framework::OperatorWithKernel& op, const NameVarMap& ins, - const framework::OpKernelType& expected_kernel_key) { + const phi::KernelKey& expected_kernel_key, + const phi::Place& place) { std::shared_ptr> tmp_ins_ptr = nullptr; for (const auto& name_pair : ins) { for (size_t i = 0; i < name_pair.second.size(); ++i) { @@ -85,7 +87,8 @@ std::shared_ptr> PrepareData( if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) { auto kernel_type_for_var = op.GetKernelTypeForVar( name_pair.first, *tensor, expected_kernel_key); - if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { + if (!framework::NeedTransform(kernel_type_for_var, + expected_kernel_key)) { continue; } else { VLOG(3) << "Transform Variable " << GetNameFromVar(template_var) @@ -111,10 +114,10 @@ std::shared_ptr> PrepareData( (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; } else { phi::DenseTensor out; - TransformData( - expected_kernel_key, kernel_type_for_var, *tensor, &out); - if (NeedTransformDataType(kernel_type_for_var, - expected_kernel_key)) { + framework::TransformData( + expected_kernel_key, kernel_type_for_var, *tensor, &out, place); + if (framework::NeedTransformDataType(kernel_type_for_var, + expected_kernel_key)) { // To avoid NameVarMap copy construction overhead in general // scenarios, if inplace transformed, return original input // directly @@ -149,7 +152,7 @@ class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, @@ -157,7 +160,7 @@ class PreparedOp { PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, phi::KernelSignature&& kernel_signature, @@ -200,12 +203,14 @@ class PreparedOp { const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); - const framework::OpKernelType& kernel_type() const { return kernel_type_; } + const phi::KernelKey& kernel_key() const { return kernel_key_; } + + const phi::Place& place() const { return dev_ctx_->GetPlace(); } private: const framework::OperatorBase& op_; const framework::RuntimeContext& ctx_; - framework::OpKernelType kernel_type_; + phi::KernelKey kernel_key_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; // NOTE(chenweihang): Similar op members are used to adapt to diff --git a/paddle/fluid/imperative/tests/test_eager.cc b/paddle/fluid/imperative/tests/test_eager.cc index 3eec90462d..6c27dead27 100644 --- a/paddle/fluid/imperative/tests/test_eager.cc +++ b/paddle/fluid/imperative/tests/test_eager.cc @@ -92,15 +92,15 @@ TEST(test_var_helper, eager_var_helper) { ASSERT_TRUE(platform::is_cpu_place(GetPlace(egr_tensor))); ASSERT_TRUE(GetDataType(egr_tensor) == framework::proto::VarType::FP32); - GetCachedValue( - egr_tensor, - framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace())); - SetCachedValue( - egr_tensor, - framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()), - egr_tensor2); + GetCachedValue(egr_tensor, + phi::KernelKey(phi::Backend::CPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32)); + SetCachedValue(egr_tensor, + phi::KernelKey(phi::Backend::CPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32), + egr_tensor2); ASSERT_ANY_THROW(GetPlace(egr_tensor2)); ASSERT_ANY_THROW(SetType( egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY)); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 613b729197..76510b39ce 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -172,7 +172,8 @@ TEST(test_prepare_op, test_prepare_data) { PrepareData( dynamic_cast(*op), ins, - prepared_op.kernel_type()); + prepared_op.kernel_key(), + gpu_place); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( @@ -229,7 +230,8 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { PrepareData( dynamic_cast(*op), ins, - prepared_op.kernel_type()); + prepared_op.kernel_key(), + cpu_place); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( diff --git a/paddle/fluid/imperative/var_helper.cc b/paddle/fluid/imperative/var_helper.cc index b5f1c8f1fd..bafea5a720 100644 --- a/paddle/fluid/imperative/var_helper.cc +++ b/paddle/fluid/imperative/var_helper.cc @@ -239,35 +239,31 @@ template void SetDataLayout( /* CheckCachedKey */ template -bool CheckCachedKey(std::shared_ptr var, - const paddle::framework::OpKernelType &key) { +bool CheckCachedKey(std::shared_ptr var, const phi::KernelKey &key) { return GetVariableWrapper(var)->hasCacheKey(key); } template <> bool CheckCachedKey( - std::shared_ptr tensor, - const paddle::framework::OpKernelType &key) { + std::shared_ptr tensor, const phi::KernelKey &key) { // TODO(jiabin): Support this later // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is // equal to self: " << key == key. return false; } -template bool CheckCachedKey( - std::shared_ptr var, const paddle::framework::OpKernelType &key); +template bool CheckCachedKey(std::shared_ptr var, + const phi::KernelKey &key); template bool CheckCachedKey( - std::shared_ptr var, - const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); /* GetCachedValue */ template -std::shared_ptr GetCachedValue( - std::shared_ptr var, const paddle::framework::OpKernelType &key) { +std::shared_ptr GetCachedValue(std::shared_ptr var, + const phi::KernelKey &key) { return GetVariableWrapper(var)->getCacheValue(key); } template <> std::shared_ptr GetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key) { + std::shared_ptr var, const phi::KernelKey &key) { // TODO(jiabin): Support this later // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this @@ -277,22 +273,21 @@ std::shared_ptr GetCachedValue( return std::make_shared(""); } template std::shared_ptr GetCachedValue( - std::shared_ptr var, const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); template std::shared_ptr GetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); /* SetCachedValue */ template void SetCachedValue(std::shared_ptr var, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res) { GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res)); } template <> void SetCachedValue( std::shared_ptr tensor, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res) { // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this @@ -300,13 +295,12 @@ void SetCachedValue( // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key // is equal to self: " << key == key << " and res name is:" << res->Name(). } -template void SetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key, - std::shared_ptr res); +template void SetCachedValue(std::shared_ptr var, + const phi::KernelKey &key, + std::shared_ptr res); template void SetCachedValue( std::shared_ptr var, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res); } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/var_helper.h b/paddle/fluid/imperative/var_helper.h index 5e96e86548..ebf3e49c51 100644 --- a/paddle/fluid/imperative/var_helper.h +++ b/paddle/fluid/imperative/var_helper.h @@ -43,16 +43,14 @@ template const std::string& GetNameFromVar(std::shared_ptr var); template -bool CheckCachedKey(std::shared_ptr tensor, - const paddle::framework::OpKernelType& key); +bool CheckCachedKey(std::shared_ptr tensor, const phi::KernelKey& key); template void SetCachedValue(std::shared_ptr tensor, - const paddle::framework::OpKernelType& key, + const phi::KernelKey& key, std::shared_ptr res); template -std::shared_ptr GetCachedValue( - std::shared_ptr tensor, - const paddle::framework::OpKernelType& key); +std::shared_ptr GetCachedValue(std::shared_ptr tensor, + const phi::KernelKey& key); template void SetType(std::shared_ptr var, diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index c1024b6d58..d4438e8b47 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -234,16 +234,15 @@ class VariableWrapper { } } - bool hasCacheKey(const paddle::framework::OpKernelType& key) { + bool hasCacheKey(const phi::KernelKey& key) { return var_cache.find(key) != var_cache.end(); } - std::shared_ptr getCacheValue( - const paddle::framework::OpKernelType& key) { + std::shared_ptr getCacheValue(const phi::KernelKey& key) { return var_cache[key]; } - void setCacheValue(const paddle::framework::OpKernelType& key, + void setCacheValue(const phi::KernelKey& key, std::shared_ptr val) { var_cache[key] = val; return; @@ -323,8 +322,7 @@ class VariableWrapper { // Used for cache the dtype promotioned variableWrapper in real and complex // compute of Paddle Quantum - std::map> - var_cache; + std::map> var_cache; // add this property for users may set stop_gradient themselves and this // should override the frameworks setting (-1) unset, (1) true, (0) false int overrided_stop_gradient_{-1}; diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 3310bdbbe8..0bf78f41d6 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -29,11 +29,11 @@ class AbsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -70,11 +70,11 @@ class AbsGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -124,20 +124,17 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } }; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 53cd5c92cd..649382ffc9 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -80,9 +80,9 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker { } }; -framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper, - const std::string& name) { +phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper, + const std::string& name) { auto data_type = oper.IndicateVarDataType(ctx, name); // FIXME(liuwei1031) temporarily disable the code to unblock users // TODO(liuwei1031) figure out the reason behind @@ -94,7 +94,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, // library = framework::LibraryType::kCUDNN; // } // #endif - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } class ActivationOp : public framework::OperatorWithKernel { @@ -107,7 +107,7 @@ class ActivationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } @@ -134,7 +134,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, framework::GradVarName("Out")); } @@ -341,7 +341,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -370,7 +370,7 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -411,7 +411,7 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -487,20 +487,22 @@ class PowOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "FactorTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -515,20 +517,22 @@ class PowOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, framework::GradVarName("Out")); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "FactorTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -537,7 +541,7 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } @@ -548,7 +552,7 @@ class PowOpTripleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index cd4a9fbdb3..0f52362c21 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -34,11 +34,10 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -54,11 +53,11 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 408d1c565e..90d8c8b0ce 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -145,11 +145,11 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index b23d3670d5..a0cb5480d5 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -130,10 +130,10 @@ class AffineGridOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -241,11 +241,11 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Output")); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index fa6bc1d6f7..ab876921f9 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -65,11 +65,10 @@ class AllcloseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/alloc_float_status_op.cc b/paddle/fluid/operators/amp/alloc_float_status_op.cc index fc96dd52e5..24e9608677 100644 --- a/paddle/fluid/operators/amp/alloc_float_status_op.cc +++ b/paddle/fluid/operators/amp/alloc_float_status_op.cc @@ -34,10 +34,9 @@ class AllocFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc index a8d1f36f11..c8faf2d655 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc @@ -29,13 +29,13 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; if (ctx.MultiInputVar("X").size() >= 1) { dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/clear_float_status_op.cc b/paddle/fluid/operators/amp/clear_float_status_op.cc index 7bfc2d34d2..06e4b986fa 100644 --- a/paddle/fluid/operators/amp/clear_float_status_op.cc +++ b/paddle/fluid/operators/amp/clear_float_status_op.cc @@ -34,10 +34,9 @@ class ClearFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/get_float_status_op.cc b/paddle/fluid/operators/amp/get_float_status_op.cc index 88a2affbca..d5a924b8d8 100644 --- a/paddle/fluid/operators/amp/get_float_status_op.cc +++ b/paddle/fluid/operators/amp/get_float_status_op.cc @@ -34,10 +34,9 @@ class GetFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc index f8ccac27c1..7f9b7da62f 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cc @@ -29,23 +29,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; if (ctx.MultiInputVar("X").size() >= 1) { dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifndef PADDLE_WITH_XPU if (var_name == "FoundInfinite" || var_name == "StopUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } #endif return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 0e44fd2fa2..090fdff31c 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -32,11 +32,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/ascend_trigger_op.cc b/paddle/fluid/operators/ascend_trigger_op.cc index abb39dce7a..b312f97d3f 100644 --- a/paddle/fluid/operators/ascend_trigger_op.cc +++ b/paddle/fluid/operators/ascend_trigger_op.cc @@ -23,10 +23,10 @@ class AscendTriggerOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 1af424fa77..244b3aec9c 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -41,16 +41,16 @@ class AssignOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const framework::Variable *var = ctx.InputVar("X"); if (var->IsType()) { @@ -58,14 +58,13 @@ class AssignOp : public framework::OperatorWithKernel { // NOTE(liym27): Support an empty tensor array as Input. // And set the kernel type is float. if (t_arr.size() == 0) { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_pos_op.cc b/paddle/fluid/operators/assign_pos_op.cc index 80412c7d67..24fc4adc60 100644 --- a/paddle/fluid/operators/assign_pos_op.cc +++ b/paddle/fluid/operators/assign_pos_op.cc @@ -31,7 +31,7 @@ class AssignPosOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto cum_count_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "cum_count"); @@ -46,7 +46,7 @@ class AssignPosOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The dtype of the cum_count_dtype, eff_num_len and " "X should be same as int64")); - return framework::OpKernelType(cum_count_dtype, ctx.device_context()); + return phi::KernelKey(cum_count_dtype, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_value_op.cc b/paddle/fluid/operators/assign_value_op.cc index b9806c7a69..766e55b031 100644 --- a/paddle/fluid/operators/assign_value_op.cc +++ b/paddle/fluid/operators/assign_value_op.cc @@ -44,9 +44,9 @@ class AssignValueOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 330b13ab8b..c461713855 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -198,10 +198,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Cell"); } -framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( +phi::KernelKey AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h index 0ce83be93c..391afc459f 100644 --- a/paddle/fluid/operators/attention_lstm_op.h +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -25,7 +25,7 @@ class AttentionLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 9f8f295c24..a59b78c3cd 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -26,10 +26,10 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 38504e3ecd..9010cadd15 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -77,11 +77,10 @@ class BatchFCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -106,11 +105,11 @@ class BatchFCGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 32cb10ec89..21a06e5257 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -171,7 +171,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { } } -framework::OpKernelType BatchNormOp::GetExpectedKernelType( +phi::KernelKey BatchNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -202,18 +202,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( platform::errors::InvalidArgument( "Variance input should be of float type")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType BatchNormOp::GetKernelTypeForVar( +phi::KernelKey BatchNormOp::GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "X") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -222,13 +222,12 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void BatchNormOpMaker::Make() { @@ -373,7 +372,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { } } -framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( +phi::KernelKey BatchNormGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -392,18 +391,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( +phi::KernelKey BatchNormGradOp::GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -412,13 +411,12 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } template @@ -515,7 +513,7 @@ void BatchNormDoubleGradOp::InferShape( } } -framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( +phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar("DY"); if (var == nullptr) { @@ -532,8 +530,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 0e579010a9..d6a1038c00 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -47,13 +47,13 @@ class BatchNormOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class BatchNormGradOp : public framework::OperatorWithKernel { @@ -62,13 +62,13 @@ class BatchNormGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class BatchNormDoubleGradOp : public framework::OperatorWithKernel { @@ -77,7 +77,7 @@ class BatchNormDoubleGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 3c775ced3f..d1be450e81 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -28,11 +28,10 @@ class BCELossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -87,11 +86,10 @@ class BCELossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 49669f1b35..1e569c4bb2 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -108,7 +108,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *scores = ctx.Input("scores"); size_t level = ctx.Attr("level"); @@ -116,11 +116,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { // The current CUDA kernel only support cases with batch_size < 4. // Compute on CPU for cases with batch_size > 4. if (batch_size <= 4) { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), ctx.GetPlace()); } else { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/bilateral_slice_op.cc b/paddle/fluid/operators/bilateral_slice_op.cc index 8b7968d2a8..c824fd9e63 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cc +++ b/paddle/fluid/operators/bilateral_slice_op.cc @@ -85,10 +85,10 @@ class BilateralSliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -147,11 +147,11 @@ class BilateralSliceOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/bincount_op.cc b/paddle/fluid/operators/bincount_op.cc index 5f5e19c585..484431eeef 100644 --- a/paddle/fluid/operators/bincount_op.cc +++ b/paddle/fluid/operators/bincount_op.cc @@ -24,19 +24,17 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; - class BincountOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto data_type = ctx.HasInput("Weights") ? OperatorWithKernel::IndicateVarDataType(ctx, "Weights") : OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 20ea0b187f..47aea12443 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -56,11 +56,10 @@ class BprLossOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of Seq-bpr // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -119,11 +118,10 @@ class BprLossGradientOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/broadcast_tensors_op.cc b/paddle/fluid/operators/broadcast_tensors_op.cc index 34a76e86aa..6d92464419 100644 --- a/paddle/fluid/operators/broadcast_tensors_op.cc +++ b/paddle/fluid/operators/broadcast_tensors_op.cc @@ -27,14 +27,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // Broadcast semantics enforces all input variables having the same // DataType/VarType // This condition is also checked during VarType Inference // Here we simply copy input type to output - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -127,11 +127,11 @@ class BroadcastTensorsGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 10b25fc478..192fe35a9b 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -75,7 +75,7 @@ class CastOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // CastOp kernel's device type is decided by input tensor place auto *tensor = ctx.Input("X"); @@ -86,9 +86,8 @@ class CastOp : public framework::OperatorWithKernel { auto &tensor_place = tensor->place(); // NOTE: cuda pinned tensor need to copy its data to target place if (platform::is_cuda_pinned_place(tensor_place)) { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + ctx.device_context().GetPlace()); } // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN @@ -108,20 +107,19 @@ class CastOp : public framework::OperatorWithKernel { auto src_type = static_cast(ctx.Attr("in_dtype")); auto dst_type = static_cast(ctx.Attr("out_dtype")); if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), tensor_place); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + tensor_place); } else { VLOG(3) << "MLU not support cast type: " << framework::DataTypeToString(src_type) << " to type: " << framework::DataTypeToString(dst_type) << ", fallbacking to CPU one!"; - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), - platform::CPUPlace()); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + platform::CPUPlace()); } #endif - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), tensor_place); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + tensor_place); } }; diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index f168eb10ae..4639e53350 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -53,11 +53,10 @@ class CenterLossOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -115,11 +114,11 @@ class CenterLossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/chunk_eval_op.cc b/paddle/fluid/operators/chunk_eval_op.cc index 6ad9f6d491..71268eb12d 100644 --- a/paddle/fluid/operators/chunk_eval_op.cc +++ b/paddle/fluid/operators/chunk_eval_op.cc @@ -88,10 +88,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc index c0dafd8534..6e44a5ce2e 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc @@ -57,10 +57,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { * specified a data type here. * */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index 8147541cba..4ce45aeea9 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -117,10 +117,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel { * Of course, the data type here is also not important. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/class_center_sample_op.cc b/paddle/fluid/operators/class_center_sample_op.cc index cb766dae22..54f0e981ca 100644 --- a/paddle/fluid/operators/class_center_sample_op.cc +++ b/paddle/fluid/operators/class_center_sample_op.cc @@ -26,11 +26,10 @@ class ClassCenterSampleOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Label"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Label"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 997c017d31..1fdc4e9a12 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -26,11 +26,11 @@ namespace operators { class ClipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -85,11 +85,11 @@ class ClipOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 75e6df4baf..e16950e31d 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -405,20 +405,20 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &context) const override { auto dtype = static_cast( context.Attr("dtype")); - return framework::OpKernelType(dtype, context.GetPlace()); + return phi::KernelKey(dtype, context.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/collective/allreduce_op.cc b/paddle/fluid/operators/collective/allreduce_op.cc index b3351dc82b..e136d8ef6e 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cc @@ -27,10 +27,10 @@ class AllReduceOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/alltoall_op.cc b/paddle/fluid/operators/collective/alltoall_op.cc index b5512fdc52..e6fa37e0e4 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cc @@ -36,10 +36,10 @@ class AllToAllOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 268144f183..f63c4a9abc 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -71,21 +71,23 @@ class CAllReduceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { if (var_name == "Cond") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index 49b1bd5fd9..35c395681e 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -26,10 +26,10 @@ class CBroadcastOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_concat_op.cc b/paddle/fluid/operators/collective/c_concat_op.cc index 75e41dba92..ed29654048 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cc @@ -58,10 +58,10 @@ class CConcatOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_embedding_op.cc b/paddle/fluid/operators/collective/c_embedding_op.cc index caea70c223..aa16d8e182 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cc +++ b/paddle/fluid/operators/collective/c_embedding_op.cc @@ -65,10 +65,10 @@ class CEmbeddingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -149,11 +149,11 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc index 8d743139d0..55728b21fb 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -35,10 +35,10 @@ class CIdentityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 3e752011f1..680af73af6 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -66,10 +66,10 @@ class CReduceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/c_scatter_op.cc index d6d4cc03dc..d122ffbcd9 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cc @@ -52,10 +52,10 @@ class CScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc index a8c19b8638..97d72457a2 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc @@ -73,11 +73,10 @@ class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -150,11 +149,11 @@ class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_split_op.cc b/paddle/fluid/operators/collective/c_split_op.cc index 5c6f126b78..52ce38cd17 100644 --- a/paddle/fluid/operators/collective/c_split_op.cc +++ b/paddle/fluid/operators/collective/c_split_op.cc @@ -66,10 +66,10 @@ class CSplitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.h b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h index 5b26e47a8f..da3fdd3453 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.h +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h @@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + +#pragma once + #include #include "paddle/fluid/framework/op_registry.h" @@ -25,10 +28,9 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 67fff76551..ff7fb09f7a 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -23,10 +23,9 @@ class CSyncCommStreamOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/global_gather_op.cc b/paddle/fluid/operators/collective/global_gather_op.cc index ee8cc39f44..f3380b4498 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cc @@ -49,10 +49,10 @@ class GlobalGatherOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/global_scatter_op.cc b/paddle/fluid/operators/collective/global_scatter_op.cc index 5d81acb226..d4469c5ead 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cc @@ -52,10 +52,10 @@ class GlobalScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/partial_recv_op.cc b/paddle/fluid/operators/collective/partial_recv_op.cc index f0effde61b..37e060acc2 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cc @@ -80,12 +80,12 @@ class PartialRecvOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int dtype = ctx.Attr("dtype"); framework::proto::VarType::Type type = framework::proto::VarType::Type(dtype); - return framework::OpKernelType(type, ctx.GetPlace()); + return phi::KernelKey(type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/partial_send_op.cc b/paddle/fluid/operators/collective/partial_send_op.cc index f11973e20c..59ab1cfa6e 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cc @@ -51,10 +51,10 @@ class PartialSendOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc index a35e1a7dda..2b51e913fb 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -69,12 +69,12 @@ class RecvOpV2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int dtype = ctx.Attr("dtype"); framework::proto::VarType::Type type = framework::proto::VarType::Type(dtype); - return framework::OpKernelType(type, ctx.GetPlace()); + return phi::KernelKey(type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc index 8323b98415..5652f07990 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -38,7 +38,7 @@ class SendOpV2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const framework::Variable* var = ctx.InputVar("X"); if (var->IsType()) { @@ -46,12 +46,11 @@ class SendOpV2 : public framework::OperatorWithKernel { // NOTE(sandyhouse): Support an empty tensor array as Input. // And set the kernel type is float. if (t_arr.size() == 0) { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 0c6e7b31c9..21e4bfcf70 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -32,7 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -48,18 +48,20 @@ class ConcatOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "All Inputs of Concat OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -110,22 +112,24 @@ class ConcatOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cc b/paddle/fluid/operators/controlflow/bitwise_op.cc index 4b339f4bd5..0922c9f5d4 100644 --- a/paddle/fluid/operators/controlflow/bitwise_op.cc +++ b/paddle/fluid/operators/controlflow/bitwise_op.cc @@ -97,11 +97,12 @@ class UnaryBitwiseOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // BitwiseOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; @@ -138,11 +139,12 @@ class BinaryBitwiseOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // BitwiseOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index ba580f4097..26d0dce91c 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -61,19 +61,20 @@ class CompareOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // CompareOp kernel's device type is decided by input tensor place bool force_cpu = ctx.Attr("force_cpu"); if (force_cpu) { - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); } else { if (ctx.Input("X")->place().GetType() != phi::AllocationType::GPUPINNED) { - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); } else { - kt.place_ = ctx.GetPlace(); + kt.set_backend(phi::TransToPhiBackend(ctx.GetPlace())); } } return kt; diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index b70211c1e1..5a99dd695c 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -72,48 +72,49 @@ class FetchV2Op : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (!tensor.IsInitialized()) { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *fetch_var = ctx.InputVar("X"); if (fetch_var == nullptr) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); if (!src_item.IsInitialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } else if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); if (!src_item.initialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } else { auto &src_item = fetch_var->Get(); if (src_item.empty() || !src_item[0].IsInitialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/controlflow/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc index c6dde6f4ba..6a9fcaf852 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -69,11 +69,12 @@ class LogicalOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // LogicalOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 0262c74923..e41270de65 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -186,7 +186,7 @@ std::vector ConvOp::ComputeOutputShape( return output_shape; } -framework::OpKernelType ConvOp::GetExpectedKernelType( +phi::KernelKey ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); // todo enable data layout when it's ready @@ -208,18 +208,18 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( paddle::framework::DataTypeToString(filter_data_type))); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType ConvOp::GetKernelTypeForVar( +phi::KernelKey ConvOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "Input") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -228,13 +228,12 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for conv // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Conv2DOpMaker::Make() { @@ -447,23 +446,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { } } -framework::OpKernelType ConvOpGrad::GetExpectedKernelType( +phi::KernelKey ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( +phi::KernelKey ConvOpGrad::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if (((var_name == "Input") || (var_name == framework::GradVarName("Output"))) && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -472,13 +471,12 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } template @@ -619,10 +617,10 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { } } -framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( +phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 62bcfb545e..29345f1432 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -196,13 +196,13 @@ class ConvOp : public framework::OperatorWithKernel { std::vector ComputeOutputShape( framework::InferShapeContext* ctx) const; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvOpGrad : public framework::OperatorWithKernel { @@ -211,13 +211,13 @@ class ConvOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvOpDoubleGrad : public framework::OperatorWithKernel { @@ -226,7 +226,7 @@ class ConvOpDoubleGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index ebc9f8afdb..e5333c5ed5 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -33,21 +33,21 @@ namespace operators { using DataLayout = phi::DataLayout; -framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( +phi::KernelKey ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( +phi::KernelKey ConvTransposeOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "Input") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -56,13 +56,12 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Conv2DTransposeOpMaker::Make() { @@ -253,10 +252,10 @@ Example: )DOC"); } -framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( +phi::KernelKey ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } template @@ -320,10 +319,10 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker { } }; -framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( +phi::KernelKey ConvTransposeOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index d47828e5bd..61860b6907 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -38,13 +38,13 @@ class ConvTransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvTransposeOpGrad : public framework::OperatorWithKernel { @@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -61,7 +61,7 @@ class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/correlation_op.cc b/paddle/fluid/operators/correlation_op.cc index 2b3450d031..c1b3fb25bc 100644 --- a/paddle/fluid/operators/correlation_op.cc +++ b/paddle/fluid/operators/correlation_op.cc @@ -109,7 +109,7 @@ class CorrelationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input1"); @@ -118,7 +118,7 @@ class CorrelationOp : public framework::OperatorWithKernel { ctx.Input("Input2")->dtype()), platform::errors::InvalidArgument( "X and Y shoule have the same datatype")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -158,9 +158,9 @@ class CorrelationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index 62bd73374b..5844beb9c0 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -202,9 +202,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 462764230f..b615fbd58f 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -58,11 +58,10 @@ class CropOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -182,11 +181,11 @@ class CropOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 99e67406c3..bdee50e773 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -126,11 +126,10 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { @@ -192,11 +191,11 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Y")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context().GetPlace()); } virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index 7731b72071..1df3def180 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -35,11 +35,10 @@ class CTCAlignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index f5fd56edef..8a004ac4a2 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -95,11 +95,10 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -249,11 +248,11 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { SetOutGradDim("InitH"); SetOutGradDim("InitC"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 29bc83bd9a..6d1089ecf7 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -25,11 +25,11 @@ class CumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 11af33df2f..54fa8e0031 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -48,11 +48,10 @@ class CVMOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // cvm // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -114,11 +113,11 @@ class CVMGradientOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // cvm // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Y")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 6770a7e31c..f4c850c423 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -159,7 +159,7 @@ class DataNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -195,7 +195,7 @@ class DataNormOp : public framework::OperatorWithKernel { "bias input should be of float type")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -475,7 +475,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -494,7 +494,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { } auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/decode_jpeg_op.cc b/paddle/fluid/operators/decode_jpeg_op.cc index acb94f57bf..521798e8dd 100644 --- a/paddle/fluid/operators/decode_jpeg_op.cc +++ b/paddle/fluid/operators/decode_jpeg_op.cc @@ -32,24 +32,23 @@ class DecodeJpegOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "X") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } }; diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index b916d069d1..d6eff438e0 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -113,11 +113,10 @@ class DeformableConvOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -173,11 +172,10 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index ed70e54678..a597c1e003 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -118,11 +118,10 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -172,11 +171,10 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index 5240116c6a..6e284d8e7b 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -290,11 +290,10 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -338,11 +337,10 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cc b/paddle/fluid/operators/dequantize_abs_max_op.cc index 99c4fad0fa..bf329324ff 100644 --- a/paddle/fluid/operators/dequantize_abs_max_op.cc +++ b/paddle/fluid/operators/dequantize_abs_max_op.cc @@ -68,11 +68,10 @@ class DequantizeMaxAbsOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - auto type = framework::OpKernelType(data_type, ctx.device_context()); - return type; + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_log_op.cc b/paddle/fluid/operators/dequantize_log_op.cc index 62359a2ce2..94299e153d 100644 --- a/paddle/fluid/operators/dequantize_log_op.cc +++ b/paddle/fluid/operators/dequantize_log_op.cc @@ -75,11 +75,10 @@ class DequantizeLogOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - auto type = framework::OpKernelType(data_type, ctx.device_context()); - return type; + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index c39f351fcb..9bb7b0eaa1 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -19,15 +19,14 @@ limitations under the License. */ namespace paddle { namespace operators { -framework::OpKernelType DeQuantOp::GetExpectedKernelType( +phi::KernelKey DeQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/dequantize_op.h b/paddle/fluid/operators/dequantize_op.h index f319828a6b..4aee7502d6 100644 --- a/paddle/fluid/operators/dequantize_op.h +++ b/paddle/fluid/operators/dequantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class DeQuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class DeQuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 530b5a1ee1..7a1397ba08 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -61,11 +61,10 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index 583122b473..8bf542e17c 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -50,9 +50,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index e07e4034f3..8c607c98c1 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -87,11 +87,11 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index 8b74f46cd3..def0f3f6d8 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -105,10 +105,10 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 20b8846bc4..9fa761abcf 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -27,10 +27,10 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 7ae5ba6ca8..6acc043176 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -110,10 +110,10 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index b11030f1d0..dcffa170b6 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -160,10 +160,10 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 030b99cd1d..d6987c7ba8 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -62,11 +62,11 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc index 0445c21b1d..885a357566 100644 --- a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc @@ -34,11 +34,11 @@ class GenerateProposalsV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/locality_aware_nms_op.cc b/paddle/fluid/operators/detection/locality_aware_nms_op.cc index 1c5135fc4e..9a230dc322 100644 --- a/paddle/fluid/operators/detection/locality_aware_nms_op.cc +++ b/paddle/fluid/operators/detection/locality_aware_nms_op.cc @@ -79,9 +79,9 @@ class LocalityAwareNMSOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/matrix_nms_op.cc b/paddle/fluid/operators/detection/matrix_nms_op.cc index 1beeaf1ba3..8038e4a42c 100644 --- a/paddle/fluid/operators/detection/matrix_nms_op.cc +++ b/paddle/fluid/operators/detection/matrix_nms_op.cc @@ -25,9 +25,9 @@ class MatrixNMSOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index 28099630b8..a673d64c52 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -316,9 +316,9 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 79077b3086..9dc6a8cc1f 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -112,9 +112,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index 66682c6787..9171b9ab25 100644 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -69,10 +69,10 @@ class NMSOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 28251c32dd..be1e224cd3 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -26,11 +26,11 @@ class PriorBoxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_input_type, ctx.GetPlace()); + return phi::KernelKey(input_input_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc index d2654e086d..a36d6a9f6c 100644 --- a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -166,12 +166,12 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Scores"); - return framework::OpKernelType(input_data_type, - platform::CPUPlace()); // ctx.GetPlace()); + return phi::KernelKey(input_data_type, + platform::CPUPlace()); // ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 9ba51850eb..27442c5dad 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -559,11 +559,10 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -585,11 +584,10 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index ba7fe51383..531c823442 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -94,9 +94,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } @@ -851,9 +851,9 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc index 91479a78b6..ff27945d18 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -89,11 +89,10 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -180,11 +179,10 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index c3d79b0505..155ec31fa9 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -77,11 +77,10 @@ class TargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index fbf4b55dfe..a60f42de66 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -129,10 +129,10 @@ class YoloBoxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 0b8fc79826..21aca33f65 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -26,11 +26,10 @@ class Yolov3LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -179,11 +178,10 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index ada4d18eb0..2620fa3a8f 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -89,9 +89,9 @@ class DetectionMAPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index 56e39747af..62cecbd36a 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -70,11 +70,11 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.cc b/paddle/fluid/operators/dgc_clip_by_norm_op.cc index 7c75949039..2f8d7ca96f 100644 --- a/paddle/fluid/operators/dgc_clip_by_norm_op.cc +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.cc @@ -31,13 +31,15 @@ class DGCClipByNormOp : public ClipByNormOp { return ClipByNormOp::InferShape(ctx); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/dgc_op.cc b/paddle/fluid/operators/dgc_op.cc index e247ab05eb..171dc84000 100644 --- a/paddle/fluid/operators/dgc_op.cc +++ b/paddle/fluid/operators/dgc_op.cc @@ -45,13 +45,15 @@ class DGCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step" || var_name == "k" || var_name == "nranks") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 804834a974..c6ee1180b5 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -27,24 +27,26 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Seed") { VLOG(10) << "var_name:" << var_name << " does not need to transform in dropout op"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -133,11 +135,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/edit_distance_op.cc b/paddle/fluid/operators/edit_distance_op.cc index c4c5db6b50..5eef3d72b3 100644 --- a/paddle/fluid/operators/edit_distance_op.cc +++ b/paddle/fluid/operators/edit_distance_op.cc @@ -24,10 +24,10 @@ class EditDistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/eigvalsh_op.cc b/paddle/fluid/operators/eigvalsh_op.cc index 9d09b96280..27c70f1e9b 100644 --- a/paddle/fluid/operators/eigvalsh_op.cc +++ b/paddle/fluid/operators/eigvalsh_op.cc @@ -66,11 +66,11 @@ class EigvalshGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Eigenvectors"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 5f169e20e3..458fc7afb9 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -68,11 +68,11 @@ class EinsumGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 8c7aa350b4..4c1afa3f6c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -41,25 +41,22 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 7f233eba88..8d1b52325d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -27,26 +27,23 @@ class ElementwiseMulOp : public ElementwiseOp { public: using ElementwiseOp::ElementwiseOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 1ed8f4eb01..7048cf5029 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -151,39 +151,36 @@ class ElementwiseOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When elementwise is first oneDNN op (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -300,26 +297,23 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -345,25 +339,22 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -380,7 +371,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type input_data_type; if (ctx.HasInput("DDX") == false) { @@ -399,22 +390,19 @@ class ElementwiseOpDoubleGradWithoutDXDY input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY"); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -446,26 +434,23 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type input_data_type; input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index 47dc2eb383..a5c707d460 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -53,21 +53,23 @@ class EmptyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& context) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(context.Attr("dtype")), context.GetPlace()); } diff --git a/paddle/fluid/operators/expand_as_op.cc b/paddle/fluid/operators/expand_as_op.cc index b793d835fc..107fe9f617 100644 --- a/paddle/fluid/operators/expand_as_op.cc +++ b/paddle/fluid/operators/expand_as_op.cc @@ -106,11 +106,11 @@ class ExpandAsGradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 09dc0f68cc..5e0f98c3ee 100644 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -25,11 +25,10 @@ class ExpandAsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -77,11 +76,11 @@ class ExpandAsV2GradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 67b8102181..43fd505acd 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -77,22 +77,23 @@ class ExpandOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_times_tensor" || var_name == "ExpandTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -206,22 +207,24 @@ class ExpandGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_times_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 6bf40fd3bb..cbd322f387 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -33,22 +33,24 @@ class ExpandV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -150,22 +152,24 @@ class ExpandV2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/exponential_op.cc b/paddle/fluid/operators/exponential_op.cc index 26e06e50a7..52ddd9ebfa 100644 --- a/paddle/fluid/operators/exponential_op.cc +++ b/paddle/fluid/operators/exponential_op.cc @@ -24,10 +24,10 @@ class ExponentialOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc index 629400a403..57582c694e 100644 --- a/paddle/fluid/operators/eye_op.cc +++ b/paddle/fluid/operators/eye_op.cc @@ -25,9 +25,9 @@ class EyeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index a5742af742..65e4b28326 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -405,11 +405,10 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -472,10 +471,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -553,10 +552,10 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -631,11 +630,10 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -711,11 +709,10 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -791,10 +788,10 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -847,11 +844,11 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index d4d160d315..94b5ba1c5c 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -124,11 +124,11 @@ class FCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index bf79a98d21..2efe0eeb72 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -31,23 +31,24 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); const auto &data_type = ctx.Attr("dtype"); if (data_type >= 0) { - kt.data_type_ = static_cast(data_type); + kt.set_dtype(phi::TransToPhiDataType( + static_cast(data_type))); } return kt; } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc index 871a8314c5..66c0470ac0 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc @@ -23,13 +23,13 @@ namespace operators { class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kernel_type = framework::OpKernelType( + phi::KernelKey kernel_type = phi::KernelKey( static_cast(ctx.Attr("dtype")), - ctx.device_context()); + ctx.GetPlace()); if (ctx.Attr("force_cpu")) { - kernel_type.place_ = platform::CPUPlace(); + kernel_type.set_backend(phi::Backend::CPU); } return kernel_type; } diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 82c6b89063..4b2ee8763c 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -56,46 +56,47 @@ class FillConstantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::proto::VarType::Type(ctx.Attr("dtype")); - framework::OpKernelType kt = - framework::OpKernelType(input_data_type, ctx.GetPlace()); + phi::KernelKey kt = phi::KernelKey(input_data_type, ctx.GetPlace()); // TODO(zyfncg) The force_cpu and place_type are conflicted, it's an issue // left before, and we may merge them in the future. // In order to invoke new fill_constant kernel, the place of OpKernelType // will be setted by force_cpu and place_type here. if (ctx.Attr("force_cpu")) { - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); } auto place_type = ctx.Attr("place_type"); if (place_type != -1) { switch (place_type) { case 0: - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); break; case 1: case 2: - kt.place_ = platform::CUDAPlace(); + kt.set_backend(phi::Backend::GPU); break; case 3: - kt.place_ = platform::XPUPlace(); + kt.set_backend(phi::Backend::XPU); break; case 4: - kt.place_ = platform::NPUPlace(); + kt.set_backend(phi::Backend::NPU); break; default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/operators/fill_diagonal_op.cc b/paddle/fluid/operators/fill_diagonal_op.cc index 8a7f5daa9f..373a63b7ff 100644 --- a/paddle/fluid/operators/fill_diagonal_op.cc +++ b/paddle/fluid/operators/fill_diagonal_op.cc @@ -50,10 +50,10 @@ class FillIDiagonalOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -71,12 +71,12 @@ class FillIDiagonalGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // Note: don't get data type from ctx.Input("Input"); auto dtype = framework::TransToProtoVarType( ctx.Input(framework::GradVarName("Out"))->type()); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 8937676c34..bcb7081847 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -51,9 +51,9 @@ class FillOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/fill_zeros_like_op.cc b/paddle/fluid/operators/fill_zeros_like_op.cc index 8bd0e328c1..aff240ca4a 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cc @@ -55,9 +55,9 @@ class FillZerosLikeOp2 : public FillZerosLikeOp { using FillZerosLikeOp::FillZerosLikeOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/filter_by_instag_op.cc b/paddle/fluid/operators/filter_by_instag_op.cc index 808792468f..3fe43017eb 100644 --- a/paddle/fluid/operators/filter_by_instag_op.cc +++ b/paddle/fluid/operators/filter_by_instag_op.cc @@ -59,10 +59,10 @@ class FilterByInstagOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -126,11 +126,11 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 9b96e27ab7..54e35a6f03 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -81,11 +81,11 @@ class FlattenOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -153,11 +153,11 @@ class FlattenGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -217,11 +217,11 @@ class Flatten2Op : public framework::OperatorWithKernel { ctx->ShareLoD("X", "XShape"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -269,11 +269,11 @@ class Flatten2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -387,11 +387,11 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index c1acc9a38b..4f59e88bd0 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -65,15 +65,10 @@ class FSPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context(), - layout_, - library_); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -131,11 +126,11 @@ class FSPOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index f25dc393d3..6b1f533b34 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -253,11 +253,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -588,11 +588,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 02494e33e1..a6fa80a493 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -60,11 +60,11 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -190,11 +190,11 @@ class FusedBiasDropoutResidualLnGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.cc b/paddle/fluid/operators/fused/fused_bn_activation_op.cc index e68be43eb7..88b11f1ef3 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.cc @@ -156,7 +156,7 @@ void FusedBatchNormActOp::InferShape(framework::InferShapeContext *ctx) const { ctx->ShareLoD("X", "Y"); } -framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormActOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -187,11 +187,7 @@ framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType( platform::errors::PreconditionNotMet( "Variance input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void FusedBatchNormActOpMaker::Make() { @@ -297,7 +293,7 @@ void FusedBatchNormActGradOp::InferShape( ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } -framework::OpKernelType FusedBatchNormActGradOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormActGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -315,14 +311,8 @@ framework::OpKernelType FusedBatchNormActGradOp::GetExpectedKernelType( platform::errors::NotFound("Can not get the tensor value of Y@GRAD.")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.h b/paddle/fluid/operators/fused/fused_bn_activation_op.h index b71812db9d..78ba849eaa 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.h @@ -33,7 +33,7 @@ class FusedBatchNormActOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -43,7 +43,7 @@ class FusedBatchNormActGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc index 08f7087b48..58a950f923 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc @@ -134,7 +134,7 @@ void FusedBatchNormAddActOp::InferShape( ctx->ShareLoD("X", "Y"); } -framework::OpKernelType FusedBatchNormAddActOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormAddActOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -152,11 +152,7 @@ framework::OpKernelType FusedBatchNormAddActOp::GetExpectedKernelType( ctx.Input("Bias")->dtype()), platform::errors::InvalidArgument("Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void FusedBatchNormAddActOpMaker::Make() { @@ -255,7 +251,7 @@ void FusedBatchNormAddActGradOp::InferShape( ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } -framework::OpKernelType FusedBatchNormAddActGradOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormAddActGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -273,14 +269,8 @@ framework::OpKernelType FusedBatchNormAddActGradOp::GetExpectedKernelType( platform::errors::NotFound("Can not get the tensor value of Y@GRAD.")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h index bdb1f2f354..2d20a880e7 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h @@ -33,7 +33,7 @@ class FusedBatchNormAddActOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -43,7 +43,7 @@ class FusedBatchNormAddActGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 8c81a646fd..2e7152ecb2 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -172,14 +172,14 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.Input("X")->dtype(), ctx.Input("Y")->dtype(), platform::errors::InvalidArgument( "The element's type of input should be the same.")); - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -389,11 +389,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc index 4f8c4d12d6..232321c65b 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc @@ -103,7 +103,7 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("Embs"); auto input_data_type = framework::proto::VarType::Type(0); @@ -119,7 +119,7 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::PreconditionNotMet( "All Inputs of fused_embedding_eltwise_layernorm OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 11b9044bc5..bec18220e9 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -169,11 +169,11 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->ShareLoD("Ids", "XX"); } -framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( +phi::KernelKey FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"), - ctx.device_context()); + ctx.GetPlace()); } void FusedEmbeddingFCLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h index 19039ec559..29db2e9961 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h @@ -25,7 +25,7 @@ class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index bbb5ce50c9..a5f20ffadc 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -72,10 +72,10 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -141,10 +141,10 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 3bf039829a..b6edf5ed44 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -120,10 +120,10 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -344,11 +344,11 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index b4fc1b57d8..187eb4fc07 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -144,12 +144,10 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -318,12 +316,10 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc index e1be5afa0b..ca9edb682b 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -166,22 +166,24 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 89d2275e06..5448578c26 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -133,22 +133,24 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 95c82c72ef..79ad83ab14 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -93,7 +93,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -109,10 +109,10 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of fused_seqpool_cvm OP are Empty!")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - // return framework::OpKernelType(framework::proto::VarType::FP32, + return phi::KernelKey(input_data_type, ctx.GetPlace()); + // return phi::KernelKey(framework::proto::VarType::FP32, // ctx.device_context()); - // return framework::OpKernelType( + // return phi::KernelKey( // OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -210,11 +210,11 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc index 9df2219910..7b737d6885 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -71,11 +71,10 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_group_op.cc b/paddle/fluid/operators/fused/fusion_group_op.cc index 36b97ea7b1..362819d97f 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.cc +++ b/paddle/fluid/operators/fused/fusion_group_op.cc @@ -76,10 +76,10 @@ class FusionGroupOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CUDAPlace(0)); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CUDAPlace(0)); }; }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index fc7804f9c4..b7e4fa5493 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -147,10 +147,10 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "XX"); } -framework::OpKernelType FusionGRUOp::GetExpectedKernelType( +phi::KernelKey FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_gru_op.h b/paddle/fluid/operators/fused/fusion_gru_op.h index 94bf38068d..e811df6550 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.h +++ b/paddle/fluid/operators/fused/fusion_gru_op.h @@ -25,7 +25,7 @@ class FusionGRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index c526fdc184..57d40f1cae 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -170,10 +170,10 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "XX"); } -framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( +phi::KernelKey FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.h b/paddle/fluid/operators/fused/fusion_lstm_op.h index 93f8eb981b..c62060d7c2 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.h +++ b/paddle/fluid/operators/fused/fusion_lstm_op.h @@ -25,7 +25,7 @@ class FusionLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index bab06f55be..154b0366ee 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -99,10 +99,10 @@ void FusionRepeatedFCReluOp::InferShape( ctx->ShareLoD("X", /*->*/ "Out"); } -framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( +phi::KernelKey FusionRepeatedFCReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionRepeatedFCReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h index 16025bf518..62eae8f7c0 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h @@ -25,7 +25,7 @@ class FusionRepeatedFCReluOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index c916691963..e9428aea00 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -88,10 +88,10 @@ void FusionSeqConvEltAddReluOp::InferShape( ctx->ShareLoD("X", "Out"); } -framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( +phi::KernelKey FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h index 96f231f9a3..42e0c57b11 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h @@ -25,7 +25,7 @@ class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index dd5b3c0073..86eb7053f8 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -102,10 +102,10 @@ void FusionSeqExpandConcatFCOp::InferShape( ctx->ShareLoD("X", "Out", 0); } -framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( +phi::KernelKey FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h index 495de5f233..7438b6c717 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h @@ -25,7 +25,7 @@ class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index f2f7801d7c..9fe789e310 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -68,10 +68,10 @@ void FusionSeqPoolConcatOp::InferShape( } } -framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( +phi::KernelKey FusionSeqPoolConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqPoolConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h index 2e2d6e07dc..5761330a76 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h @@ -25,7 +25,7 @@ class FusionSeqPoolConcatOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc index e3953f9e6a..f9ee16eb81 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -67,10 +67,10 @@ void FusionSeqPoolCVMConcatOp::InferShape( ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast(n)}); } -framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( +phi::KernelKey FusionSeqPoolCVMConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqPoolCVMConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h index b9d7d0dfc3..6d45ad4cb9 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h @@ -25,7 +25,7 @@ class FusionSeqPoolCVMConcatOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 8d7f792f3c..67fcc65274 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -63,10 +63,10 @@ void FusionSquaredMatSubOp::InferShape( ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]}); } -framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( +phi::KernelKey FusionSquaredMatSubOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSquaredMatSubOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h index fc6a54fd9e..41bde97c4b 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h @@ -26,7 +26,7 @@ class FusionSquaredMatSubOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 0552c3ce9b..b66ea9b202 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -138,13 +138,12 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); } -framework::OpKernelType MultiGRUOp::GetExpectedKernelType( +phi::KernelKey MultiGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"))); } void MultiGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/multi_gru_op.h b/paddle/fluid/operators/fused/multi_gru_op.h index 1846d81960..956fcce59c 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.h +++ b/paddle/fluid/operators/fused/multi_gru_op.h @@ -28,7 +28,7 @@ class MultiGRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op.cc b/paddle/fluid/operators/fused/resnet_basic_block_op.cc index b449ca3bbe..d17e6c9872 100644 --- a/paddle/fluid/operators/fused/resnet_basic_block_op.cc +++ b/paddle/fluid/operators/fused/resnet_basic_block_op.cc @@ -219,7 +219,7 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -247,10 +247,7 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -545,21 +542,15 @@ class ResNetBasicBlockGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar(framework::GradVarName("Y")), platform::errors::NotFound( "Can not find Y@GRAD in the execution context.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc index 4b46dc76b2..05aa019a5a 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -200,7 +200,7 @@ class ResNetUnitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -217,10 +217,7 @@ class ResNetUnitOp : public framework::OperatorWithKernel { ctx.Input("BiasX")->dtype()), platform::errors::InvalidArgument( "Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -392,21 +389,15 @@ class ResNetUnitGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar(framework::GradVarName("Y")), platform::errors::NotFound( "Can not find Y@GRAD in the execution context.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 4907153a11..4b85dee9a2 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -32,21 +32,22 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Axis") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -55,21 +56,23 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Axis") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc index d98721cfff..84f5479a61 100644 --- a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc @@ -23,9 +23,9 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 0f81d7fec3..03c1c4dd64 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -51,22 +51,24 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index b28c8bdc1a..207589edd5 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -254,7 +254,7 @@ paddle::small_vector outputs { {% macro get_expected_kernel(op) %} {% set kernel = op["kernel"] %} -framework::OpKernelType GetExpectedKernelType( +phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { {%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {% if kernel["data_type"]["candidates"] | length == 1 %} @@ -273,7 +273,7 @@ framework::OpKernelType GetExpectedKernelType( } {% endif %} {% endif %} - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } {% endmacro %} diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index 658352d844..7df3a292e5 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -47,11 +47,10 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_khop_sampler_op.cc b/paddle/fluid/operators/graph_khop_sampler_op.cc index 4702d66c3c..1cb5ac3c30 100644 --- a/paddle/fluid/operators/graph_khop_sampler_op.cc +++ b/paddle/fluid/operators/graph_khop_sampler_op.cc @@ -90,11 +90,10 @@ class GraphKhopSamplerOP : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Row"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_reindex_op.cc b/paddle/fluid/operators/graph_reindex_op.cc index 7bdd1708b6..c24af3f16d 100644 --- a/paddle/fluid/operators/graph_reindex_op.cc +++ b/paddle/fluid/operators/graph_reindex_op.cc @@ -25,11 +25,10 @@ class GraphReindexOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_sample_neighbors_op.cc b/paddle/fluid/operators/graph_sample_neighbors_op.cc index 14f17f77dc..0e7a1c97b7 100644 --- a/paddle/fluid/operators/graph_sample_neighbors_op.cc +++ b/paddle/fluid/operators/graph_sample_neighbors_op.cc @@ -25,11 +25,10 @@ class GraphSampleNeighborsOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Row"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index 9e57884c14..afdbaf0ca7 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -25,11 +25,10 @@ class GraphSendRecvOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -43,11 +42,11 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_send_ue_recv_op.cc b/paddle/fluid/operators/graph_send_ue_recv_op.cc index 561c7e06f0..2a252bcf70 100644 --- a/paddle/fluid/operators/graph_send_ue_recv_op.cc +++ b/paddle/fluid/operators/graph_send_ue_recv_op.cc @@ -25,11 +25,10 @@ class GraphSendUERecvOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -45,11 +44,11 @@ class GraphSendUERecvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 7331c792ea..90e15ef273 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -114,7 +114,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); @@ -132,8 +132,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Input(Y@GRAD) phi::DenseTensor of " "GroupNormGradOp should not be null")); - return framework::OpKernelType(framework::TransToProtoVarType(t->dtype()), - ctx.GetPlace()); + return phi::KernelKey(framework::TransToProtoVarType(t->dtype()), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 1c10692d15..ed7dfa0349 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -305,11 +305,11 @@ class GRUGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(weight_grad_name, weight_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 8e05454f1a..7bd104472f 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -270,11 +270,11 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(weight_grad_name, weight_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 7255abcb7b..e1de4a9a4d 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -66,10 +66,10 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -213,10 +213,10 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/identity_loss_op.cc b/paddle/fluid/operators/identity_loss_op.cc index bc9986c7ff..76e7f8a733 100644 --- a/paddle/fluid/operators/identity_loss_op.cc +++ b/paddle/fluid/operators/identity_loss_op.cc @@ -27,11 +27,10 @@ class IdentityLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -59,11 +58,11 @@ class IdentityLossGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, platform::CPUPlace()); + return phi::KernelKey(input_data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc index e2274d87c4..a2fdd53e03 100644 --- a/paddle/fluid/operators/imag_op.cc +++ b/paddle/fluid/operators/imag_op.cc @@ -58,12 +58,12 @@ class ImagGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); auto complex_dtype = framework::ToComplexType(dtype); - return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + return phi::KernelKey(complex_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/increment_op.cc b/paddle/fluid/operators/increment_op.cc index 342ef41d41..5fbde4f449 100644 --- a/paddle/fluid/operators/increment_op.cc +++ b/paddle/fluid/operators/increment_op.cc @@ -39,11 +39,12 @@ class IncrementOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // IncrementOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/index_add_op.cc b/paddle/fluid/operators/index_add_op.cc index b856e479fb..da3b720ae3 100644 --- a/paddle/fluid/operators/index_add_op.cc +++ b/paddle/fluid/operators/index_add_op.cc @@ -26,10 +26,10 @@ class IndexAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -79,11 +79,11 @@ class IndexAddGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index a80324d5d3..5acc9f1bd1 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -30,7 +30,7 @@ class InplaceABNOp : public paddle::operators::BatchNormOp { using paddle::operators::BatchNormOp::BatchNormOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -61,11 +61,7 @@ class InplaceABNOp : public paddle::operators::BatchNormOp { platform::errors::InvalidArgument( "Variance input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -135,7 +131,7 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto* var = ctx.InputVar(framework::GradVarName("Y")); auto input_data_type = framework::TransToProtoVarType( @@ -154,11 +150,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index c9f33799c9..289df565b8 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -29,7 +29,7 @@ limitations under the License. */ namespace paddle { namespace operators { -framework::OpKernelType InstanceNormOp::GetExpectedKernelType( +phi::KernelKey InstanceNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -54,7 +54,7 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType( "Bias input should be of float type")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void InstanceNormOpMaker::Make() { @@ -98,7 +98,7 @@ NCHW `[batch, in_channels, in_height, in_width]` )DOC"); } -framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( +phi::KernelKey InstanceNormGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -115,11 +115,11 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } -framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( +phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar("DY"); if (var == nullptr) { @@ -136,8 +136,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer, diff --git a/paddle/fluid/operators/instance_norm_op.h b/paddle/fluid/operators/instance_norm_op.h index 05e2bde973..9a885e47e4 100644 --- a/paddle/fluid/operators/instance_norm_op.h +++ b/paddle/fluid/operators/instance_norm_op.h @@ -29,7 +29,7 @@ class InstanceNormOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; @@ -38,7 +38,7 @@ class InstanceNormGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; @@ -47,7 +47,7 @@ class InstanceNormDoubleGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index c1b2ae3ea5..999e6df67c 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -337,18 +337,18 @@ class InterpolateOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -357,16 +357,17 @@ class InterpolateOp : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif if (var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -589,22 +590,24 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 95404bbd4a..e3c4b0be18 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -441,18 +441,18 @@ class InterpolateV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -461,18 +461,19 @@ class InterpolateV2Op : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif if (var_name == "OutSize" || var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -692,23 +693,25 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "OutSize" || var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index f03051e2a5..8f68ef13e4 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -47,7 +47,7 @@ class OverflowOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { int dtype = -1; auto *x_var = ctx.InputVar("X"); @@ -65,8 +65,8 @@ class OverflowOp : public framework::OperatorWithKernel { "The input type mismatch, the type of Input(X) must be Tensor or " "SelectedRows, please check your input.")); } - return framework::OpKernelType(framework::proto::VarType::Type(dtype), - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::Type(dtype), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 9a06fd369f..e45e686dd0 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -24,10 +24,10 @@ class KLDivLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -104,11 +104,11 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index 707d9a4700..6349ec65a9 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -29,26 +29,23 @@ class KronOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -110,27 +107,24 @@ class KronGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto out_grad_name = framework::GradVarName("Out"); - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 461d77f324..062e33f266 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -101,7 +101,7 @@ class LayerNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -113,7 +113,7 @@ class LayerNormOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -203,7 +203,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); PADDLE_ENFORCE_NOT_NULL( @@ -218,14 +218,8 @@ class LayerNormGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL( t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc index fbb091a78a..ffae23c702 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cc +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -35,7 +35,7 @@ class LimitByCapacityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the expert_count and capacity should be same as int64 auto expert_count_dtype = @@ -54,7 +54,7 @@ class LimitByCapacityOp : public framework::OperatorWithKernel { framework::proto::VarType::INT64, platform::errors::InvalidArgument("The dtype of the expert_count and " "capacity should be same as int64")); - return framework::OpKernelType(expert_count_dtype, ctx.GetPlace()); + return phi::KernelKey(expert_count_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index 64fe6562a6..26f90851d5 100644 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -298,9 +298,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of linear_chain_crf // is determined by its input "Emission". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), platform::CPUPlace()); } @@ -343,12 +343,11 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of output of the linear_chain_crf_grad // operator is determined by its input: gradients of LogLikelihood. - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("LogLikelihood")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("LogLikelihood")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index d9dcfbed59..e3fade6d61 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -28,22 +28,24 @@ class LinspaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (platform::is_xpu_place(tensor.place())) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 78c06e8c24..5f03e1304b 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -27,11 +27,9 @@ class LoadCombineOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = framework::OpKernelType( - framework::proto::VarType::FP32, ctx.GetPlace()); - return kt; + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/load_combine_op.h b/paddle/fluid/operators/load_combine_op.h index 16e53dbead..258275f403 100644 --- a/paddle/fluid/operators/load_combine_op.h +++ b/paddle/fluid/operators/load_combine_op.h @@ -116,14 +116,15 @@ class LoadCombineOpKernel : public framework::OpKernel { // Get data from fin to tensor paddle::framework::DeserializeFromStream(*buffer, tensor, dev_ctx); - auto in_dtype = framework::TransToProtoVarType(tensor->dtype()); - auto out_dtype = - load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { // convert to float16 tensor - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor fp16_tensor; // copy LoD info to the new tensor fp16_tensor.set_lod(tensor->lod()); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 0c66dbd365..434c0db2b8 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -26,11 +26,9 @@ class LoadOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = framework::OpKernelType( - framework::proto::VarType::FP32, ctx.GetPlace()); - return kt; + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/load_op_npu.cc b/paddle/fluid/operators/load_op_npu.cc index 8c00f08683..0e8517fd7b 100644 --- a/paddle/fluid/operators/load_op_npu.cc +++ b/paddle/fluid/operators/load_op_npu.cc @@ -85,13 +85,15 @@ class LoadOpKernel : public framework::OpKernel { } auto load_as_fp16 = ctx.Attr("load_as_fp16"); - auto in_dtype = framework::TransToProtoVarType(tensor->dtype()); - auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { // convert to float16 tensor - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor fp16_tensor; // copy LoD info to the new tensor fp16_tensor.set_lod(tensor->lod()); diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 374bb8920f..502afbf0c7 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -62,20 +62,19 @@ class LoDResetOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; @@ -202,11 +201,11 @@ class LoDResetGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/log_softmax_op.cc b/paddle/fluid/operators/log_softmax_op.cc index 99da0b08af..eb3ee5b7cd 100644 --- a/paddle/fluid/operators/log_softmax_op.cc +++ b/paddle/fluid/operators/log_softmax_op.cc @@ -29,11 +29,11 @@ class LogSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -86,11 +86,11 @@ class LogSoftmaxGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc index 5e5e25a56d..171ee209eb 100644 --- a/paddle/fluid/operators/logspace_op.cc +++ b/paddle/fluid/operators/logspace_op.cc @@ -28,9 +28,9 @@ class LogspaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/lookup_table_dequant_op.cc b/paddle/fluid/operators/lookup_table_dequant_op.cc index e0ca707ffa..09636f600a 100644 --- a/paddle/fluid/operators/lookup_table_dequant_op.cc +++ b/paddle/fluid/operators/lookup_table_dequant_op.cc @@ -85,10 +85,10 @@ class LookupTableDequantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 8ad3966a1d..6bb9f9ee19 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -67,10 +67,10 @@ class LookupTableOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -191,11 +191,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 84f8c6cf64..3af95c484f 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -67,10 +67,10 @@ class LookupTableV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -135,11 +135,11 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index ce31108aa5..5a6ed73047 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -222,18 +222,18 @@ class LRNOp : public framework::OperatorWithKernel { ctx->SetOutputDim("MidOut", x_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -242,13 +242,12 @@ class LRNOp : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for lrn // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -346,18 +345,18 @@ class LRNOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -366,13 +365,12 @@ class LRNOpGrad : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for lrn // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index b7310ed475..7250cf65e4 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -135,11 +135,10 @@ class LSTMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -304,11 +303,10 @@ class LSTMGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index dc36b3431d..63cf07e35b 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -143,11 +143,10 @@ class LSTMPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -388,11 +387,11 @@ class LSTMPGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index b02a2fe13a..bf19e28af0 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -26,7 +26,7 @@ class LstsqOp : public framework::OperatorWithKernel { protected: // The output of lstsq is always complex-valued even for real-valued inputs - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (dtype != framework::proto::VarType::FP32 && @@ -34,7 +34,7 @@ class LstsqOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "unsupported data type: %s!", dtype)); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index 923c14f3db..5a111f278b 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -44,10 +44,10 @@ class LUOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -105,10 +105,10 @@ class LUGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cc b/paddle/fluid/operators/margin_cross_entropy_op.cc index 9e9ee9c561..5688ca2fc7 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cc +++ b/paddle/fluid/operators/margin_cross_entropy_op.cc @@ -26,11 +26,11 @@ class MarginCrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; @@ -96,11 +96,11 @@ class MarginCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/marker_op.cc b/paddle/fluid/operators/marker_op.cc index 3de4f4451d..0cd3ccd686 100644 --- a/paddle/fluid/operators/marker_op.cc +++ b/paddle/fluid/operators/marker_op.cc @@ -30,10 +30,9 @@ class MarkerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index cba18b3cdb..b1c623b002 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -665,39 +665,36 @@ class MatMulOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When matmul is first oneDNN op in a chain (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -846,11 +843,11 @@ class MatMulOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 0a76f43175..c52fc08c91 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -131,38 +131,35 @@ class MatMulV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN // op previously) then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -195,26 +192,23 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index a63d3cb86f..16ca2cf09e 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -84,12 +84,10 @@ class MatrixRankOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index 0e75629f71..3728fbee53 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -40,9 +40,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 7715cf8773..0c628a4651 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -59,11 +59,11 @@ class MeanGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index 82feee0f69..06af45d485 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -37,20 +37,19 @@ class MemcpyD2HOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_h2d_op.cc b/paddle/fluid/operators/memcpy_h2d_op.cc index 1426b23dc1..8d3fc63154 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.cc +++ b/paddle/fluid/operators/memcpy_h2d_op.cc @@ -38,20 +38,19 @@ class MemcpyH2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_op.cc b/paddle/fluid/operators/memcpy_op.cc index 66cf6a00b7..f000a1cc0d 100644 --- a/paddle/fluid/operators/memcpy_op.cc +++ b/paddle/fluid/operators/memcpy_op.cc @@ -54,20 +54,19 @@ class MemcpyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/meshgrid_op.cc b/paddle/fluid/operators/meshgrid_op.cc index 7921e8844c..f813b9e341 100644 --- a/paddle/fluid/operators/meshgrid_op.cc +++ b/paddle/fluid/operators/meshgrid_op.cc @@ -30,7 +30,7 @@ class MeshgridOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -47,7 +47,7 @@ class MeshgridOp : public framework::OperatorWithKernel { "All Inputs of Meshgrid OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -94,11 +94,11 @@ class MeshgridGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index f8e57adc70..25e32b5197 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -24,10 +24,10 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Out"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 7529523bec..8910e61e42 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -26,11 +26,11 @@ class AucOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index 30302ceb82..0652151320 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -143,11 +143,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/moe_op.cc b/paddle/fluid/operators/moe_op.cc index 6832beeaa8..186ac1fc43 100644 --- a/paddle/fluid/operators/moe_op.cc +++ b/paddle/fluid/operators/moe_op.cc @@ -27,10 +27,10 @@ class MoeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 02537512c9..8236bdd599 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -25,16 +25,16 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -103,10 +103,10 @@ class MulGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index ba263427ca..c057d76730 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -28,11 +28,10 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -104,11 +103,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc index d57c5f18bd..f0bc985f3e 100644 --- a/paddle/fluid/operators/nanmedian_op.cc +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -28,10 +28,10 @@ class NanmedianOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -99,11 +99,11 @@ class NanmedianGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index b80de06279..286c851278 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -113,11 +113,10 @@ class NCEOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; @@ -279,11 +278,10 @@ class NCEOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/nop_op.cc b/paddle/fluid/operators/nop_op.cc index 876468f8a7..709b1f4f1f 100644 --- a/paddle/fluid/operators/nop_op.cc +++ b/paddle/fluid/operators/nop_op.cc @@ -25,10 +25,9 @@ class NopOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/number_count_op.cc b/paddle/fluid/operators/number_count_op.cc index 29f0a5bf57..e636bc98bf 100644 --- a/paddle/fluid/operators/number_count_op.cc +++ b/paddle/fluid/operators/number_count_op.cc @@ -28,7 +28,7 @@ class NumberCountOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the numbers should be same as int64 auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers"); @@ -37,7 +37,7 @@ class NumberCountOp : public framework::OperatorWithKernel { framework::proto::VarType::INT64, platform::errors::InvalidArgument( "The dtype of the number_dtype should be int64")); - return framework::OpKernelType(number_dtype, ctx.GetPlace()); + return phi::KernelKey(number_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 0cd6cab49e..ffb3081ca0 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -56,22 +56,23 @@ class OneHotOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index f5b55fcf02..a2ef01a89e 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -29,22 +29,23 @@ class OneHotV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 262aa0fc35..aa78843724 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -24,10 +24,10 @@ class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 54643a39bc..fc260c7e99 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -29,10 +29,10 @@ class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index cf447bc593..2a7dc7f311 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -23,23 +23,25 @@ class AdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || var_name == "SkipUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index 12429933e0..51397e210a 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -24,10 +24,10 @@ class AdamaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index 6c73439c62..8ae9f86ac4 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -80,10 +80,10 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc index 2b4b1c1a10..e8b719dc62 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc @@ -35,13 +35,15 @@ class DGCMomentumOp : public MomentumOp { return MomentumOp::InferShape(ctx); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step" || var_name == "nranks") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc index e32cf36251..ad99133f35 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc @@ -24,10 +24,10 @@ class DistributedFusedLambInitOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index d810f8df73..f7b8dacfc5 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -24,17 +24,19 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index f5710f2e7d..a752517ea8 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -72,10 +72,10 @@ class DpsgdOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 22be1f5ac6..d8110b5bbb 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -69,11 +69,11 @@ class FTRLOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc index df55ffa116..c6c4397332 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.cc +++ b/paddle/fluid/operators/optimizers/lamb_op.cc @@ -29,21 +29,23 @@ class LambOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index a5c641cc70..b5b15fa09e 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -133,11 +133,11 @@ class LarsMomentumOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cc b/paddle/fluid/operators/optimizers/merged_adam_op.cc index 867cfe0268..2be0d28a1a 100644 --- a/paddle/fluid/operators/optimizers/merged_adam_op.cc +++ b/paddle/fluid/operators/optimizers/merged_adam_op.cc @@ -23,23 +23,25 @@ class MergedAdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto param_dtype = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(param_dtype, ctx.GetPlace()); + return phi::KernelKey(param_dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || var_name == "SkipUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index 85b2f818fe..17d31e35fd 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -25,11 +25,11 @@ class MergedMomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto param_dtype = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(param_dtype, ctx.GetPlace()); + return phi::KernelKey(param_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index ad1ae55074..316f742a2f 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -114,11 +114,11 @@ class MomentumOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc index d3d45ad3c6..8def9c961f 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -31,11 +31,11 @@ class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel { ctx->SetOutputDim("StepOut", dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 598b84415f..076f5137ca 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -72,10 +72,10 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index 21b145ee49..d7e01aa071 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -52,10 +52,10 @@ class ProximalGDOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index b8883f22e9..ac445d30c3 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -28,7 +28,7 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); @@ -46,21 +46,18 @@ class SGDOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "LearningRate") { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.h b/paddle/fluid/operators/optimizers/sparse_momentum_op.h index 9eea5c11cb..7ea3b29cfa 100644 --- a/paddle/fluid/operators/optimizers/sparse_momentum_op.h +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.h @@ -176,11 +176,11 @@ class SparseMomentumOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 6686912941..91eeed0e90 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -696,7 +696,7 @@ class Pad2dOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -706,32 +706,31 @@ class Pad2dOp : public framework::OperatorWithKernel { ctx.Input("X") ->mem_desc() .data.format_desc.blocking.inner_nblks == 0) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -832,11 +831,11 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index f457151b70..0bfb02bc45 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -30,7 +30,7 @@ class Pad3dOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -40,32 +40,31 @@ class Pad3dOp : public framework::OperatorWithKernel { ctx.Input("X") ->mem_desc() .data.format_desc.blocking.inner_nblks == 0) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -183,11 +182,11 @@ class Pad3dOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 28d264ba8e..9b08bb3fc1 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -62,11 +62,10 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Y"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context().GetPlace()); } }; @@ -210,11 +209,10 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Y"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 2951091508..fd23f57793 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -32,10 +32,10 @@ class PadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -107,11 +107,11 @@ class PadOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/partial_concat_op.cc b/paddle/fluid/operators/partial_concat_op.cc index 01095b6d42..a8a7d82e46 100644 --- a/paddle/fluid/operators/partial_concat_op.cc +++ b/paddle/fluid/operators/partial_concat_op.cc @@ -89,7 +89,7 @@ class PartialConcatOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -105,7 +105,7 @@ class PartialConcatOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of PartialSum OP are Empty!")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -138,11 +138,11 @@ class PartialConcatGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/partial_sum_op.cc b/paddle/fluid/operators/partial_sum_op.cc index 6473f8d603..a2255d8e07 100644 --- a/paddle/fluid/operators/partial_sum_op.cc +++ b/paddle/fluid/operators/partial_sum_op.cc @@ -91,7 +91,7 @@ class PartialSumOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -108,7 +108,7 @@ class PartialSumOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of PartialSum OP are Empty!")); - return framework::OpKernelType(input_data_type, platform::CPUPlace()); + return phi::KernelKey(input_data_type, platform::CPUPlace()); } }; @@ -141,11 +141,11 @@ class PartialSumGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index c160dc28bf..b03f2954d2 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -42,7 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { (src_tz[src_tz.size() - 2] % ksize[0] == 0)); } -framework::OpKernelType PoolOp::GetExpectedKernelType( +phi::KernelKey PoolOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -50,15 +50,15 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType PoolOp::GetKernelTypeForVar( +phi::KernelKey PoolOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -67,16 +67,15 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } -framework::OpKernelType PoolOpGrad::GetExpectedKernelType( +phi::KernelKey PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -84,26 +83,26 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType PoolOpGrad::GetKernelTypeForVar( +phi::KernelKey PoolOpGrad::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Pool2dOpMaker::Make() { diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 9bb7572c10..a935c6b14f 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -24,13 +24,13 @@ class PoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class PoolOpGrad : public framework::OperatorWithKernel { @@ -38,13 +38,13 @@ class PoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 57aef714a0..74b98069bf 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -36,11 +36,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -49,11 +48,11 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index dc8a088ad2..3f4d812567 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -167,11 +167,10 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Score"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Score"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 8a2199e023..5100b4f869 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -30,11 +30,11 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -93,11 +93,11 @@ class PReluGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index ca291187b9..d1c455331b 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -135,11 +135,10 @@ class PRROIPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -161,11 +160,10 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cc b/paddle/fluid/operators/prune_gate_by_capacity_op.cc index 14494f426d..388b65f3dd 100644 --- a/paddle/fluid/operators/prune_gate_by_capacity_op.cc +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cc @@ -66,7 +66,7 @@ class PruneGateByCapacityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto gate_idx_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx"); @@ -82,7 +82,7 @@ class PruneGateByCapacityOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The dtype of the gate_idx and expert_count should " "be same as int64")); - return framework::OpKernelType(gate_idx_data_type, ctx.device_context()); + return phi::KernelKey(gate_idx_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc b/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc index 046269a396..e080f96e88 100644 --- a/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc @@ -78,9 +78,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc index 97391bc0e8..1950991b7b 100644 --- a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc +++ b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc @@ -51,9 +51,9 @@ class DistributedPushSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/pscore/send_and_recv_op.cc b/paddle/fluid/operators/pscore/send_and_recv_op.cc index d3f1d17e7a..d252621116 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op.cc @@ -60,10 +60,10 @@ class SendAndRecvOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index 1222f97c09..a853417923 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -83,11 +83,10 @@ class PSROIPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -96,11 +95,10 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pull_box_extended_sparse_op.cc b/paddle/fluid/operators/pull_box_extended_sparse_op.cc index 36ebc2ef67..7b949fa433 100644 --- a/paddle/fluid/operators/pull_box_extended_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_extended_sparse_op.cc @@ -72,10 +72,9 @@ class PullBoxExtendedSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -131,11 +130,11 @@ class PushBoxExtendedSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 14d8bacfa9..c58a176d52 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -56,10 +56,9 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -119,11 +118,11 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_gpups_sparse_op.cc b/paddle/fluid/operators/pull_gpups_sparse_op.cc index 052c5d3c8b..821cfdab6f 100644 --- a/paddle/fluid/operators/pull_gpups_sparse_op.cc +++ b/paddle/fluid/operators/pull_gpups_sparse_op.cc @@ -64,10 +64,9 @@ class PullGpuPSSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -129,11 +128,11 @@ class PushGpuPSSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_sparse_op.cc b/paddle/fluid/operators/pull_sparse_op.cc index 5023a620af..7dc9ae98e0 100644 --- a/paddle/fluid/operators/pull_sparse_op.cc +++ b/paddle/fluid/operators/pull_sparse_op.cc @@ -58,10 +58,9 @@ class PullSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -127,11 +126,11 @@ class PushSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_sparse_v2_op.cc b/paddle/fluid/operators/pull_sparse_v2_op.cc index c0c7c4e036..88a0ac86c2 100644 --- a/paddle/fluid/operators/pull_sparse_v2_op.cc +++ b/paddle/fluid/operators/pull_sparse_v2_op.cc @@ -51,10 +51,9 @@ class PullSparseV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -119,11 +118,11 @@ class PushSparseV2Op : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/push_dense_op.cc b/paddle/fluid/operators/push_dense_op.cc index 7ab49f2c2f..e13d757480 100644 --- a/paddle/fluid/operators/push_dense_op.cc +++ b/paddle/fluid/operators/push_dense_op.cc @@ -30,10 +30,9 @@ class PushDenseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc index a24b234a05..d445dca250 100644 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -225,10 +225,10 @@ class PyramidHashOP : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), + ctx.GetPlace()); } }; @@ -465,10 +465,10 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 7f9d472cb5..f143bc3a50 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -124,10 +124,10 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index c98e15fcff..83be35f998 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -19,13 +19,13 @@ namespace paddle { namespace operators { -framework::OpKernelType QuantOp::GetExpectedKernelType( +phi::KernelKey QuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), + return phi::KernelKey( + phi::Backend::ONEDNN, phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"))); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/quantize_op.h b/paddle/fluid/operators/quantize_op.h index 46a0469c80..3426af2b36 100644 --- a/paddle/fluid/operators/quantize_op.h +++ b/paddle/fluid/operators/quantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class QuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class QuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 9752f21b7e..810680ea5d 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -86,9 +86,9 @@ class RandintOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 6736cb4c87..11ba62197d 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -54,11 +54,10 @@ class RandomCropOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", phi::make_ddim(out_dim)); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/random_routing_op.cc b/paddle/fluid/operators/random_routing_op.cc index c20e124807..320f5cd1cf 100644 --- a/paddle/fluid/operators/random_routing_op.cc +++ b/paddle/fluid/operators/random_routing_op.cc @@ -55,7 +55,7 @@ class RandomRoutingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the gate_idx should be same as int64 const auto topk_idx_dtype = @@ -67,7 +67,7 @@ class RandomRoutingOp : public framework::OperatorWithKernel { const auto& topk_value_type = OperatorWithKernel::IndicateVarDataType(ctx, "TopK_Value"); - return framework::OpKernelType(topk_value_type, ctx.GetPlace()); + return phi::KernelKey(topk_value_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc index 78366efc53..187b227f33 100644 --- a/paddle/fluid/operators/randperm_op.cc +++ b/paddle/fluid/operators/randperm_op.cc @@ -44,11 +44,11 @@ class RandpermOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc index 8a965034ac..08706bc705 100644 --- a/paddle/fluid/operators/range_op.cc +++ b/paddle/fluid/operators/range_op.cc @@ -29,15 +29,17 @@ class RangeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (platform::is_xpu_place(tensor.place())) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/rank_attention_op.cc b/paddle/fluid/operators/rank_attention_op.cc index 80bd022aff..afc3388f42 100644 --- a/paddle/fluid/operators/rank_attention_op.cc +++ b/paddle/fluid/operators/rank_attention_op.cc @@ -79,11 +79,10 @@ class RankAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -118,11 +117,11 @@ class RankAttentionGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/read_file_op.cc b/paddle/fluid/operators/read_file_op.cc index 602f98dadb..9b42a895a9 100644 --- a/paddle/fluid/operators/read_file_op.cc +++ b/paddle/fluid/operators/read_file_op.cc @@ -61,10 +61,10 @@ class ReadFileOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::UINT8, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::UINT8, + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc index 617c47530c..94cdc2d658 100644 --- a/paddle/fluid/operators/real_op.cc +++ b/paddle/fluid/operators/real_op.cc @@ -58,12 +58,12 @@ class RealGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); auto complex_dtype = framework::ToComplexType(dtype); - return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + return phi::KernelKey(complex_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 0cc7bf2898..ecf8119ed2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -455,12 +455,12 @@ class ReduceGradKernel : public framework::OpKernel { phi::DenseTensor tmp_tensor; auto* pre_input = context.Input(framework::GradVarName("Out")); - auto in_kernel_type = framework::OpKernelType( - framework::TransToProtoVarType(pre_input->dtype()), - context.GetPlace()); - auto out_kernel_type = framework::OpKernelType( - static_cast(in_dtype), - context.GetPlace()); + auto in_kernel_type = + phi::KernelKey(framework::TransToProtoVarType(pre_input->dtype()), + context.GetPlace()); + auto out_kernel_type = + phi::KernelKey(static_cast(in_dtype), + context.GetPlace()); framework::TransDataType( in_kernel_type, out_kernel_type, *pre_input, &tmp_tensor); ComputeFromInput(&tmp_tensor, context); @@ -584,7 +584,7 @@ class ReduceOp : public framework::OperatorWithKernel { return true; } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -606,7 +606,7 @@ class ReduceOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "float16 can only be used on GPU or NPU or MLU place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -615,10 +615,11 @@ class ReduceOpUseInputPlace : public ReduceOp { using ReduceOp::ReduceOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); - kt.place_ = ctx.Input("X")->place(); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; @@ -663,7 +664,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int out_dtype = ctx.Attr("out_dtype"); auto input_data_type = @@ -679,7 +680,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 7f5a174952..cd695511d3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -49,18 +49,17 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { int in_dtype = ctx.Attr("out_dtype"); if (in_dtype >= 0) { - return framework::OpKernelType( + return phi::KernelKey( static_cast(in_dtype), ctx.GetPlace()); } - return framework::OpKernelType( - framework::OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h index 7b1b6bc831..38d526778e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h @@ -83,12 +83,10 @@ class ReduceSumGradKernel : public framework::OpKernel { phi::DenseTensor tmp_tensor; auto* pre_input = context.Input(framework::GradVarName("Out")); - auto in_kernel_type = framework::OpKernelType( - framework::TransToProtoVarType(pre_input->dtype()), - context.GetPlace()); - auto out_kernel_type = framework::OpKernelType( - static_cast(in_dtype), - context.GetPlace()); + auto in_kernel_type = phi::KernelKey(context.GetPlace(), + phi::DataLayout::ALL_LAYOUT, + pre_input->dtype()); + auto out_kernel_type = phi::KernelKey(in_dtype, context.GetPlace()); framework::TransDataType( in_kernel_type, out_kernel_type, *pre_input, &tmp_tensor); ComputeFromInput(&tmp_tensor, context); diff --git a/paddle/fluid/operators/repeat_interleave_op.cc b/paddle/fluid/operators/repeat_interleave_op.cc index aaef332bd0..44d022f4d5 100644 --- a/paddle/fluid/operators/repeat_interleave_op.cc +++ b/paddle/fluid/operators/repeat_interleave_op.cc @@ -86,10 +86,10 @@ class RepeatInterleaveOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -111,11 +111,11 @@ class RepeatInterleaveGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index d0cc991e95..354a5d820e 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -19,13 +19,13 @@ namespace paddle { namespace operators { -framework::OpKernelType ReQuantOp::GetExpectedKernelType( +phi::KernelKey ReQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), + return phi::KernelKey( + phi::Backend::ONEDNN, phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"))); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/requantize_op.h b/paddle/fluid/operators/requantize_op.h index 5b2f0148f1..a53ea52394 100644 --- a/paddle/fluid/operators/requantize_op.h +++ b/paddle/fluid/operators/requantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class ReQuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class ReQuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e980aa66e7..b4191fb46c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -246,22 +246,24 @@ class ReshapeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -359,12 +361,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -602,22 +603,24 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -630,22 +633,23 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 93877aa825..07c3aac520 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -28,12 +28,12 @@ class ReverseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc index 3528cc957f..2f75d5aaf2 100644 --- a/paddle/fluid/operators/rnn_op.cc +++ b/paddle/fluid/operators/rnn_op.cc @@ -30,11 +30,10 @@ class RNNOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; @@ -116,11 +115,11 @@ class RNNGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 4407fbf1a8..1a08a01542 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -25,11 +25,10 @@ class ROIAlignOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -51,11 +50,10 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index e79975e625..dadbd1115b 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -28,11 +28,10 @@ class ROIPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -53,11 +52,10 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index 823eb03aff..53f6969695 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -27,10 +27,10 @@ class RReluOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index eb4f1b88c6..88d51eabaf 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -47,17 +47,18 @@ class RunProgramOp : public framework::OperatorWithKernel { * * Of course, the data type here is also not important. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; @@ -173,17 +174,18 @@ class RunProgramGradOp : public framework::OperatorWithKernel { protected: /* see [Why use single type kernel] */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index ee9abf6f35..db9944ffb1 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -177,12 +177,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits"); - framework::OpKernelType kt = - framework::OpKernelType(data_type, ctx.device_context()); - return kt; + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -234,13 +232,11 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("SampledLogits")); - framework::OpKernelType kt = - framework::OpKernelType(data_type, ctx.device_context()); - return kt; + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 6d4e844d03..0263180a45 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -30,19 +30,19 @@ class SaveCombineOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } // TODO(lujun): The override here is just to bypass transform // in operator impl, which is not elegant enough. - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(tensor.place(), + phi::DataLayout::ALL_LAYOUT, + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/save_combine_op.h b/paddle/fluid/operators/save_combine_op.h index bf5e2a5e4d..10acb286ee 100644 --- a/paddle/fluid/operators/save_combine_op.h +++ b/paddle/fluid/operators/save_combine_op.h @@ -99,12 +99,14 @@ void SaveCombineTensorKernel(const Context& dev_ctx, "The Tensor with Index (%d) to be saved is not initialized.", i)); // Serialize tensors one by one // Check types to see if a fp16 transformation is required - auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); - auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor.dtype(); + auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { auto place = dev_ctx.GetPlace(); - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor out; framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); // copy LoD info to the new tensor diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 179a18ba8d..3af82952f4 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -30,10 +30,10 @@ class SaveOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/save_op.h b/paddle/fluid/operators/save_op.h index 7b78ac1ece..e33fc68f39 100644 --- a/paddle/fluid/operators/save_op.h +++ b/paddle/fluid/operators/save_op.h @@ -90,12 +90,14 @@ class SaveOpKernel : public framework::OpKernel { "Cannot open %s to save variables.", filename)); auto save_as_fp16 = ctx.Attr("save_as_fp16"); - auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); - auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor.dtype(); + auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor out; framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); // copy LoD info to the new tensor diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 7416269e33..2cfd096986 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -27,11 +27,11 @@ class ScaleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index 93d57aedd8..f6d1974968 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -26,10 +26,9 @@ class SeedOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::INT32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::INT32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/segment_pool_op.cc b/paddle/fluid/operators/segment_pool_op.cc index 2cdc574661..c2199b7036 100644 --- a/paddle/fluid/operators/segment_pool_op.cc +++ b/paddle/fluid/operators/segment_pool_op.cc @@ -28,11 +28,10 @@ class SegmentPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -113,11 +112,11 @@ class SegmentPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 117fc4ebe0..63aef4a628 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -121,11 +121,11 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index b1223618ee..a9e0b21b7b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -83,10 +83,10 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ctx->ShareLoD("Y", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -166,11 +166,11 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index 67573b543d..4a3100b14b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -128,10 +128,10 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -238,11 +238,11 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index c380779861..940d5caaaa 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -39,21 +39,22 @@ class SequenceMaskOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index 6957920131..12c8225015 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -134,10 +134,10 @@ class SequencePadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -246,11 +246,11 @@ class SequencePadGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 778b2f8854..938b23a22a 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -154,11 +154,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 17961181fb..a626d487b5 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -126,11 +126,10 @@ class SequenceScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -146,11 +145,11 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 9375cea85c..b7e2ff766f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -56,11 +56,10 @@ class SequenceSliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -81,11 +80,11 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 80f13a51ab..4089cdb9fc 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -36,14 +36,15 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); phi::DataLayout layout_ = DataLayout::kAnyLayout; if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; @@ -120,14 +121,15 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); phi::DataLayout layout_ = DataLayout::kAnyLayout; if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc index b19dfe40ed..c57cd94951 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc @@ -100,10 +100,10 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index fe91dd00d4..bddad088fe 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -86,10 +86,10 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -156,11 +156,11 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index d635feee58..19ce77b6b4 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -45,22 +45,24 @@ class SetValue : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StepsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -214,23 +216,25 @@ class SetValueGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto in_tensor = ctx.Input(framework::GradVarName("Out")); - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - in_tensor->place()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + in_tensor->place()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StepsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index 6849b4e427..24d2f1104d 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -26,21 +26,21 @@ class ShapeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/shuffle_batch_op.cc b/paddle/fluid/operators/shuffle_batch_op.cc index 6eeec76112..d34c102d0e 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cc +++ b/paddle/fluid/operators/shuffle_batch_op.cc @@ -55,18 +55,20 @@ class ShuffleBatchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "Seed") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( var_name, tensor, expected_kernel_type); @@ -123,11 +125,11 @@ class ShuffleBatchOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index 7e98514cde..b72d3557b6 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,11 +35,11 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -89,11 +89,11 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 5c5343bf42..536e878c6f 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -74,11 +74,10 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index 094e87f384..695807a4c3 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -25,17 +25,19 @@ class SizeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index a418719907..426eec0b0e 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -132,7 +132,7 @@ class SliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *in_var = ctx.InputVar("Input"); if (in_var->IsType()) { @@ -144,9 +144,8 @@ class SliceOp : public framework::OperatorWithKernel { "The tensor Input (Input) of Slice op is not initialized.")); // NOTE: cuda pinned tensor need to copy its data to target place if (platform::is_cuda_pinned_place(in_tensor.place())) { - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()), + ctx.GetPlace()); } #ifdef PADDLE_WITH_MKLDNN @@ -162,33 +161,37 @@ class SliceOp : public framework::OperatorWithKernel { // created, so in that scenario a fallback is needed if (ctx.Input("Input") ->mem_desc() - .data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + .data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor.dtype()), in_tensor.place()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()), + in_tensor.place()); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -322,7 +325,7 @@ class SliceOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); @@ -335,28 +338,32 @@ class SliceOpGrad : public framework::OperatorWithKernel { // created, so in that scenario a fallback is needed if (ctx.Input(framework::GradVarName("Out")) ->mem_desc() - .data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + .data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index bc11f53e00..99383363e6 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -31,7 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. std::string data_format = ctx.Attr("data_format"); @@ -48,7 +48,8 @@ class SoftmaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; @@ -116,7 +117,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. std::string data_format = ctx.Attr("data_format"); @@ -132,7 +133,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index a2ca77cc60..df142f3350 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -218,11 +218,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -310,11 +309,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 0d4af9c0ce..ed9c82c34f 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -204,11 +204,11 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/sparse_attention_op.cc b/paddle/fluid/operators/sparse_attention_op.cc index 48dc3d7824..26dfc0fbbc 100644 --- a/paddle/fluid/operators/sparse_attention_op.cc +++ b/paddle/fluid/operators/sparse_attention_op.cc @@ -122,11 +122,11 @@ class SparseAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "Q", "K"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -169,11 +169,11 @@ class SparseAttentionOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 372e31aa9a..85bd867665 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -25,9 +25,9 @@ class SpectralNormOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; @@ -143,9 +143,9 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index fc7e8a869e..47f6306acb 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -108,7 +108,7 @@ class SplitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -120,25 +120,27 @@ class SplitOp : public framework::OperatorWithKernel { // 16(depending on which blocking format is used) submemory cannot be // created, so in that scenario a fallback is needed const auto x_md = ctx.Input("X")->mem_desc(); - if (x_md.data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + if (x_md.data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor" || var_name == "SectionsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cc b/paddle/fluid/operators/squared_l2_distance_op.cc index dc1848b3ee..f1ed2d3ee6 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cc +++ b/paddle/fluid/operators/squared_l2_distance_op.cc @@ -200,9 +200,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 7b023bcdf6..115901d3ee 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -120,11 +120,11 @@ class SqueezeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -139,11 +139,11 @@ class SqueezeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc index d30320f995..9cc78eb300 100644 --- a/paddle/fluid/operators/stack_op.cc +++ b/paddle/fluid/operators/stack_op.cc @@ -31,11 +31,11 @@ class StackOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/stft_op.cc b/paddle/fluid/operators/stft_op.cc index 986911a139..8c9507bc89 100644 --- a/paddle/fluid/operators/stft_op.cc +++ b/paddle/fluid/operators/stft_op.cc @@ -79,10 +79,10 @@ class StftOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); + return phi::KernelKey(in_dtype, ctx.GetPlace()); } }; @@ -140,12 +140,12 @@ class StftGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto in_dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); const auto kernel_dtype = framework::ToRealType(in_dtype); - return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + return phi::KernelKey(kernel_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index c08f214ab5..fffd99ae76 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -31,7 +31,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *in_var = ctx.InputVar("Input"); auto is_in_var_array = in_var->IsType(); @@ -50,35 +50,37 @@ class StridedSliceOp : public framework::OperatorWithKernel { string::to_string(tensor.place()))); } } - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + ctx.GetPlace()); } // NOTE: cuda pinned tensor need to copy its data to target place auto in_tensor = ctx.Input("Input"); if (platform::is_cuda_pinned_place(in_tensor->place())) { - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor->dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor->dtype()), + ctx.GetPlace()); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - in_tensor->place()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + in_tensor->place()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor" || var_name == "StridesTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StridesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -164,26 +166,30 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor" || var_name == "StridesTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StridesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index f1a7688372..35128b0085 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -469,19 +469,19 @@ class FasterTokenizerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::INT64, - paddle::platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::INT64, + paddle::platform::CPUPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 098167cb69..a4902a85fc 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -30,7 +30,7 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto x_vars = ctx.MultiInputVar("X"); auto x_vars_name = ctx.InputNames("X"); @@ -87,27 +87,24 @@ class SumOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { - return framework::OpKernelType( - framework::TransToProtoVarType(value.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(value.dtype()), + ctx.GetPlace()); } } // if input sparse vars are not initialized, use an default kernel type. - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& x_var : x_vars) { auto& array = x_var->Get(); for (auto& each : array) { if (each.numel() != 0 && each.IsInitialized()) { - return framework::OpKernelType( - framework::TransToProtoVarType(each.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(each.dtype()), + ctx.GetPlace()); } } } diff --git a/paddle/fluid/operators/tdm_child_op.cc b/paddle/fluid/operators/tdm_child_op.cc index c91f0b989e..0ec2c1e85b 100644 --- a/paddle/fluid/operators/tdm_child_op.cc +++ b/paddle/fluid/operators/tdm_child_op.cc @@ -102,10 +102,10 @@ class TDMChildOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/tdm_sampler_op.cc b/paddle/fluid/operators/tdm_sampler_op.cc index 7480c10394..66e9728d88 100644 --- a/paddle/fluid/operators/tdm_sampler_op.cc +++ b/paddle/fluid/operators/tdm_sampler_op.cc @@ -118,10 +118,10 @@ class TDMSamplerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index bad4479868..fdb78f9da3 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -74,11 +74,10 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // teacher_student_sigmoid_loss // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -186,11 +185,10 @@ class TeacherStudentSigmoidLossGradientOp // Explicitly set that the data type of computation kernel of // teacher_student_sigmoid_loss // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 119fcf4f49..32fc06f578 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -28,10 +28,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -120,11 +120,11 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/tile_op.cc b/paddle/fluid/operators/tile_op.cc index 172e967370..9ea804b244 100644 --- a/paddle/fluid/operators/tile_op.cc +++ b/paddle/fluid/operators/tile_op.cc @@ -29,22 +29,23 @@ class TileOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -121,22 +122,24 @@ class TileGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 42d5433792..22eb23c93f 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -65,15 +65,10 @@ class TopkOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context(), - layout_, - library_); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -128,11 +123,11 @@ class TopkOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/transfer_layout_op.cc b/paddle/fluid/operators/transfer_layout_op.cc index 5bba1c225a..a197546b35 100644 --- a/paddle/fluid/operators/transfer_layout_op.cc +++ b/paddle/fluid/operators/transfer_layout_op.cc @@ -42,7 +42,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // kernel's device type is decided by input tensor place auto *in = ctx.InputVar("X"); @@ -59,14 +59,16 @@ class TransferLayoutOp : public framework::OperatorWithKernel { in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace(); // dtype is not important - return framework::OpKernelType(framework::proto::VarType::FP32, place); + return phi::KernelKey(framework::proto::VarType::FP32, place); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index e81c619db4..d49cbad114 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -97,12 +97,13 @@ class TransposeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto &data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -192,13 +193,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -229,13 +231,14 @@ class Transpose2Op : public TransposeOp { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -333,14 +336,15 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc index 525dd17c39..0e78aa20fa 100644 --- a/paddle/fluid/operators/tree_conv_op.cc +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -151,11 +151,11 @@ class TreeConvOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), - ctx.device_context()); + ctx.GetPlace()); } }; @@ -215,11 +215,11 @@ class TreeConvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), - ctx.device_context()); + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc index 62dc419fd0..66e4c3a578 100644 --- a/paddle/fluid/operators/triangular_solve_op.cc +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -23,10 +23,10 @@ class TriangularSolveOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/tril_indices_op.cc b/paddle/fluid/operators/tril_indices_op.cc index bae34fa5f5..4631900e3b 100644 --- a/paddle/fluid/operators/tril_indices_op.cc +++ b/paddle/fluid/operators/tril_indices_op.cc @@ -27,9 +27,9 @@ class TrilIndicesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/triu_indices_op.cc b/paddle/fluid/operators/triu_indices_op.cc index d02b54f608..8167cb3e3f 100644 --- a/paddle/fluid/operators/triu_indices_op.cc +++ b/paddle/fluid/operators/triu_indices_op.cc @@ -24,9 +24,9 @@ class TriuIndicesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index 1d29a9c518..c5a4a1268f 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -31,15 +31,11 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - phi::DataLayout layout{phi::DataLayout::kAnyLayout}; - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), - ctx.device_context(), - layout, - library); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc index a9191c09ea..6d4206d343 100644 --- a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc @@ -23,9 +23,9 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cc b/paddle/fluid/operators/uniform_random_inplace_op.cc index 09870c8401..d43d1cd125 100644 --- a/paddle/fluid/operators/uniform_random_inplace_op.cc +++ b/paddle/fluid/operators/uniform_random_inplace_op.cc @@ -57,10 +57,10 @@ class UniformRandomInplaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 7ba22baff9..e2605332cc 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -136,22 +136,24 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensorList" || var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/unique_consecutive_op.cc b/paddle/fluid/operators/unique_consecutive_op.cc index 97cd31141d..d57a9ceacf 100644 --- a/paddle/fluid/operators/unique_consecutive_op.cc +++ b/paddle/fluid/operators/unique_consecutive_op.cc @@ -26,10 +26,10 @@ class UniqueConsecutiveOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/unique_op.cc b/paddle/fluid/operators/unique_op.cc index c99f60ca87..5484a16ca6 100644 --- a/paddle/fluid/operators/unique_op.cc +++ b/paddle/fluid/operators/unique_op.cc @@ -98,18 +98,17 @@ class UniqueOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // Return CPUPlace when Attr("is_sorted") is false. Because it means // that fluid.layers.unique is called, but there is no cuda kernel. if (!ctx.Attr("is_sorted")) { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } else { // new version paddle.unique is called. - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } }; diff --git a/paddle/fluid/operators/unique_with_counts_op.cc b/paddle/fluid/operators/unique_with_counts_op.cc index 6e60078f6a..3726fd978b 100644 --- a/paddle/fluid/operators/unique_with_counts_op.cc +++ b/paddle/fluid/operators/unique_with_counts_op.cc @@ -44,11 +44,10 @@ class UniqueWithCountsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index 92e2082013..6eb6b81eb4 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -148,11 +148,10 @@ int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) { class UnpoolOp : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -161,11 +160,10 @@ class UnpoolOp : public framework::OperatorWithKernel { class Unpool3dOp : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -204,11 +202,10 @@ class Unpool3dOpGradMaker : public framework::SingleGradOpMaker { class UnpoolOpGrad : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -217,11 +214,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { class Unpool3dOpGrad : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index d092c03a56..5c6816a171 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -144,23 +144,24 @@ class UnsqueezeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::TransToProtoVarType( - ctx.Input("X")->type()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType( + ctx.Input("X")->type()), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxesTensor" || var_name == "AxesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -225,11 +226,11 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 3f09b20068..93811358c7 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -28,15 +28,10 @@ class WarpCTCOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.GetPlace(), - layout_, - library_); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -146,11 +141,11 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/where_index_op.cc b/paddle/fluid/operators/where_index_op.cc index 52448b08c5..2b19b62595 100644 --- a/paddle/fluid/operators/where_index_op.cc +++ b/paddle/fluid/operators/where_index_op.cc @@ -25,10 +25,10 @@ class WhereIndexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Condition"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 89ae772f30..ad0a83546e 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -26,6 +26,7 @@ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/type_defs.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -53,10 +54,27 @@ class KernelKey { KernelKey(Backend backend, DataLayout layout, DataType dtype) : backend_(backend), layout_(layout), dtype_(dtype) {} + explicit KernelKey(Place place) + : backend_(TransToPhiBackend(place)), + layout_(DataLayout::ALL_LAYOUT), + dtype_(DataType::ALL_DTYPE) {} + + explicit KernelKey(const int& dtype, Place place) + : backend_(TransToPhiBackend(place)), + layout_(DataLayout::ALL_LAYOUT), + dtype_(phi::TransToPhiDataType(dtype)) {} + + explicit KernelKey(Place place, DataLayout layout, DataType dtype) + : backend_(TransToPhiBackend(place)), layout_(layout), dtype_(dtype) {} + Backend backend() const { return backend_; } DataLayout layout() const { return layout_; } DataType dtype() const { return dtype_; } + void set_backend(const Backend& backend) { backend_ = backend; } + void set_layout(const DataLayout& layout) { layout_ = layout; } + void set_dtype(const DataType& dtype) { dtype_ = dtype; } + struct Hash { // Note: Now the number of bits we need does not exceed 32 bits, so there is // no need to use 64 bits. If needed in the future, it can be expanded, diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index 6879c62065..edb841aeb1 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -125,6 +125,7 @@ enum ProtoDataType { FP16 = 4, FP32 = 5, FP64 = 6, + RAW = 17, UINT8 = 20, INT8 = 21, BF16 = 22, @@ -163,6 +164,8 @@ inline DataType TransToPhiDataType(const int& dtype) { return DataType::BOOL; case ProtoDataType::PSTRING: return DataType::PSTRING; + case ProtoDataType::RAW: + return DataType::ALL_DTYPE; default: return DataType::UNDEFINED; } @@ -198,6 +201,8 @@ inline int TransToProtoVarType(const DataType& dtype) { return ProtoDataType::BOOL; case DataType::PSTRING: return ProtoDataType::PSTRING; + case DataType::UNDEFINED: + return ProtoDataType::RAW; default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported data type `%s` when casting it into " diff --git a/paddle/phi/kernels/impl/searchsorted_kernel_impl.h b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h index e3cd6f5828..6c0891e59b 100644 --- a/paddle/phi/kernels/impl/searchsorted_kernel_impl.h +++ b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h @@ -147,7 +147,7 @@ class SearchSortedFunctor { }; template -static void VisitDataType(DataType type, Visitor visitor) { +void VisitDataTypeForSearchSorted(DataType type, Visitor visitor) { if (type == DataType::FLOAT32) { visitor.template apply(); } else if (type == DataType::FLOAT64) { @@ -178,13 +178,13 @@ void SearchsortedKernel(const Context& ctx, int* out_data = out->data(); SearchSortedFunctor functor( ctx, &sorted_sequence, &value, right, out_data); - VisitDataType(value.dtype(), functor); + VisitDataTypeForSearchSorted(value.dtype(), functor); } else { ctx.template Alloc(out); int64_t* out_data = out->data(); SearchSortedFunctor functor( ctx, &sorted_sequence, &value, right, out_data); - VisitDataType(value.dtype(), functor); + VisitDataTypeForSearchSorted(value.dtype(), functor); } } -- GitLab