diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index c34e727486bf29c088b035f6a327684c875a5f8f..e0a6fbd37dafffc6fb3d334355b09127be368ee3 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -422,9 +422,9 @@ class CustomOperator : public OperatorWithKernel { * The RAW type is used here as the data type, indicating that * it can only be determined at runtime. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace()); + return phi::KernelKey(ctx.GetPlace()); } /** @@ -432,13 +432,13 @@ class CustomOperator : public OperatorWithKernel { * Because the kernel data type is RAW, we should skip the cast for * data type difference when PrepareData. */ - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const override { - return OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index 777d7a67700d3b176b8147551df92e01cd5e6cd0..9723fde1cc8fe51d5e95af725c3422ef48ed0f8a 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -47,15 +47,17 @@ class TestOpWithKernel : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { if (Attr("use_gpu")) { VLOG(3) << "force use gpu kernel"; - return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); + return phi::KernelKey(phi::Backend::GPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32); } else { VLOG(3) << "use default kernel"; - return OpKernelType(proto::VarType::FP32, - ctx.Input("input")->place()); + return phi::KernelKey(proto::VarType::FP32, + ctx.Input("input")->place()); } } }; diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 3c0e8d4f0ec722c6866fc31dd3acbd868e655a87..73ce635f57ce592b61e7ad21dc71f78ac1c0570e 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -50,13 +50,14 @@ void CastDataLayout::apply() { } } -void TransDataLayout(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataLayout(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, - phi::DenseTensor* out) { + phi::DenseTensor* out, + const phi::Place& place) { PADDLE_ENFORCE( - platform::places_are_same_class(kernel_type_for_var.place_, - expected_kernel_type.place_), + backends_are_same_class(kernel_type_for_var.backend(), + expected_kernel_type.backend()), platform::errors::PreconditionNotMet( "TransDataLayout only support DataLayout transform on same place.")); @@ -72,21 +73,20 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, auto src_dim = in.dims(); std::vector dst_dim; - auto axis = GetAxis(kernel_type_for_var.data_layout_, - expected_kernel_type.data_layout_); + auto axis = + GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout()); dst_dim.resize(axis.size()); for (size_t i = 0; i < axis.size(); i++) { dst_dim[i] = src_dim[axis[i]]; } out->Resize(phi::make_ddim(dst_dim)); - out->mutable_data(expected_kernel_type.place_, in.dtype()); + out->mutable_data(place, in.dtype()); - framework::VisitDataType( - framework::TransToProtoVarType(in.dtype()), - CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); + framework::VisitDataType(framework::TransToProtoVarType(in.dtype()), + CastDataLayout(pool.Get(place), axis, in, out)); - out->set_layout(expected_kernel_type.data_layout_); + out->set_layout(expected_kernel_type.layout()); } } // namespace framework diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index bad13e7e90384b9fcbd902f59c705384911d7153..3bc55b8ad86050fa82fddd4f173d742263ec9714 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -54,10 +54,11 @@ struct CastDataLayout { std::vector GetAxis(const DataLayout& from, const DataLayout& to); -void TransDataLayout(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataLayout(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, - phi::DenseTensor* out); + phi::DenseTensor* out, + const phi::Place& place); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/paddle/fluid/framework/data_layout_transform_test.cc index 9b314fbb2c1609afda244b36d6022baa8e96bf13..880fa5b057df6ec391194ba80e7549cbbe5eb0bc 100644 --- a/paddle/fluid/framework/data_layout_transform_test.cc +++ b/paddle/fluid/framework/data_layout_transform_test.cc @@ -24,22 +24,16 @@ TEST(DataTransform, DataLayoutFunction) { in.set_layout(phi::DataLayout::kNHWC); auto kernel_nhwc = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kNHWC, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::kNHWC, phi::DataType::FLOAT32); auto kernel_ncwh = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kNCHW, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::kNCHW, phi::DataType::FLOAT32); - paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); + paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out, place); EXPECT_TRUE(out.layout() == phi::DataLayout::kNCHW); EXPECT_TRUE(out.dims() == phi::make_ddim({2, 2, 3, 1})); - TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out); + paddle::framework::TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out, place); EXPECT_TRUE(in.layout() == phi::DataLayout::kNHWC); EXPECT_TRUE(in.dims() == phi::make_ddim({2, 3, 1, 2})); diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index fff4f6acb3e74f740d21cc5157e973eb87b19106..38e1ce1c3141f83f65adc40ce89b4fc0717236b7 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -36,16 +36,17 @@ static void PassTensorData(phi::DenseTensor *from, phi::DenseTensor *to) { *from = phi::DenseTensor(); } -void TransformData(const OpKernelType &expected_kernel_type, - const OpKernelType &kernel_type_for_var, +void TransformData(const phi::KernelKey &expected_kernel_type, + const phi::KernelKey &kernel_type_for_var, const phi::DenseTensor &input_tensor, - phi::DenseTensor *output_tensor) { + phi::DenseTensor *output_tensor, + const phi::Place &place) { bool transformed = false; phi::DenseTensor in; in.ShareDataWith(input_tensor); phi::DenseTensor out; - const DataLayout lin = kernel_type_for_var.data_layout_; - const DataLayout lout = expected_kernel_type.data_layout_; + const DataLayout lin = kernel_type_for_var.layout(); + const DataLayout lout = expected_kernel_type.layout(); // do layout transform if (NeedTransformLayout(lout, lin)) { #ifdef PADDLE_WITH_MKLDNN @@ -79,43 +80,42 @@ void TransformData(const OpKernelType &expected_kernel_type, } else { // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel // Do transform via ONEDNN lib - PADDLE_ENFORCE( - kernel_type_for_var.data_layout_ == DataLayout::ONEDNN && - expected_kernel_type.data_layout_ != DataLayout::ONEDNN, - platform::errors::InvalidArgument( - "TransDataLayoutFromOneDNN only supports " - "transform from ONEDNN to non-ONEDNN")); + PADDLE_ENFORCE(lin == DataLayout::ONEDNN && lout != DataLayout::ONEDNN, + platform::errors::InvalidArgument( + "TransDataLayoutFromOneDNN only supports " + "transform from ONEDNN to non-ONEDNN")); phi::funcs::TransDataLayoutFromOneDNN( - kernel_type_for_var.data_layout_, + lin, phi::OneDNNContext::tls().get_cur_paddle_data_layout(), in, &out, - expected_kernel_type.place_); + place); } } else { // Case3 - transfrom between Non-ONEDNN OPKernels - TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); + TransDataLayout( + kernel_type_for_var, expected_kernel_type, in, &out, place); } #else // Case3 - transfrom between Non-ONEDNN OPKernels - TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); + TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out, place); #endif transformed = true; PassTensorData(&out, &in); } // do data type transform - if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { + if (NeedTransformDataType(expected_kernel_type, kernel_type_for_var)) { TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); transformed = true; PassTensorData(&out, &in); } // do device transform - if (!platform::is_same_place(kernel_type_for_var.place_, - expected_kernel_type.place_)) { - TransDataDevice(in, expected_kernel_type.place_, &out); + if (kernel_type_for_var.backend() != phi::Backend::ALL_BACKEND && + !platform::is_same_place(in.place(), place)) { + TransDataDevice(in, place, &out); transformed = true; PassTensorData(&out, &in); } diff --git a/paddle/fluid/framework/data_transform.h b/paddle/fluid/framework/data_transform.h index 2fcea7803ed31fef21ceeb7fb884a1c9b00bb9fc..27bc0086c233ded09605fedadaa703672bed6e41 100644 --- a/paddle/fluid/framework/data_transform.h +++ b/paddle/fluid/framework/data_transform.h @@ -33,10 +33,11 @@ namespace framework { class OpKernelType; class Variable; -void TransformData(const OpKernelType &expected_kernel_type, - const OpKernelType &kernel_type_for_var, +void TransformData(const phi::KernelKey &expected_kernel_type, + const phi::KernelKey &kernel_type_for_var, const phi::DenseTensor &input_tensor, - phi::DenseTensor *out); + phi::DenseTensor *out, + const phi::Place &place); /** * Set OutVar from InVar, except the tensor is shared with `tensor` diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index fd1c06fc6458e1eb20e668e27b4ca7cc33deb0e2..a05f2858c0df3b0d80e6dbfd4758270902e1403e 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace framework { @@ -226,6 +227,11 @@ extern inline bool IsComplexType(const proto::VarType::Type& type) { type == proto::VarType::COMPLEX128); } +extern inline bool IsComplexType(const phi::DataType& type) { + return (type == phi::DataType::COMPLEX64 || + type == phi::DataType::COMPLEX128); +} + extern proto::VarType::Type PromoteTypesIfComplexExists( const proto::VarType::Type type_a, const proto::VarType::Type type_b); diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 0768d2d82fb81c147aed0998bc65322ad7cdfa89..0f2e244af0ab3cdf91f96a47eccfd654c8f87c38 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -129,19 +129,18 @@ struct CastDataType { } }; -void TransDataType(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataType(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, phi::DenseTensor* out) { - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(in.dtype()), - kernel_type_for_var.data_type_, - platform::errors::InvalidArgument( - "The src dtype(%s) of input tensor and kernel_type(%s) " - "are not conststent.", - DataTypeToString(framework::TransToProtoVarType(in.dtype())), - DataTypeToString(kernel_type_for_var.data_type_))); - auto dst_type = expected_kernel_type.data_type_; + PADDLE_ENFORCE_EQ(in.dtype(), + kernel_type_for_var.dtype(), + platform::errors::InvalidArgument( + "The src dtype(%s) of input tensor and kernel_type(%s) " + "are not conststent.", + DataTypeToString(in.dtype()), + DataTypeToString(kernel_type_for_var.dtype()))); + auto dst_type = framework::TransToProtoVarType(expected_kernel_type.dtype()); TransDataType(in, dst_type, out); } diff --git a/paddle/fluid/framework/data_type_transform.h b/paddle/fluid/framework/data_type_transform.h index 619e15b6045e883a64ec788d85554ac52619fc05..2ec193b675097d61d68c60c95d544269b1ac726a 100644 --- a/paddle/fluid/framework/data_type_transform.h +++ b/paddle/fluid/framework/data_type_transform.h @@ -28,8 +28,8 @@ class OpKernelType; using KernelTypePair = std::pair; -void TransDataType(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_type, +void TransDataType(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, const phi::DenseTensor& in, phi::DenseTensor* out); void TransDataType(const phi::DenseTensor& in, diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index e57d9d6d268e5bb9b77a7ff9186f3db41a3d42b0..44ebdc96e6afe366675a85710e9d6446aea92f5e 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -19,47 +19,26 @@ limitations under the License. */ TEST(DataTypeTransform, CPUTransform) { auto place = paddle::platform::CPUPlace(); - auto kernel_fp16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_bf16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BF16, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + auto kernel_fp16 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16); + + auto kernel_bf16 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BFLOAT16); + + auto kernel_fp32 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); + + auto kernel_fp64 = phi::KernelKey( + place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64); auto kernel_int32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32); auto kernel_int64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64); auto kernel_bool = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL, - place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); // data type transform from float32 { diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu index 6e047bbbf15982bd4605510d2b7dba62e32ddf31..f9394bea7fc372885ba67bf4326683501256c5b7 100644 --- a/paddle/fluid/framework/data_type_transform_test.cu +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -24,41 +24,24 @@ TEST(DataTypeTransform, GPUTransform) { .GetAllocator(gpu_place, context.stream()) .get()); context.PartialInitWithAllocator(); - auto kernel_fp16 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_fp64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_int32 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_int64 = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); - - auto kernel_bool = - paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL, - gpu_place, - phi::DataLayout::kAnyLayout, - paddle::framework::LibraryType::kPlain); + + auto kernel_fp16 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16); + + auto kernel_fp32 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); + + auto kernel_fp64 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64); + + auto kernel_int32 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32); + + auto kernel_int64 = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64); + + auto kernel_bool = phi::KernelKey( + gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); // data type transform from float32 { diff --git a/paddle/fluid/framework/details/build_strategy_test.cc b/paddle/fluid/framework/details/build_strategy_test.cc index c39388fa5bc8616ae96adc50ba7aa7a251d5e208..7ec7d93ee661092514874cddf6465dbea35e50d2 100644 --- a/paddle/fluid/framework/details/build_strategy_test.cc +++ b/paddle/fluid/framework/details/build_strategy_test.cc @@ -50,10 +50,10 @@ class SumOpWithKernel : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext &ctx) const override { - return OpKernelType(proto::VarType::FP32, - ctx.Input("X")->place()); + return phi::KernelKey(proto::VarType::FP32, + ctx.Input("X")->place()); } }; diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/paddle/fluid/framework/infershape_utils_test.cc index 6aef5b7a891777004f0501c5744209cae91369cf..43fbb7d550ee65c2fa4d36916f56250cf6af9358 100644 --- a/paddle/fluid/framework/infershape_utils_test.cc +++ b/paddle/fluid/framework/infershape_utils_test.cc @@ -84,9 +84,9 @@ class InferShapeUtilsTestOp : public OperatorWithKernel { public: using OperatorWithKernel::OperatorWithKernel; - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); + return phi::KernelKey(proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index d41a1dca448bc6351fc93c5636859acc58832251..7e97d82c78eaefe40c5c58caf81caa8c52973c53 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -27,22 +27,26 @@ namespace paddle { namespace framework { namespace interpreter { -bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key, - const std::string& var_name, - std::string* new_var_name, - std::vector* op_func_nodes, - bool use_local_scope, - bool is_fetch_v2, - bool skip_run) { +bool DataTranferHelper::apply( + const phi::KernelKey& kernel_type_for_var, + const framework::OpKernelType& expected_kernel_key, + const phi::DenseTensor* tensor, + const std::string& var_name, + std::string* new_var_name, + std::vector* op_func_nodes, + bool use_local_scope, + bool is_fetch_v2, + bool skip_run) { bool is_transferred = false; auto* src_var_name = &var_name; // 1. layout transform - if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { + if (need_layout_transform( + kernel_type_for_var, + TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { auto op = TransferLayout(*src_var_name, new_var_name, - kernel_type_for_var.data_layout_, + kernel_type_for_var.layout(), expected_kernel_key.data_layout_, var_scope_, scope_, @@ -56,13 +60,16 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, is_transferred = true; } // 2. dype transform - if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) { - auto op = TransferDtype(*src_var_name, - new_var_name, - kernel_type_for_var.data_type_, - expected_kernel_key.data_type_, - var_scope_, - scope_); + if (need_dtype_transform( + kernel_type_for_var, + TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) { + auto op = TransferDtype( + *src_var_name, + new_var_name, + framework::TransToProtoVarType(kernel_type_for_var.dtype()), + expected_kernel_key.data_type_, + var_scope_, + scope_); if (op) { RunAndConstructOpFuncNode( op, *src_var_name, *new_var_name, op_func_nodes, skip_run); @@ -72,8 +79,9 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, is_transferred = true; } // 3. device transform - if (need_device_transform(kernel_type_for_var, expected_kernel_key)) { - auto src_place = kernel_type_for_var.place_; + if (need_device_transform( + kernel_type_for_var, tensor, expected_kernel_key.place_)) { + auto src_place = tensor->place(); auto dst_place = expected_kernel_key.place_; auto op = TransferDevice( @@ -526,11 +534,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, auto kernel_type_for_var = static_cast(op_base) ->GetKernelTypeForVar( - var_name_item.first, *tensor_in, expected_kernel_key); + var_name_item.first, + *tensor_in, + framework::TransOpKernelTypeToPhiKernelKey( + expected_kernel_key)); // apply data transform is_transferred = data_transfer_helper.apply(kernel_type_for_var, expected_kernel_key, + tensor_in, var_name, &new_var_name, new_op_func_nodes, diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h index e74fe8066e6b8ef911093e3c9e925f200aae6c9b..604f12038008daa91570c1ccd2b0621e5f782981 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.h +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h @@ -34,8 +34,9 @@ class DataTranferHelper { Scope* local_scope) : place_(place), var_scope_(var_scope), scope_(local_scope) {} - bool apply(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key, + bool apply(const phi::KernelKey& kernel_type_for_var, + const framework::OpKernelType& expected_kernel_key, + const phi::DenseTensor* tensor, const std::string& var_name, std::string* new_var_name, std::vector* new_op_func_nodes, @@ -79,28 +80,28 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, framework::Scope* local_scope, bool skip_run = false); -inline bool need_device_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { - auto& src_place = kernel_type_for_var.place_; - auto& dst_place = expected_kernel_key.place_; - if (platform::is_same_place(src_place, dst_place) || - (platform::is_cuda_pinned_place(src_place) && - platform::is_cpu_place(dst_place))) { +inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var, + const phi::DenseTensor* tensor, + const phi::Place& expected_place) { + if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND || + platform::is_same_place(tensor->place(), expected_place) || + (platform::is_cuda_pinned_place(tensor->place()) && + platform::is_cpu_place(expected_place))) { return false; } return true; } -inline bool need_dtype_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { +inline bool need_dtype_transform(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key) { return framework::NeedTransformDataType(kernel_type_for_var, expected_kernel_key); } -inline bool need_layout_transform(const OpKernelType& kernel_type_for_var, - const OpKernelType& expected_kernel_key) { - return framework::NeedTransformLayout(kernel_type_for_var.data_layout_, - expected_kernel_key.data_layout_); +inline bool need_layout_transform(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_key) { + return framework::NeedTransformLayout(kernel_type_for_var.layout(), + expected_kernel_key.layout()); } std::shared_ptr TransferLayout(const std::string& var_name, diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f5b430e829a13c332c03c974ee6dbb5b9f0b5be7..f98acfdccfd7d70eca91a7e1ec3d4bb12bb8703d 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -730,8 +730,8 @@ bool BuildOpFuncList(const platform::Place& place, auto* dev_ctx = pool.Get(place); auto exec_ctx = ExecutionContext( *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); - auto expected_kernel_key = - op_with_kernel->GetExpectedKernelType(exec_ctx); + auto expected_kernel_key = framework::TransPhiKernelKeyToOpKernelType( + op_with_kernel->GetExpectedKernelType(exec_ctx)); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (op_with_kernel->CanCUDNNBeUsed(exec_ctx, expected_kernel_key.data_type_)) { @@ -741,6 +741,10 @@ bool BuildOpFuncList(const platform::Place& place, VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // change device by the device_guard() ApplyDeviceGuard(op, place, &expected_kernel_key); + if (platform::places_are_same_class(exec_ctx.GetPlace(), + expected_kernel_key.place_)) { + expected_kernel_key.place_ = exec_ctx.GetPlace(); + } // step 2. select op kernel auto run_phi_kernel = false; diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index a609313e84800543f39904e5bc2f293bd3db32db..eb969a94d825617d37ef760c7bd12746637786a8 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace framework { @@ -108,15 +109,32 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { return ret; } -inline bool NeedTransformDataType(const OpKernelType& l, - const OpKernelType& r) { - return (l.data_type_ != r.data_type_); +inline bool NeedTransformDataType(const phi::KernelKey& l, + const phi::KernelKey& r) { + return l.dtype() != phi::DataType::ALL_DTYPE && + r.dtype() != phi::DataType::ALL_DTYPE && l.dtype() != r.dtype(); } -inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) { - return (!platform::places_are_same_class(l.place_, r.place_)) || - (l.data_type_ != r.data_type_) || - NeedTransformLayout(l.data_layout_, r.data_layout_); +inline bool backends_are_same_class(const phi::Backend& l, + const phi::Backend& r) { + if (l == phi::Backend::ALL_BACKEND || r == phi::Backend::ALL_BACKEND) { + return true; + } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + size_t num_backends = static_cast(phi::Backend::NUM_BACKENDS); + if (static_cast(l) > num_backends && + static_cast(r) > num_backends) { + return phi::TransToPhiPlace(l).GetDeviceType() == + phi::TransToPhiPlace(r).GetDeviceType(); + } +#endif + return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r); +} + +inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) { + return !backends_are_same_class(l.backend(), r.backend()) || + NeedTransformDataType(l, r) || + NeedTransformLayout(l.layout(), r.layout()); } } // namespace framework diff --git a/paddle/fluid/framework/op_registry_test.cc b/paddle/fluid/framework/op_registry_test.cc index 9ef577f62855f7f648295d0e8c1657c3c151dfad..5a40a4df004b4bc23bd8944e594b7842edba5532 100644 --- a/paddle/fluid/framework/op_registry_test.cc +++ b/paddle/fluid/framework/op_registry_test.cc @@ -214,9 +214,10 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(InferShapeContext* ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::FP32, ctx.device_context()); + return phi::KernelKey(proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; @@ -275,12 +276,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel { protected: void InferShape(InferShapeContext* ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(proto::VarType::FP32, - platform::CUDAPlace(0), - DataLayout::kAnyLayout, - framework::LibraryType::kCUDNN); + return phi::KernelKey(phi::Backend::GPUDNN, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32); } }; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index dcb822afb4cadbfeac111e53cbd3cecedf05c77d..5e8d0b1b87ae27a0dd40ac3e29a31cf149e14576 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1380,8 +1380,7 @@ bool OperatorWithKernel::SupportXPU() const { #endif } -bool OperatorWithKernel::SupportsMKLDNN( - const proto::VarType::Type data_type) const { +bool OperatorWithKernel::SupportsMKLDNN(const phi::DataType data_type) const { auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( phi::TransToPhiKernelName(type_)); auto has_phi_kernel = @@ -1389,8 +1388,7 @@ bool OperatorWithKernel::SupportsMKLDNN( phi_kernels.end(), [data_type](phi::KernelKeyMap::const_reference kern_pair) { return kern_pair.first.backend() == phi::Backend::ONEDNN && - kern_pair.first.dtype() == - framework::TransToPhiDataType(data_type); + kern_pair.first.dtype() == data_type; }); if (has_phi_kernel) { return true; @@ -1406,25 +1404,22 @@ bool OperatorWithKernel::SupportsMKLDNN( [data_type](OpKernelMap::const_reference kern_pair) { return platform::is_cpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == LibraryType::kMKLDNN && - kern_pair.first.data_type_ == data_type; + kern_pair.first.data_type_ == TransToProtoVarType(data_type); }); } } } -bool OperatorWithKernel::SupportsCUDNN( - const proto::VarType::Type data_type) const { +bool OperatorWithKernel::SupportsCUDNN(const phi::DataType data_type) const { auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( phi::TransToPhiKernelName(type_)); - paddle::experimental::DataType phi_data_type = - framework::TransToPhiDataType(data_type); - auto has_phi_kernel = std::any_of( - phi_kernels.begin(), - phi_kernels.end(), - [phi_data_type](phi::KernelKeyMap::const_reference kern_pair) { - return kern_pair.first.backend() == phi::Backend::GPUDNN && - kern_pair.first.dtype() == phi_data_type; - }); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [data_type](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::GPUDNN && + kern_pair.first.dtype() == data_type; + }); if (has_phi_kernel) { return true; } else { @@ -1433,13 +1428,15 @@ bool OperatorWithKernel::SupportsCUDNN( return false; } else { auto& op_kernels = op_kernel_iter->second; + proto::VarType::Type fluid_data_type = + framework::TransToProtoVarType(data_type); return std::any_of( op_kernels.begin(), op_kernels.end(), - [data_type](OpKernelMap::const_reference kern_pair) { + [fluid_data_type](OpKernelMap::const_reference kern_pair) { return platform::is_gpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == LibraryType::kCUDNN && - kern_pair.first.data_type_ == data_type; + kern_pair.first.data_type_ == fluid_data_type; }); } } @@ -1509,14 +1506,19 @@ bool OperatorWithKernel::SupportsKernelType( } bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, - proto::VarType::Type data_type) const { + phi::DataType data_type) const { return ctx.HasAttr("use_mkldnn") && ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()) && this->SupportsMKLDNN(data_type); } +bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { + return this->CanMKLDNNBeUsed(ctx, phi::TransToPhiDataType(data_type)); +} + bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, - proto::VarType::Type data_type) const { + phi::DataType data_type) const { bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr("use_cudnn") && paddle::platform::is_gpu_place(ctx.GetPlace()); @@ -1528,7 +1530,7 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #if defined(PADDLE_WITH_CUDA) - if (use_cudnn && data_type == framework::proto::VarType::BF16) { + if (use_cudnn && data_type == phi::DataType::BFLOAT16) { PADDLE_ENFORCE_GE( platform::DnnVersion(), 8100, @@ -1540,6 +1542,11 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, return use_cudnn && this->SupportsCUDNN(data_type); } +bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { + return this->CanCUDNNBeUsed(ctx, phi::TransToPhiDataType(data_type)); +} + void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { PADDLE_THROW(platform::errors::PermissionDenied( "The default InferShape function of OperatorWithKernel is not allowed to " @@ -1839,8 +1846,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, 1, platform::EventRole::kInnerOp); if (need_prepare_data_) { - transfer_scope = PrepareData( - scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); + transfer_scope = + PrepareData(scope, + framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_), + &transfered_inplace_vars, + runtime_ctx, + dev_ctx->GetPlace()); } } // exec scope is the scope that kernel actually executed on. @@ -1960,7 +1971,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( const ExecutionContext& ctx) const { - auto expected_kernel_key = this->GetExpectedKernelType(ctx); + phi::KernelKey phi_kernel_key = this->GetExpectedKernelType(ctx); + auto expected_kernel_key = + framework::TransPhiKernelKeyToOpKernelType(phi_kernel_key); // NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function // GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and @@ -2063,6 +2076,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( } } } + + if (platform::places_are_same_class(expected_kernel_key.place_, + ctx.GetPlace())) { + expected_kernel_key.place_ = ctx.GetPlace(); + } + VLOG(3) << "op type:" << type_ << ", expected_kernel_key:" << expected_kernel_key; return expected_kernel_key; @@ -2333,9 +2352,10 @@ void OperatorWithKernel::HandleComplexGradToRealGrad( Scope* OperatorWithKernel::PrepareData( const Scope& scope, - const OpKernelType& expected_kernel_key, + const phi::KernelKey& expected_kernel_key, std::vector* transfered_inplace_vars, - RuntimeContext* ctx) const { + RuntimeContext* ctx, + const phi::Place& place) const { Scope* new_scope = nullptr; const std::unordered_set* no_buffer_ins = nullptr; @@ -2378,7 +2398,7 @@ Scope* OperatorWithKernel::PrepareData( // has to be created and registered if ((tensor_in->layout() == DataLayout::ONEDNN) && (var->IsType() == true) && - (expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && + (expected_kernel_key.layout() != DataLayout::ONEDNN) && (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == DataLayout::kNHWC) && (tensor_in->dims().size() >= 3)) { @@ -2411,35 +2431,33 @@ Scope* OperatorWithKernel::PrepareData( auto kernel_type_for_var = GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key); bool need_trans_dtype = - kernel_type_for_var.data_type_ != expected_kernel_key.data_type_; + NeedTransformDataType(expected_kernel_key, kernel_type_for_var); bool need_trans_layout = NeedTransformLayout( - kernel_type_for_var.data_layout_, expected_kernel_key.data_layout_); + kernel_type_for_var.layout(), expected_kernel_key.layout()); if (!need_trans_dtype && !need_trans_layout) { if (!run_phi_kernel_ && - platform::places_are_same_class(kernel_type_for_var.place_, - expected_kernel_key.place_)) { + backends_are_same_class(kernel_type_for_var.backend(), + expected_kernel_key.backend())) { continue; } } - std::unique_ptr new_expected_kernel_key = nullptr; + std::unique_ptr new_expected_kernel_key = nullptr; if (run_phi_kernel_ && in_def != nullptr && in_def->backend != phi::Backend::ALL_BACKEND) { auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); if ((in_def->backend != tensor_backend && - (in_def->backend != phi::Backend::GPUDNN || - tensor_backend != phi::Backend::GPU) && - (in_def->backend != phi::Backend::KPS || - tensor_backend != phi::Backend::XPU) && - (in_def->backend != phi::Backend::ONEDNN || - tensor_backend != phi::Backend::CPU)) || + !(in_def->backend == phi::Backend::GPUDNN && + tensor_backend == phi::Backend::GPU) && + !(in_def->backend == phi::Backend::KPS && + tensor_backend == phi::Backend::XPU) && + !(in_def->backend == phi::Backend::ONEDNN && + tensor_backend == phi::Backend::CPU)) || tensor_in->place().GetType() == AllocationType::GPUPINNED) { - new_expected_kernel_key = std::make_unique( - expected_kernel_key.data_type_, - phi::TransToPhiPlace(in_def->backend), - expected_kernel_key.data_layout_, - expected_kernel_key.library_type_, - expected_kernel_key.customized_type_value_); + new_expected_kernel_key = + std::make_unique(in_def->backend, + expected_kernel_key.layout(), + expected_kernel_key.dtype()); } } @@ -2474,14 +2492,18 @@ Scope* OperatorWithKernel::PrepareData( enable_cache_transfer_scope_ = false; if (!run_by_executor_) { if (new_expected_kernel_key) { - if ((platform::is_gpu_place(kernel_type_for_var.place_) || - platform::is_gpu_place(new_expected_kernel_key->place_))) { + if (kernel_type_for_var.backend() == phi::Backend::GPU || + kernel_type_for_var.backend() == phi::Backend::GPUDNN || + new_expected_kernel_key->backend() == phi::Backend::GPU || + new_expected_kernel_key->backend() == phi::Backend::GPUDNN) { new_scope = TryCreateTransferScope( kernel_type_for_var, *new_expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; } - } else if ((platform::is_gpu_place(kernel_type_for_var.place_) || - platform::is_gpu_place(expected_kernel_key.place_))) { + } else if (kernel_type_for_var.backend() == phi::Backend::GPU || + kernel_type_for_var.backend() == phi::Backend::GPUDNN || + expected_kernel_key.backend() == phi::Backend::GPU || + expected_kernel_key.backend() == phi::Backend::GPUDNN) { new_scope = TryCreateTransferScope( kernel_type_for_var, expected_kernel_key, &scope); enable_cache_transfer_scope_ = true; @@ -2523,11 +2545,15 @@ Scope* OperatorWithKernel::PrepareData( // Do transfer phi::DenseTensor out; - TransformData(new_expected_kernel_key ? *new_expected_kernel_key - : expected_kernel_key, - kernel_type_for_var, - *tensor_in, - &out); + TransformData( + new_expected_kernel_key ? *new_expected_kernel_key + : expected_kernel_key, + kernel_type_for_var, + *tensor_in, + &out, + new_expected_kernel_key + ? phi::TransToPhiPlace(new_expected_kernel_key->backend()) + : place); SetTensorToVariable(*var, out, trans_var); } }; @@ -2818,30 +2844,29 @@ proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes( return target_type; } -OpKernelType OperatorWithKernel::GetExpectedKernelType( +phi::KernelKey OperatorWithKernel::GetExpectedKernelType( const ExecutionContext& ctx) const { - return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); + return phi::KernelKey(IndicateDataType(ctx), ctx.GetPlace()); } -OpKernelType OperatorWithKernel::GetKernelTypeForVar( +phi::KernelKey OperatorWithKernel::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // When the op is first oneDNN op (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey( + tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype()); } #endif - return OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 07e1a26c7c0abd111f34f2cadb4ed70c8e7e9c54..b4e0c94c20be2d189a03fe73d59bd46447acb1c8 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -638,16 +638,22 @@ class OperatorWithKernel : public OperatorBase { bool SupportXPU() const override; - bool SupportsMKLDNN(proto::VarType::Type data_type) const; + bool SupportsMKLDNN(phi::DataType data_type) const; - bool SupportsCUDNN(proto::VarType::Type data_type) const; + bool SupportsCUDNN(phi::DataType data_type) const; bool SupportsKernelType(const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + phi::DataType data_type) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; + bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, + phi::DataType data_type) const; + bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; @@ -665,14 +671,15 @@ class OperatorWithKernel : public OperatorBase { const std::string& name1, const std::string& name2) const; - virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; + virtual phi::KernelKey GetExpectedKernelType( + const ExecutionContext& ctx) const; // change this to public so that in dygraph mode we can call it to check if we // need transform data - virtual OpKernelType GetKernelTypeForVar( + virtual phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const OpKernelType& expected_kernel_type) const; + const phi::KernelKey& expected_kernel_type) const; platform::Place GetExecutionPlace( const platform::Place& platform) const override { @@ -734,9 +741,10 @@ class OperatorWithKernel : public OperatorBase { * transfered_inplace_vars is a output vector. */ Scope* PrepareData(const Scope& scope, - const OpKernelType& expected_kernel_key, + const phi::KernelKey& expected_kernel_key, std::vector* transfered_inplace_vars, - RuntimeContext* ctx) const; + RuntimeContext* ctx, + const phi::Place& place) const; void TransferInplaceVarsBack(const Scope& scope, const std::vector& inplace_vars, diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index a7b597fdb30766fe8dccb6ab9aba10b9b8ba7c83..1d57efd875f069ce14743adf9feee2bdb4beff40 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -127,14 +127,10 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - int sub_type = ctx.Attr("kernel_sub_type"); - return OpKernelType(proto::VarType::FP32, - ctx.GetPlace(), - phi::DataLayout::kAnyLayout, - framework::LibraryType::kPlain, - sub_type); + return phi::KernelKey( + ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); } }; @@ -256,16 +252,6 @@ TEST(OpKernel, all) { // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); - - attr = op_desc.mutable_attrs()->Add(); - attr->set_name("kernel_sub_type"); - attr->set_type(paddle::framework::proto::AttrType::INT); - attr->set_i(1); - auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc); - op2->Run(scope, cpu_place); - // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called. - ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); - ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1); } REGISTER_OP_WITHOUT_GRADIENT( @@ -339,11 +325,11 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "phi::DenseTensor"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -361,11 +347,11 @@ class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; class IndicateSelectedRowsDataTypeTestProtoMaker @@ -383,10 +369,10 @@ class IndicateOtherDataTypeTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { @@ -597,10 +583,10 @@ class OpUnusedVarTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override {} - OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType( - proto::VarType::FP32, ctx.GetPlace(), phi::DataLayout::kAnyLayout); + return phi::KernelKey( + ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32); } }; diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index c812a6dc95a2551c75c696879352f430f75f2c9c..60c2516c0047d1985a2ab6fd918ae6567016a1d9 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -34,12 +34,13 @@ global_transfer_scope_key() { return *x; } -Scope* TryCreateTransferScope(OpKernelType type0, - OpKernelType type1, +Scope* TryCreateTransferScope(const phi::KernelKey& type0, + const phi::KernelKey& type1, const Scope* scope) { Scope* new_scope{nullptr}; size_t infer_cache_key = - CombineHash(OpKernelType::Hash()(type0), OpKernelType::Hash()(type1)); + CombineHash(static_cast(phi::KernelKey::Hash()(type0)), + static_cast(phi::KernelKey::Hash()(type1))); infer_cache_key = CombineHash(infer_cache_key, std::hash()(scope)); diff --git a/paddle/fluid/framework/transfer_scope_cache.h b/paddle/fluid/framework/transfer_scope_cache.h index da2e319d5ba5e41280d3f265217a872bdf992913..58707f501a70ec5a50a236e92c2c4c5f2da1b2ca 100644 --- a/paddle/fluid/framework/transfer_scope_cache.h +++ b/paddle/fluid/framework/transfer_scope_cache.h @@ -39,8 +39,8 @@ static size_t CombineHash(size_t seed, size_t a) { return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } -Scope* TryCreateTransferScope(OpKernelType type0, - OpKernelType type1, +Scope* TryCreateTransferScope(const phi::KernelKey& type0, + const phi::KernelKey& type1, const Scope* scope); } // namespace framework diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index cc3ed77c391d87be4fbb0dc3bccf34a7d0f35bac..43c83a7237757556de6ba5c0461dd69d457bae8e 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -23,7 +23,8 @@ if(WITH_XPU) scalar int_array var_helper - profiler) + profiler + place) else() cc_library( prepared_operator @@ -40,7 +41,8 @@ else() scalar int_array var_helper - profiler) + profiler + place) endif() cc_library( layer diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index d3f163da2a8e341eaf69352e6fb391f5a969df1e..6f1f54de8a95c348a5c3294352cdf4f2c23e909a 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -24,6 +24,7 @@ #include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace imperative { @@ -39,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const framework::AttributeMap* attr, const framework::AttributeMap* default_attr, const std::string op_type, - const framework::OpKernelType* op_kernel_type = nullptr, + const phi::KernelKey* op_kernel_key = nullptr, const phi::ArgumentMappingFn* arg_map_fn = nullptr, const phi::KernelSignature* default_kernel_signature = nullptr) : var_map_in_(in), @@ -47,7 +48,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { attrs_(attr), default_attrs_(default_attr), op_type_(op_type), - op_kernel_type_(op_kernel_type), + op_kernel_key_(op_kernel_key), arg_map_fn_(arg_map_fn), default_kernel_signature_(default_kernel_signature) {} @@ -250,8 +251,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { bool IsRuntime() const override { return true; } bool IsRunMKLDNNKernel() const override { - return (op_kernel_type_ && - (op_kernel_type_->data_layout_ == phi::DataLayout::ONEDNN)); + return (op_kernel_key_ && + (op_kernel_key_->layout() == phi::DataLayout::ONEDNN)); } paddle::small_vector @@ -497,7 +498,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const framework::AttributeMap* attrs_; const framework::AttributeMap* default_attrs_; const std::string op_type_; - const framework::OpKernelType* op_kernel_type_; + const phi::KernelKey* op_kernel_key_; // arg_map_fn_ and default_kernel_signature_ may be nullptr const phi::ArgumentMappingFn* arg_map_fn_; const phi::KernelSignature* default_kernel_signature_; diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 89398c5246d54779ed5e91ca88482aeb50de13ec..2ac43c39d72abb590d9d712a46230488fec8eeb1 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -519,8 +519,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, */ auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); - auto tmp_ins_ptr = - PrepareData(*op_kernel, ins, prepared_op.kernel_type()); + auto tmp_ins_ptr = PrepareData( + *op_kernel, ins, prepared_op.kernel_key(), prepared_op.place()); if (tmp_ins_ptr == nullptr) { prepared_op.Run(ins, outs, attrs, default_attrs); } else { diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 5eb045a0c522396e42450ca14ece251f626292e3..32a4515624570542442ead65f35c5b0722fb7c0c 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -30,6 +30,7 @@ #endif #include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h" @@ -116,14 +117,14 @@ void TestHandleComplexGradToRealGradEager( PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), - kernel_type_(kernel_type), + kernel_key_(kernel_key), func_(func), dev_ctx_(dev_ctx), arg_map_fn_(arg_map_fn), @@ -132,7 +133,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, phi::KernelSignature&& kernel_signature, @@ -140,7 +141,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), - kernel_type_(kernel_type), + kernel_key_(kernel_key), func_(nullptr), dev_ctx_(dev_ctx), run_phi_kernel_(true), @@ -228,7 +229,6 @@ PreparedOp PrepareImpl( const phi::KernelSignature* default_kernel_signature = nullptr; phi::KernelSignature kernel_signature; - phi::KernelKey phi_kernel_key; std::string phi_kernel_name; // NOTE(jiahongyu): The registered MKLDNN kernel have library_type = @@ -240,29 +240,27 @@ PreparedOp PrepareImpl( // 3. Whether mkldnn kernel can be used. #ifdef PADDLE_WITH_MKLDNN if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) && - op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { - expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; - expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN; + op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) { + expected_kernel_key.set_backend(phi::Backend::ONEDNN); + expected_kernel_key.set_layout(phi::DataLayout::ONEDNN); } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) { - expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN; + if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) { + expected_kernel_key.set_backend(phi::Backend::GPUDNN); } #endif #if defined(PADDLE_WITH_XPU) - bool is_xpu_unsupport = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + bool is_xpu_unsupport = expected_kernel_key.backend() == phi::Backend::XPU && + !paddle::platform::is_xpu_support_op( + op.Type(), expected_kernel_key.dtype()); #endif #ifdef PADDLE_WITH_MLU if (is_in_mlu_black_list(op.Type())) { - expected_kernel_key.place_ = platform::CPUPlace(); + expected_kernel_key.set_backend(phi::Backend::CPU); } #endif @@ -290,12 +288,10 @@ PreparedOp PrepareImpl( // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. #ifdef PADDLE_WITH_XPU_KP - if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + if (expected_kernel_key.backend() == phi::Backend::XPU) { bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op( + op.Type(), expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(op.Type()); if (use_xpu_kp_kernel_rt) { @@ -307,17 +303,14 @@ PreparedOp PrepareImpl( bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { - auto expected_kernel_key_library_type = - expected_kernel_key.library_type_; - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + auto expected_kernel_key_backend = expected_kernel_key.backend(); + expected_kernel_key.set_backend(phi::Backend::KPS); VLOG(3) << "modifing XPU KP kernel: " << phi_kernel_name << ", using_kernel_key:" << expected_kernel_key; - phi::KernelKey try_phi_kernel_key = - TransOpKernelTypeToPhiKernelKey(expected_kernel_key); if (!phi_kernel_factory.HasKernel(phi_kernel_name, - try_phi_kernel_key)) { - expected_kernel_key.library_type_ = expected_kernel_key_library_type; + expected_kernel_key)) { + expected_kernel_key.set_backend(expected_kernel_key_backend); VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name << " in dynamic graph is failed " << expected_kernel_key; } else { @@ -328,9 +321,8 @@ PreparedOp PrepareImpl( } #endif - phi_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); auto& phi_kernel = - phi_kernel_factory.SelectKernel(phi_kernel_name, phi_kernel_key); + phi_kernel_factory.SelectKernel(phi_kernel_name, expected_kernel_key); if (phi_kernel.IsValid() #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) @@ -338,13 +330,14 @@ PreparedOp PrepareImpl( #endif ) { VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name - << " | kernel key: " << phi_kernel_key + << " | kernel key: " << expected_kernel_key << " | kernel: " << phi_kernel; - if (expected_kernel_key.place_ != place) { - dev_ctx = pool.Get(expected_kernel_key.place_); + if (!framework::backends_are_same_class( + expected_kernel_key.backend(), + phi::TransToPhiBackend(dev_ctx->GetPlace()))) { + dev_ctx = pool.Get(phi::TransToPhiPlace(expected_kernel_key.backend())); } - return PreparedOp(op, empty_ctx, expected_kernel_key, @@ -368,22 +361,23 @@ PreparedOp PrepareImpl( // registered in KP use library_type[KP], we need to modify it. #ifdef PADDLE_WITH_XPU_KP bool use_xpu_kp_kernel_rt = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && + expected_kernel_key.backend() == phi::Backend::XPU && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( - op.Type(), - framework::TransToPhiDataType(expected_kernel_key.data_type_)); + paddle::platform::is_xpu_support_op(op.Type(), + expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = - paddle::platform::is_xpu_place(expected_kernel_key.place_) && + expected_kernel_key.backend() == phi::Backend::XPU && paddle::platform::is_in_xpu_kpwhite_list(op.Type()); bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + expected_kernel_key.set_backend(phi::Backend::KPS); } #endif + paddle::framework::OpKernelType fluid_kernel_type = + paddle::framework::TransPhiKernelKeyToOpKernelType(expected_kernel_key); if ((kernels_iter == all_op_kernels.end() || - kernels_iter->second.find(expected_kernel_key) == + kernels_iter->second.find(fluid_kernel_type) == kernels_iter->second.end()) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport @@ -393,7 +387,7 @@ PreparedOp PrepareImpl( #endif ) { if (has_phi_kernel) { - auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op); + auto phi_cpu_kernel_key = FallBackToCpu(expected_kernel_key, op); auto& phi_cpu_kernel = phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key); if (phi_cpu_kernel.IsValid()) { @@ -401,15 +395,14 @@ PreparedOp PrepareImpl( << " | kernel key: " << phi_cpu_kernel_key << " | kernel: " << phi_cpu_kernel; auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp( - op, - empty_ctx, - framework::TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key), - arg_map_fn, - default_kernel_signature, - std::move(kernel_signature), - phi_cpu_kernel, - cpu_ctx); + return PreparedOp(op, + empty_ctx, + phi_cpu_kernel_key, + arg_map_fn, + default_kernel_signature, + std::move(kernel_signature), + phi_cpu_kernel, + cpu_ctx); } } } @@ -422,21 +415,21 @@ PreparedOp PrepareImpl( op.Type())); auto& kernels = kernels_iter->second; - auto kernel_iter = kernels.find(expected_kernel_key); + auto kernel_iter = kernels.find(fluid_kernel_type); #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) - if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && + if (paddle::platform::is_xpu_place(fluid_kernel_type.place_) && (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "fluid missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_XPU_KP - if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + if (paddle::platform::is_xpu_place(fluid_kernel_type.place_)) { if (use_xpu_kp_kernel_rt) { VLOG(3) << "fluid xpu_kp using rt mode "; } @@ -444,60 +437,60 @@ PreparedOp PrepareImpl( VLOG(3) << "fluid xpu_kp using debug mode "; } if (is_xpu_kp_support) { - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.library_type_ = paddle::framework::LibraryType::kKP; + kernel_iter = kernels.find(fluid_kernel_type); VLOG(3) << "using fluid XPU KP kernel: " << op.Type() - << ", using_kernel_key:" << expected_kernel_key; + << ", using_kernel_key:" << fluid_kernel_type; } if (!is_xpu_kp_support && (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "fluid missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } } #endif #ifdef PADDLE_WITH_ASCEND_CL if (kernel_iter == kernels.end() && - paddle::platform::is_npu_place(expected_kernel_key.place_)) { + paddle::platform::is_npu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing NPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_IPU if (kernel_iter == kernels.end() && - paddle::platform::is_ipu_place(expected_kernel_key.place_)) { + paddle::platform::is_ipu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing IPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_MLU if (kernel_iter == kernels.end() && - paddle::platform::is_mlu_place(expected_kernel_key.place_)) { + paddle::platform::is_mlu_place(fluid_kernel_type.place_)) { VLOG(3) << "missing MLU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << fluid_kernel_type << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE if (kernel_iter == kernels.end() && - paddle::platform::is_custom_place(expected_kernel_key.place_)) { + paddle::platform::is_custom_place(fluid_kernel_type.place_)) { VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type() << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); + fluid_kernel_type.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(fluid_kernel_type); } #endif // TODO(jiabin): Add operator.cc's line 1000 part back when we need that @@ -507,19 +500,20 @@ PreparedOp PrepareImpl( kernels.end(), platform::errors::NotFound("Operator %s does not have kernel for %s.", op.Type(), - KernelTypeToString(expected_kernel_key))); - - if (!(expected_kernel_key.place_ == place)) { - dev_ctx = pool.Get(expected_kernel_key.place_); - } - - return PreparedOp(op, - empty_ctx, - expected_kernel_key, - kernel_iter->second, - arg_map_fn, - default_kernel_signature, - dev_ctx); + KernelTypeToString(fluid_kernel_type))); + + if (!platform::places_are_same_class(fluid_kernel_type.place_, + dev_ctx->GetPlace())) { + dev_ctx = pool.Get(fluid_kernel_type.place_); + } + return PreparedOp( + op, + empty_ctx, + framework::TransOpKernelTypeToPhiKernelKey(fluid_kernel_type), + kernel_iter->second, + arg_map_fn, + default_kernel_signature, + dev_ctx); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, @@ -576,7 +570,7 @@ template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, @@ -597,7 +591,7 @@ static void PreparedOpRunImpl( &attrs, &default_attrs, op.Type(), - &kernel_type, + &kernel_key, arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); @@ -641,7 +635,7 @@ static void PreparedOpRunImpl( * grad op kernel executed, we need to recognize this situation and * convert dx to float32 type. HandleComplexGradToRealGrad does this thing. */ - if (framework::IsComplexType(kernel_type.data_type_)) { + if (framework::IsComplexType(kernel_key.dtype())) { HandleComplexGradToRealGrad(outs); } } @@ -649,7 +643,7 @@ static void PreparedOpRunImpl( template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature& kernel_signature, @@ -669,7 +663,7 @@ static void PreparedOpRunPtImpl( &attrs, &default_attrs, op.Type(), - &kernel_type, + &kernel_key, arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); @@ -712,7 +706,7 @@ static void PreparedOpRunPtImpl( #endif } - if (framework::IsComplexType(kernel_type.data_type_)) { + if (framework::IsComplexType(kernel_key.dtype())) { HandleComplexGradToRealGrad(outs); } } @@ -723,7 +717,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -736,7 +730,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, @@ -754,7 +748,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -767,7 +761,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, @@ -785,7 +779,7 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl(op_, - kernel_type_, + kernel_key_, arg_map_fn_, default_kernel_signature_, kernel_signature_, @@ -798,7 +792,7 @@ void PreparedOp::Run(const NameVarMap& ins, } else { PreparedOpRunImpl(op_, ctx_, - kernel_type_, + kernel_key_, func_, arg_map_fn_, default_kernel_signature_, diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index a3d90939faeaed9450218fcd4f77757a0c252c6a..fb36a03e01890849b34faeca06f2368901c6d450 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -29,6 +29,7 @@ #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/var_helper.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" @@ -75,7 +76,8 @@ template std::shared_ptr> PrepareData( const framework::OperatorWithKernel& op, const NameVarMap& ins, - const framework::OpKernelType& expected_kernel_key) { + const phi::KernelKey& expected_kernel_key, + const phi::Place& place) { std::shared_ptr> tmp_ins_ptr = nullptr; for (const auto& name_pair : ins) { for (size_t i = 0; i < name_pair.second.size(); ++i) { @@ -85,7 +87,8 @@ std::shared_ptr> PrepareData( if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) { auto kernel_type_for_var = op.GetKernelTypeForVar( name_pair.first, *tensor, expected_kernel_key); - if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { + if (!framework::NeedTransform(kernel_type_for_var, + expected_kernel_key)) { continue; } else { VLOG(3) << "Transform Variable " << GetNameFromVar(template_var) @@ -111,10 +114,10 @@ std::shared_ptr> PrepareData( (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; } else { phi::DenseTensor out; - TransformData( - expected_kernel_key, kernel_type_for_var, *tensor, &out); - if (NeedTransformDataType(kernel_type_for_var, - expected_kernel_key)) { + framework::TransformData( + expected_kernel_key, kernel_type_for_var, *tensor, &out, place); + if (framework::NeedTransformDataType(kernel_type_for_var, + expected_kernel_key)) { // To avoid NameVarMap copy construction overhead in general // scenarios, if inplace transformed, return original input // directly @@ -149,7 +152,7 @@ class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const framework::OperatorWithKernel::OpKernelFunc& func, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, @@ -157,7 +160,7 @@ class PreparedOp { PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, - const framework::OpKernelType& kernel_type, + const phi::KernelKey& kernel_key, const phi::ArgumentMappingFn* arg_map_fn, const phi::KernelSignature* default_kernel_signature, phi::KernelSignature&& kernel_signature, @@ -200,12 +203,14 @@ class PreparedOp { const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); - const framework::OpKernelType& kernel_type() const { return kernel_type_; } + const phi::KernelKey& kernel_key() const { return kernel_key_; } + + const phi::Place& place() const { return dev_ctx_->GetPlace(); } private: const framework::OperatorBase& op_; const framework::RuntimeContext& ctx_; - framework::OpKernelType kernel_type_; + phi::KernelKey kernel_key_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; // NOTE(chenweihang): Similar op members are used to adapt to diff --git a/paddle/fluid/imperative/tests/test_eager.cc b/paddle/fluid/imperative/tests/test_eager.cc index 3eec90462d293f5f35f93bb32b7b0a7990fd5d96..6c27dead2709da5b4e3f834abd8d1049ef4eec77 100644 --- a/paddle/fluid/imperative/tests/test_eager.cc +++ b/paddle/fluid/imperative/tests/test_eager.cc @@ -92,15 +92,15 @@ TEST(test_var_helper, eager_var_helper) { ASSERT_TRUE(platform::is_cpu_place(GetPlace(egr_tensor))); ASSERT_TRUE(GetDataType(egr_tensor) == framework::proto::VarType::FP32); - GetCachedValue( - egr_tensor, - framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace())); - SetCachedValue( - egr_tensor, - framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()), - egr_tensor2); + GetCachedValue(egr_tensor, + phi::KernelKey(phi::Backend::CPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32)); + SetCachedValue(egr_tensor, + phi::KernelKey(phi::Backend::CPU, + phi::DataLayout::ALL_LAYOUT, + phi::DataType::FLOAT32), + egr_tensor2); ASSERT_ANY_THROW(GetPlace(egr_tensor2)); ASSERT_ANY_THROW(SetType( egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY)); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 613b72919714cdb9b4323465962887cf46b2a51b..76510b39ce85d43edde3cddc91456dc9f0d31f57 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -172,7 +172,8 @@ TEST(test_prepare_op, test_prepare_data) { PrepareData( dynamic_cast(*op), ins, - prepared_op.kernel_type()); + prepared_op.kernel_key(), + gpu_place); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( @@ -229,7 +230,8 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { PrepareData( dynamic_cast(*op), ins, - prepared_op.kernel_type()); + prepared_op.kernel_key(), + cpu_place); for (const auto& name_pair : ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( diff --git a/paddle/fluid/imperative/var_helper.cc b/paddle/fluid/imperative/var_helper.cc index b5f1c8f1fdd54f6ab7f932a0d02fec7e2ac6f8a4..bafea5a720d3a71be940d397eed1204b78d96f0d 100644 --- a/paddle/fluid/imperative/var_helper.cc +++ b/paddle/fluid/imperative/var_helper.cc @@ -239,35 +239,31 @@ template void SetDataLayout( /* CheckCachedKey */ template -bool CheckCachedKey(std::shared_ptr var, - const paddle::framework::OpKernelType &key) { +bool CheckCachedKey(std::shared_ptr var, const phi::KernelKey &key) { return GetVariableWrapper(var)->hasCacheKey(key); } template <> bool CheckCachedKey( - std::shared_ptr tensor, - const paddle::framework::OpKernelType &key) { + std::shared_ptr tensor, const phi::KernelKey &key) { // TODO(jiabin): Support this later // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is // equal to self: " << key == key. return false; } -template bool CheckCachedKey( - std::shared_ptr var, const paddle::framework::OpKernelType &key); +template bool CheckCachedKey(std::shared_ptr var, + const phi::KernelKey &key); template bool CheckCachedKey( - std::shared_ptr var, - const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); /* GetCachedValue */ template -std::shared_ptr GetCachedValue( - std::shared_ptr var, const paddle::framework::OpKernelType &key) { +std::shared_ptr GetCachedValue(std::shared_ptr var, + const phi::KernelKey &key) { return GetVariableWrapper(var)->getCacheValue(key); } template <> std::shared_ptr GetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key) { + std::shared_ptr var, const phi::KernelKey &key) { // TODO(jiabin): Support this later // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this @@ -277,22 +273,21 @@ std::shared_ptr GetCachedValue( return std::make_shared(""); } template std::shared_ptr GetCachedValue( - std::shared_ptr var, const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); template std::shared_ptr GetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key); + std::shared_ptr var, const phi::KernelKey &key); /* SetCachedValue */ template void SetCachedValue(std::shared_ptr var, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res) { GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res)); } template <> void SetCachedValue( std::shared_ptr tensor, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res) { // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this @@ -300,13 +295,12 @@ void SetCachedValue( // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key // is equal to self: " << key == key << " and res name is:" << res->Name(). } -template void SetCachedValue( - std::shared_ptr var, - const paddle::framework::OpKernelType &key, - std::shared_ptr res); +template void SetCachedValue(std::shared_ptr var, + const phi::KernelKey &key, + std::shared_ptr res); template void SetCachedValue( std::shared_ptr var, - const paddle::framework::OpKernelType &key, + const phi::KernelKey &key, std::shared_ptr res); } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/var_helper.h b/paddle/fluid/imperative/var_helper.h index 5e96e865482cab22db46b259183083e8f17d51c4..ebf3e49c51870fcb48096068b980852b43d7d986 100644 --- a/paddle/fluid/imperative/var_helper.h +++ b/paddle/fluid/imperative/var_helper.h @@ -43,16 +43,14 @@ template const std::string& GetNameFromVar(std::shared_ptr var); template -bool CheckCachedKey(std::shared_ptr tensor, - const paddle::framework::OpKernelType& key); +bool CheckCachedKey(std::shared_ptr tensor, const phi::KernelKey& key); template void SetCachedValue(std::shared_ptr tensor, - const paddle::framework::OpKernelType& key, + const phi::KernelKey& key, std::shared_ptr res); template -std::shared_ptr GetCachedValue( - std::shared_ptr tensor, - const paddle::framework::OpKernelType& key); +std::shared_ptr GetCachedValue(std::shared_ptr tensor, + const phi::KernelKey& key); template void SetType(std::shared_ptr var, diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index c1024b6d58aec6a6811f73fb37b2aa106cf7f8a4..d4438e8b47b970c8c39f5943a617215b243719a9 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -234,16 +234,15 @@ class VariableWrapper { } } - bool hasCacheKey(const paddle::framework::OpKernelType& key) { + bool hasCacheKey(const phi::KernelKey& key) { return var_cache.find(key) != var_cache.end(); } - std::shared_ptr getCacheValue( - const paddle::framework::OpKernelType& key) { + std::shared_ptr getCacheValue(const phi::KernelKey& key) { return var_cache[key]; } - void setCacheValue(const paddle::framework::OpKernelType& key, + void setCacheValue(const phi::KernelKey& key, std::shared_ptr val) { var_cache[key] = val; return; @@ -323,8 +322,7 @@ class VariableWrapper { // Used for cache the dtype promotioned variableWrapper in real and complex // compute of Paddle Quantum - std::map> - var_cache; + std::map> var_cache; // add this property for users may set stop_gradient themselves and this // should override the frameworks setting (-1) unset, (1) true, (0) false int overrided_stop_gradient_{-1}; diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 3310bdbbe8254955bdeb1a3f4b6fdfc5c32748aa..0bf78f41d64691176a1c02d2c3b6517d8821fa04 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -29,11 +29,11 @@ class AbsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -70,11 +70,11 @@ class AbsGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -124,20 +124,17 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } }; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 53cd5c92cda3c1e08add2d296524dea50101687e..649382ffc92f77a9feda1452cf6fdc896424a80c 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -80,9 +80,9 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker { } }; -framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper, - const std::string& name) { +phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper, + const std::string& name) { auto data_type = oper.IndicateVarDataType(ctx, name); // FIXME(liuwei1031) temporarily disable the code to unblock users // TODO(liuwei1031) figure out the reason behind @@ -94,7 +94,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, // library = framework::LibraryType::kCUDNN; // } // #endif - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } class ActivationOp : public framework::OperatorWithKernel { @@ -107,7 +107,7 @@ class ActivationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } @@ -134,7 +134,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, framework::GradVarName("Out")); } @@ -341,7 +341,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -370,7 +370,7 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -411,7 +411,7 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "DDX"); } @@ -487,20 +487,22 @@ class PowOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "FactorTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -515,20 +517,22 @@ class PowOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, framework::GradVarName("Out")); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "FactorTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -537,7 +541,7 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } @@ -548,7 +552,7 @@ class PowOpTripleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index cd4a9fbdb332cceaf140eebd2a7ea8f89c85dbc3..0f52362c21e911d50339c94255cb0a8a5bd4a99d 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -34,11 +34,10 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -54,11 +53,11 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 408d1c565e5d125d5f98f56cb20f54d5b131884e..90d8c8b0ce12fb9f355f46b2ccfd7e9707cd95f7 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -145,11 +145,11 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index b23d3670d5e80885a11a2759011c93e0eff552a6..a0cb5480d51b1831dd98439e586ac05e7faab2e0 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -130,10 +130,10 @@ class AffineGridOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -241,11 +241,11 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Output")); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index fa6bc1d6f77574de4decc1fd055e2b9d03a08839..ab876921f9c8578de207ef45b4fc7a135c912923 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -65,11 +65,10 @@ class AllcloseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/alloc_float_status_op.cc b/paddle/fluid/operators/amp/alloc_float_status_op.cc index fc96dd52e54a2b54ffa0eb7c0fd76cda0a2dc3d6..24e960867716ef0aac04b26c66b24bdfd0098808 100644 --- a/paddle/fluid/operators/amp/alloc_float_status_op.cc +++ b/paddle/fluid/operators/amp/alloc_float_status_op.cc @@ -34,10 +34,9 @@ class AllocFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc index a8d1f36f1159d40d3f49e3c083202f425384b596..c8faf2d6553e3318716f465718ec8faba9b9d157 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc @@ -29,13 +29,13 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; if (ctx.MultiInputVar("X").size() >= 1) { dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/clear_float_status_op.cc b/paddle/fluid/operators/amp/clear_float_status_op.cc index 7bfc2d34d296e1135fb225c0a31280bdf1d9b607..06e4b986fa7f860e098113ff00a1cf1bb9be5f90 100644 --- a/paddle/fluid/operators/amp/clear_float_status_op.cc +++ b/paddle/fluid/operators/amp/clear_float_status_op.cc @@ -34,10 +34,9 @@ class ClearFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/get_float_status_op.cc b/paddle/fluid/operators/amp/get_float_status_op.cc index 88a2affbcaabaff1633859242129b372bcaa5732..d5a924b8d842c3f0ce7170ae16496a205e831556 100644 --- a/paddle/fluid/operators/amp/get_float_status_op.cc +++ b/paddle/fluid/operators/amp/get_float_status_op.cc @@ -34,10 +34,9 @@ class GetFloatStatusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc index f8ccac27c19c9082ec3ddb7df85ae22054c20ee3..7f9b7da62f4d4ecd985a98ca1da965374908609f 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cc @@ -29,23 +29,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; if (ctx.MultiInputVar("X").size() >= 1) { dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifndef PADDLE_WITH_XPU if (var_name == "FoundInfinite" || var_name == "StopUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } #endif return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 0e44fd2fa27bbe3ab9c999e250d5ba98acd6c69e..090fdff31c1dacdbc780295f9cf328ffaa83f1c8 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -32,11 +32,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/ascend_trigger_op.cc b/paddle/fluid/operators/ascend_trigger_op.cc index abb39dce7a20a51bedc03b829f84dd20c421cc81..b312f97d3f93d1d8147950a64362297bb3b18cae 100644 --- a/paddle/fluid/operators/ascend_trigger_op.cc +++ b/paddle/fluid/operators/ascend_trigger_op.cc @@ -23,10 +23,10 @@ class AscendTriggerOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 1af424fa77dbe63949621fa21796bba09c804ac7..244b3aec9c904a5c13267161fdc4fc109dc0d47b 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -41,16 +41,16 @@ class AssignOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const framework::Variable *var = ctx.InputVar("X"); if (var->IsType()) { @@ -58,14 +58,13 @@ class AssignOp : public framework::OperatorWithKernel { // NOTE(liym27): Support an empty tensor array as Input. // And set the kernel type is float. if (t_arr.size() == 0) { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_pos_op.cc b/paddle/fluid/operators/assign_pos_op.cc index 80412c7d6786ae0dca6b1d084bc9e6538d1e9c4b..24fc4adc60d94f6c0ce11e080dfb2501c5e4daaf 100644 --- a/paddle/fluid/operators/assign_pos_op.cc +++ b/paddle/fluid/operators/assign_pos_op.cc @@ -31,7 +31,7 @@ class AssignPosOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto cum_count_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "cum_count"); @@ -46,7 +46,7 @@ class AssignPosOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The dtype of the cum_count_dtype, eff_num_len and " "X should be same as int64")); - return framework::OpKernelType(cum_count_dtype, ctx.device_context()); + return phi::KernelKey(cum_count_dtype, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/assign_value_op.cc b/paddle/fluid/operators/assign_value_op.cc index b9806c7a696cc44c3397c8a33e77bd6575dc8e00..766e55b03168a7b15169261fdf115668a1ac265b 100644 --- a/paddle/fluid/operators/assign_value_op.cc +++ b/paddle/fluid/operators/assign_value_op.cc @@ -44,9 +44,9 @@ class AssignValueOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 330b13ab8b2901d1c9c655168c75e33a90e0f1a1..c4617138553917f7644006fb27cfcd7ddb00111b 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -198,10 +198,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Cell"); } -framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( +phi::KernelKey AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/attention_lstm_op.h b/paddle/fluid/operators/attention_lstm_op.h index 0ce83be93c6cce918c7185c2febc18566faada8b..391afc459f90f637c9dc7be30d76a33ab772f99a 100644 --- a/paddle/fluid/operators/attention_lstm_op.h +++ b/paddle/fluid/operators/attention_lstm_op.h @@ -25,7 +25,7 @@ class AttentionLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 9f8f295c249353e7645f0ed9dd3daf2aa6510662..a59b78c3cd44b595fd3729e04d3ac037629491a7 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -26,10 +26,10 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 38504e3ecdf189ac5261e2697bc352636b968619..9010cadd1533238d0e401be425330c3170595a4e 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -77,11 +77,10 @@ class BatchFCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -106,11 +105,11 @@ class BatchFCGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 32cb10ec89031dfb47b27d1962c26f68966a1d36..21a06e5257acd4c3357b59dc3534540e521e88fa 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -171,7 +171,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { } } -framework::OpKernelType BatchNormOp::GetExpectedKernelType( +phi::KernelKey BatchNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -202,18 +202,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( platform::errors::InvalidArgument( "Variance input should be of float type")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType BatchNormOp::GetKernelTypeForVar( +phi::KernelKey BatchNormOp::GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "X") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -222,13 +222,12 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void BatchNormOpMaker::Make() { @@ -373,7 +372,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { } } -framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( +phi::KernelKey BatchNormGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -392,18 +391,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( +phi::KernelKey BatchNormGradOp::GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -412,13 +411,12 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } template @@ -515,7 +513,7 @@ void BatchNormDoubleGradOp::InferShape( } } -framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( +phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar("DY"); if (var == nullptr) { @@ -532,8 +530,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 0e579010a91d796831856128c971a05bd6af04c7..d6a1038c00167e223dd46e7c4fa80a284895c8cd 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -47,13 +47,13 @@ class BatchNormOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class BatchNormGradOp : public framework::OperatorWithKernel { @@ -62,13 +62,13 @@ class BatchNormGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class BatchNormDoubleGradOp : public framework::OperatorWithKernel { @@ -77,7 +77,7 @@ class BatchNormDoubleGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 3c775ced3f434cd1ef5e6fa9d516efd7fae0f74d..d1be450e815fc2361d740185d2eabfbce00f3d7f 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -28,11 +28,10 @@ class BCELossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -87,11 +86,10 @@ class BCELossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 49669f1b350d9fd5a631a490248ec7654212a60d..1e569c4bb27324a91bdcebdfca3f5af55a284802 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -108,7 +108,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *scores = ctx.Input("scores"); size_t level = ctx.Attr("level"); @@ -116,11 +116,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { // The current CUDA kernel only support cases with batch_size < 4. // Compute on CPU for cases with batch_size > 4. if (batch_size <= 4) { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), ctx.GetPlace()); } else { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/bilateral_slice_op.cc b/paddle/fluid/operators/bilateral_slice_op.cc index 8b7968d2a8839500db906990ef8b2484262a3d8d..c824fd9e6316046fd4b0009fb6a27087c930a44c 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cc +++ b/paddle/fluid/operators/bilateral_slice_op.cc @@ -85,10 +85,10 @@ class BilateralSliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -147,11 +147,11 @@ class BilateralSliceOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/bincount_op.cc b/paddle/fluid/operators/bincount_op.cc index 5f5e19c585baea2695bd83074d97dc9d39655aae..484431eeefa60389ab68ba9e63c0e1f1840677e9 100644 --- a/paddle/fluid/operators/bincount_op.cc +++ b/paddle/fluid/operators/bincount_op.cc @@ -24,19 +24,17 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; - class BincountOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto data_type = ctx.HasInput("Weights") ? OperatorWithKernel::IndicateVarDataType(ctx, "Weights") : OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 20ea0b187f64e2c3d9061427b20c37aaaec70af9..47aea124430373b3c322a2ce1a2045a265204de4 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -56,11 +56,10 @@ class BprLossOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of Seq-bpr // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -119,11 +118,10 @@ class BprLossGradientOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/broadcast_tensors_op.cc b/paddle/fluid/operators/broadcast_tensors_op.cc index 34a76e86aae0d87008f5026092b59600bb9dff69..6d924644192c9526cebbda27fc7423c778474c5e 100644 --- a/paddle/fluid/operators/broadcast_tensors_op.cc +++ b/paddle/fluid/operators/broadcast_tensors_op.cc @@ -27,14 +27,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // Broadcast semantics enforces all input variables having the same // DataType/VarType // This condition is also checked during VarType Inference // Here we simply copy input type to output - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -127,11 +127,11 @@ class BroadcastTensorsGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 10b25fc4787447a5cc5c9ac0f9e1194f73232252..192fe35a9bb4ded923c77270223ff7319f22c50e 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -75,7 +75,7 @@ class CastOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // CastOp kernel's device type is decided by input tensor place auto *tensor = ctx.Input("X"); @@ -86,9 +86,8 @@ class CastOp : public framework::OperatorWithKernel { auto &tensor_place = tensor->place(); // NOTE: cuda pinned tensor need to copy its data to target place if (platform::is_cuda_pinned_place(tensor_place)) { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + ctx.device_context().GetPlace()); } // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN @@ -108,20 +107,19 @@ class CastOp : public framework::OperatorWithKernel { auto src_type = static_cast(ctx.Attr("in_dtype")); auto dst_type = static_cast(ctx.Attr("out_dtype")); if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), tensor_place); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + tensor_place); } else { VLOG(3) << "MLU not support cast type: " << framework::DataTypeToString(src_type) << " to type: " << framework::DataTypeToString(dst_type) << ", fallbacking to CPU one!"; - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), - platform::CPUPlace()); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + platform::CPUPlace()); } #endif - return framework::OpKernelType( - framework::TransToProtoVarType(tensor->dtype()), tensor_place); + return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()), + tensor_place); } }; diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index f168eb10ae76993d997987c0f0b70bede67f35f2..4639e53350eb40bf51bbfeeb488242be9d7ab074 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -53,11 +53,10 @@ class CenterLossOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -115,11 +114,11 @@ class CenterLossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/chunk_eval_op.cc b/paddle/fluid/operators/chunk_eval_op.cc index 6ad9f6d491ed7b2218b99043c79b2386631ef385..71268eb12df933510141e9caccb1aea6c3d0ce7f 100644 --- a/paddle/fluid/operators/chunk_eval_op.cc +++ b/paddle/fluid/operators/chunk_eval_op.cc @@ -88,10 +88,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc index c0dafd8534468829c52a746e1d2ad704dc296549..6e44a5ce2effcbb3f32c52a019676d03159a3f19 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc @@ -57,10 +57,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { * specified a data type here. * */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index 8147541cbaa932777689e8207d57d15100c17e42..4ce45aeea957d76b6792d3a9f98c9c48fda013fd 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -117,10 +117,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel { * Of course, the data type here is also not important. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/class_center_sample_op.cc b/paddle/fluid/operators/class_center_sample_op.cc index cb766dae225c6a182e2a9ef4a82ffaddc2ef8ca4..54f0e981ca078dd525a79f5561643d9b84b04eee 100644 --- a/paddle/fluid/operators/class_center_sample_op.cc +++ b/paddle/fluid/operators/class_center_sample_op.cc @@ -26,11 +26,10 @@ class ClassCenterSampleOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Label"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Label"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 997c017d3129cc4feb948f027f5fb8b65c07a792..1fdc4e9a123dc0d2faed4f289c4315c1d8a89fb1 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -26,11 +26,11 @@ namespace operators { class ClipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -85,11 +85,11 @@ class ClipOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 75e6df4baf82b1d037fb63306571f9e5d81bf5fa..e16950e31d12922bd488e7b3757ce8b9e26d290d 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -405,20 +405,20 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &context) const override { auto dtype = static_cast( context.Attr("dtype")); - return framework::OpKernelType(dtype, context.GetPlace()); + return phi::KernelKey(dtype, context.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/collective/allreduce_op.cc b/paddle/fluid/operators/collective/allreduce_op.cc index b3351dc82b7e755d43def2976f2aaf89edaa543c..e136d8ef6e3889fb92b80daaab289cbe0e569cfa 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cc @@ -27,10 +27,10 @@ class AllReduceOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/alltoall_op.cc b/paddle/fluid/operators/collective/alltoall_op.cc index b5512fdc52452eca7b8e0371c917b13f70f7ff0f..e6fa37e0e42d8f55e33dea841308efb3d087790d 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cc @@ -36,10 +36,10 @@ class AllToAllOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 268144f183fba03489e4409f1554ef0ad13fb62f..f63c4a9abcdc6e18a75e3b591a626be33ea30671 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -71,21 +71,23 @@ class CAllReduceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { if (var_name == "Cond") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index 49b1bd5fd9f8abdf3eddbde113463439968c5c81..35c395681e4c1e728b4906e1fe2e293fc56557b2 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -26,10 +26,10 @@ class CBroadcastOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_concat_op.cc b/paddle/fluid/operators/collective/c_concat_op.cc index 75e41dba92e05599df1ba7fb9cb4ffb6089bcb7d..ed29654048c60ae9df4e3f2382dbb55a4daefb84 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cc @@ -58,10 +58,10 @@ class CConcatOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_embedding_op.cc b/paddle/fluid/operators/collective/c_embedding_op.cc index caea70c223bd349b5a5983253bf7557f579b9f09..aa16d8e1827735953b09c234e7699baae9264ae0 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cc +++ b/paddle/fluid/operators/collective/c_embedding_op.cc @@ -65,10 +65,10 @@ class CEmbeddingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -149,11 +149,11 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc index 8d743139d09e65b5f0ed2a7909c0d5278219fb58..55728b21fb6706f97cd2672d5e2c77fd73185628 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -35,10 +35,10 @@ class CIdentityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 3e752011f152e2762580881bf87b41dcb76ac311..680af73af6b6c514196a94b0b109f850c4738207 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -66,10 +66,10 @@ class CReduceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/c_scatter_op.cc index d6d4cc03dc0c4bcf026dac21a614c51cd524d70a..d122ffbcd9d9b9836b1efd6b87a99eb6bf9c9c21 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cc @@ -52,10 +52,10 @@ class CScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc index a8c19b863810dbc166d7b85a55cc6d12aef85799..97d72457a24965c525e72109df55c4e92c96fa80 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cc @@ -73,11 +73,10 @@ class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -150,11 +149,11 @@ class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_split_op.cc b/paddle/fluid/operators/collective/c_split_op.cc index 5c6f126b78ccf4694796b655cfa181a509c4db2c..52ce38cd179c4b73967fc320b1000f16a94b0caf 100644 --- a/paddle/fluid/operators/collective/c_split_op.cc +++ b/paddle/fluid/operators/collective/c_split_op.cc @@ -66,10 +66,10 @@ class CSplitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.h b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h index 5b26e47a8fdc7abd1e045b129ab5b2c103014592..da3fdd345393473918b69f819dec4428fafc1135 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.h +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.h @@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + +#pragma once + #include #include "paddle/fluid/framework/op_registry.h" @@ -25,10 +28,9 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 67fff7655133da515704f03006c1c15ec39085fe..ff7fb09f7a228c19fc8fc4cdf8652709f68edebd 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -23,10 +23,9 @@ class CSyncCommStreamOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/global_gather_op.cc b/paddle/fluid/operators/collective/global_gather_op.cc index ee8cc39f442609eca14892b45d61d7252f4b3756..f3380b4498331f7fa3260a41ca0b1b14317c4743 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cc @@ -49,10 +49,10 @@ class GlobalGatherOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/global_scatter_op.cc b/paddle/fluid/operators/collective/global_scatter_op.cc index 5d81acb226b9ced3b7c66d2b866e30d100893d46..d4469c5eadbbd832928a7d20c4d871c37df5d6b6 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cc @@ -52,10 +52,10 @@ class GlobalScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/partial_recv_op.cc b/paddle/fluid/operators/collective/partial_recv_op.cc index f0effde61b7e2426d247e58e6b95aee89cdcb855..37e060acc28c62c5e6837fdd392023935708da58 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cc @@ -80,12 +80,12 @@ class PartialRecvOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int dtype = ctx.Attr("dtype"); framework::proto::VarType::Type type = framework::proto::VarType::Type(dtype); - return framework::OpKernelType(type, ctx.GetPlace()); + return phi::KernelKey(type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/partial_send_op.cc b/paddle/fluid/operators/collective/partial_send_op.cc index f11973e20c4f0da11b3ae2e7c7c96acfbe0c9ab1..59ab1cfa6ea1a37a8ed29148c811cdadecc29e1c 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cc @@ -51,10 +51,10 @@ class PartialSendOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc index a35e1a7dda8dcb414feebc98c2371e8b2abea5c9..2b51e913fb511a497b95ce67ea7b47632aa52629 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -69,12 +69,12 @@ class RecvOpV2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int dtype = ctx.Attr("dtype"); framework::proto::VarType::Type type = framework::proto::VarType::Type(dtype); - return framework::OpKernelType(type, ctx.GetPlace()); + return phi::KernelKey(type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc index 8323b9841537aeefac1c7395e0a35c4f86880808..5652f079906025b120650235f90cb6abaeb0fffe 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -38,7 +38,7 @@ class SendOpV2 : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const framework::Variable* var = ctx.InputVar("X"); if (var->IsType()) { @@ -46,12 +46,11 @@ class SendOpV2 : public framework::OperatorWithKernel { // NOTE(sandyhouse): Support an empty tensor array as Input. // And set the kernel type is float. if (t_arr.size() == 0) { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 0c6e7b31c9d2e1748b63ff5a07159febd96f48f6..21e4bfcf7093cec747aed71e8de03883096a72a5 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -32,7 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -48,18 +48,20 @@ class ConcatOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "All Inputs of Concat OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -110,22 +112,24 @@ class ConcatOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cc b/paddle/fluid/operators/controlflow/bitwise_op.cc index 4b339f4bd58627c0e6d18b0ae0fa9610fda7c019..0922c9f5d437cca69abecf4a9a3c67be8c123dc9 100644 --- a/paddle/fluid/operators/controlflow/bitwise_op.cc +++ b/paddle/fluid/operators/controlflow/bitwise_op.cc @@ -97,11 +97,12 @@ class UnaryBitwiseOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // BitwiseOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; @@ -138,11 +139,12 @@ class BinaryBitwiseOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // BitwiseOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index ba580f4097e3324e6dc1543dc4a9be70cdc638e7..26d0dce91c3202b1886f3dbd12d5e5ec747c2db1 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -61,19 +61,20 @@ class CompareOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // CompareOp kernel's device type is decided by input tensor place bool force_cpu = ctx.Attr("force_cpu"); if (force_cpu) { - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); } else { if (ctx.Input("X")->place().GetType() != phi::AllocationType::GPUPINNED) { - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); } else { - kt.place_ = ctx.GetPlace(); + kt.set_backend(phi::TransToPhiBackend(ctx.GetPlace())); } } return kt; diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index b70211c1e167933b171d30462a0b7561575b54aa..5a99dd695c02ba41a2dd3357d4caf0daa47adb23 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -72,48 +72,49 @@ class FetchV2Op : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (!tensor.IsInitialized()) { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *fetch_var = ctx.InputVar("X"); if (fetch_var == nullptr) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); if (!src_item.IsInitialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } else if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); if (!src_item.initialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } else { auto &src_item = fetch_var->Get(); if (src_item.empty() || !src_item[0].IsInitialized()) { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CPUPlace()); } } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/controlflow/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc index c6dde6f4ba53aaf7bd73d802505c4a200068742a..6a9fcaf852b15ce169072484057fed2db40673a3 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -69,11 +69,12 @@ class LogicalOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // LogicalOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 0262c74923e1e80b8c5e4f5b61106682c26853fb..e41270de650a47333a50f25fe7d600480679dde3 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -186,7 +186,7 @@ std::vector ConvOp::ComputeOutputShape( return output_shape; } -framework::OpKernelType ConvOp::GetExpectedKernelType( +phi::KernelKey ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); // todo enable data layout when it's ready @@ -208,18 +208,18 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( paddle::framework::DataTypeToString(filter_data_type))); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType ConvOp::GetKernelTypeForVar( +phi::KernelKey ConvOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "Input") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -228,13 +228,12 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for conv // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Conv2DOpMaker::Make() { @@ -447,23 +446,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { } } -framework::OpKernelType ConvOpGrad::GetExpectedKernelType( +phi::KernelKey ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( +phi::KernelKey ConvOpGrad::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if (((var_name == "Input") || (var_name == framework::GradVarName("Output"))) && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -472,13 +471,12 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } template @@ -619,10 +617,10 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { } } -framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( +phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 62bcfb545e00f40b3ffd356b430f3cf9423a51b6..29345f1432e8d6066b03e84c1ef193897e457323 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -196,13 +196,13 @@ class ConvOp : public framework::OperatorWithKernel { std::vector ComputeOutputShape( framework::InferShapeContext* ctx) const; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvOpGrad : public framework::OperatorWithKernel { @@ -211,13 +211,13 @@ class ConvOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvOpDoubleGrad : public framework::OperatorWithKernel { @@ -226,7 +226,7 @@ class ConvOpDoubleGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index ebc9f8afdb0b7f84644d718bd3f3427ac8d9938c..e5333c5ed5a286e6da3582e3611fed9c1926ec10 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -33,21 +33,21 @@ namespace operators { using DataLayout = phi::DataLayout; -framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( +phi::KernelKey ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( +phi::KernelKey ConvTransposeOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN // Only input require reshaping, weights and // bias are having shape in NCHW order if ((var_name == "Input") && - (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -56,13 +56,12 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Conv2DTransposeOpMaker::Make() { @@ -253,10 +252,10 @@ Example: )DOC"); } -framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( +phi::KernelKey ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } template @@ -320,10 +319,10 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker { } }; -framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( +phi::KernelKey ConvTransposeOpDoubleGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index d47828e5bdc8f1ec15294cdc4878ef0ef0937ab3..61860b6907756946a6e95f7e355730943b59ba34 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -38,13 +38,13 @@ class ConvTransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class ConvTransposeOpGrad : public framework::OperatorWithKernel { @@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -61,7 +61,7 @@ class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/correlation_op.cc b/paddle/fluid/operators/correlation_op.cc index 2b3450d0316079b7a0cf5e2626f393067199dcf5..c1b3fb25bc243a28b3c1361c565415871f67bec0 100644 --- a/paddle/fluid/operators/correlation_op.cc +++ b/paddle/fluid/operators/correlation_op.cc @@ -109,7 +109,7 @@ class CorrelationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input1"); @@ -118,7 +118,7 @@ class CorrelationOp : public framework::OperatorWithKernel { ctx.Input("Input2")->dtype()), platform::errors::InvalidArgument( "X and Y shoule have the same datatype")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -158,9 +158,9 @@ class CorrelationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index 62bd73374b3a189431572441e146546c0d3501ca..5844beb9c06d5be9667e18d6b5b112f6f6ca887d 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -202,9 +202,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 462764230f48480ff91885fee6b06344a65f9b5a..b615fbd58faeca23b837cd0770f7266050c62123 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -58,11 +58,10 @@ class CropOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -182,11 +181,11 @@ class CropOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 99e67406c3cbadf4493e05737f6016d36a868102..bdee50e773e4a9d4519e3db636b5e05e9b4551e9 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -126,11 +126,10 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { @@ -192,11 +191,11 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of cross_entropy // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Y")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context().GetPlace()); } virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index 7731b7207180a12a3725def8517595cfa729b747..1df3def180e5c57b4b884ad4e8391790a6d4cd75 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -35,11 +35,10 @@ class CTCAlignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index f5fd56edef900ef6c5fd60e0cafc633332179227..8a004ac4a27d6f2a7f4bac4b0af7440694cae5bc 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -95,11 +95,10 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -249,11 +248,11 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { SetOutGradDim("InitH"); SetOutGradDim("InitC"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 29bc83bd9ae518ad11d3eaed4a626e15158a0bfc..6d1089ecf72a44850cfdd9e7cfd666fc15b5e7de 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -25,11 +25,11 @@ class CumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 11af33df2f61b1d2ce6f869a16d825b67b19bcba..54fa8e0031961689f20ae0a1b939e702e73fff3c 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -48,11 +48,10 @@ class CVMOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // cvm // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -114,11 +113,11 @@ class CVMGradientOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // cvm // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Y")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 6770a7e31c1a53b72a04feb37e2dbd6922f100ee..f4c850c423b61b4174ed6ff61f82e0f8dd2b8f6c 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -159,7 +159,7 @@ class DataNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -195,7 +195,7 @@ class DataNormOp : public framework::OperatorWithKernel { "bias input should be of float type")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -475,7 +475,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -494,7 +494,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { } auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/decode_jpeg_op.cc b/paddle/fluid/operators/decode_jpeg_op.cc index acb94f57bf4e8cf6d7b75522d07854ef3a6fea6e..521798e8ddce8194143e53cc195b7c4df5bd292b 100644 --- a/paddle/fluid/operators/decode_jpeg_op.cc +++ b/paddle/fluid/operators/decode_jpeg_op.cc @@ -32,24 +32,23 @@ class DecodeJpegOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "X") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } }; diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index b916d069d10d60264397ac9fc5d69a76ba8c0663..d6eff438e0bdb0439c680d2760006493aebdaca6 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -113,11 +113,10 @@ class DeformableConvOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -173,11 +172,10 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index ed70e546789814e3a9bac944df44695b5b782bd6..a597c1e00396d5f5f03be3ad2a4ca04d8d7b4257 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -118,11 +118,10 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -172,11 +171,10 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index 5240116c6a4f88dce97e405c95b85869747890c2..6e284d8e7bc70d13ada549c3979ad8551c258aea 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -290,11 +290,10 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -338,11 +337,10 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_abs_max_op.cc b/paddle/fluid/operators/dequantize_abs_max_op.cc index 99c4fad0fa2ab27bd81f81049aa61ac817efdfcc..bf329324ff7bd6f86ac91fe9840233730bd5d411 100644 --- a/paddle/fluid/operators/dequantize_abs_max_op.cc +++ b/paddle/fluid/operators/dequantize_abs_max_op.cc @@ -68,11 +68,10 @@ class DequantizeMaxAbsOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - auto type = framework::OpKernelType(data_type, ctx.device_context()); - return type; + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_log_op.cc b/paddle/fluid/operators/dequantize_log_op.cc index 62359a2ce2124a26910da4aae7e2879efca350e4..94299e153dd9ff17bd209e9cce1fc49dd031fb3c 100644 --- a/paddle/fluid/operators/dequantize_log_op.cc +++ b/paddle/fluid/operators/dequantize_log_op.cc @@ -75,11 +75,10 @@ class DequantizeLogOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - auto type = framework::OpKernelType(data_type, ctx.device_context()); - return type; + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index c39f351fcb178243018796a5cf54e27c3b896f36..9bb7b0eaa1bb58646dda3f9ba0d7568715a15ad0 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -19,15 +19,14 @@ limitations under the License. */ namespace paddle { namespace operators { -framework::OpKernelType DeQuantOp::GetExpectedKernelType( +phi::KernelKey DeQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/dequantize_op.h b/paddle/fluid/operators/dequantize_op.h index f319828a6be4b5b26cd5e2e31aa72bd5a20a44d4..4aee7502d6d58963ada48cb41487fe948fd01111 100644 --- a/paddle/fluid/operators/dequantize_op.h +++ b/paddle/fluid/operators/dequantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class DeQuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class DeQuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 530b5a1ee13d9673547aeea4555444f932fece1d..7a1397ba08f17d75aa7abec2d3e2e426f9a02b59 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -61,11 +61,10 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index 583122b473d269f0d749f4e15bea366cb1002b8c..8bf542e17caed069e99beef5bb1f2a5858665ccd 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -50,9 +50,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index e07e4034f330f2ff9580b2cc791030ae510bc3f1..8c607c98c14b2c17861848f895b772e9bca44c6c 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -87,11 +87,11 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index 8b74f46cd315ffc8792401aa55647fe7e1a95430..def0f3f6d8dd3ef97272b5eddf0005862592f305 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -105,10 +105,10 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 20b8846bc4cc6cec908354a1a866b9c8096aa623..9fa761abcfabc2c17fa771d2b40a63c0c850adc2 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -27,10 +27,10 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 7ae5ba6ca8f9caeabfb9ca4888bd0f9ce8a0caad..6acc0431762cf78f858b54839016af4cd0b1b5db 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -110,10 +110,10 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index b11030f1d086ad576e2089468ed32619636b6d17..dcffa170b6a3a49d5783c91e870988a7818092ff 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -160,10 +160,10 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 030b99cd1dbd70db2743265ba4d81ba60333907a..d6987c7ba8c7ec22e2ac2ceac22a85d9744d5dd5 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -62,11 +62,11 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc index 0445c21b1de3bebe4ee919828c5449e15e1ca7e6..885a3575664287966c3df9e852922f5363622f6f 100644 --- a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc @@ -34,11 +34,11 @@ class GenerateProposalsV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/locality_aware_nms_op.cc b/paddle/fluid/operators/detection/locality_aware_nms_op.cc index 1c5135fc4e8a78509c760343a3442a7b4147f3d3..9a230dc3224daf0b864da3214cb516d3fed0a8d6 100644 --- a/paddle/fluid/operators/detection/locality_aware_nms_op.cc +++ b/paddle/fluid/operators/detection/locality_aware_nms_op.cc @@ -79,9 +79,9 @@ class LocalityAwareNMSOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/matrix_nms_op.cc b/paddle/fluid/operators/detection/matrix_nms_op.cc index 1beeaf1ba3356f5b2c2a3807f834cbe701937745..8038e4a42cc11230f1f96ba5dbb5775ec2ce9dcb 100644 --- a/paddle/fluid/operators/detection/matrix_nms_op.cc +++ b/paddle/fluid/operators/detection/matrix_nms_op.cc @@ -25,9 +25,9 @@ class MatrixNMSOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index 28099630b8347c54bf2c8aa1e52532192fffdc80..a673d64c52d19a64182fbe9c47dc840df0e5b1ce 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -316,9 +316,9 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 79077b3086671deb828e95deeffcac3cde3ede00..9dc6a8cc1f29ecb574804f0916c91620f9c73376 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -112,9 +112,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index 66682c67870baa23f0b4547511537501f99740ba..9171b9ab25ea46c4d541a4aba694ab5c63747137 100644 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -69,10 +69,10 @@ class NMSOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 28251c32ddee96376df767370aadadfd3be695c2..be1e224cd30fe6e2f0efb1d9704a8e853f8ddb99 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -26,11 +26,11 @@ class PriorBoxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_input_type, ctx.GetPlace()); + return phi::KernelKey(input_input_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc index d2654e086d08d11a2576da6af32bfbea8338503d..a36d6a9f6c703dc69554a33ad34fed8ad2a94939 100644 --- a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -166,12 +166,12 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Scores"); - return framework::OpKernelType(input_data_type, - platform::CPUPlace()); // ctx.GetPlace()); + return phi::KernelKey(input_data_type, + platform::CPUPlace()); // ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 9ba51850ebaaae4ba9e9daeb6a4885615c3ba77f..27442c5dadce9836194637058e628820248c59c3 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -559,11 +559,10 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -585,11 +584,10 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index ba7fe51383822ea95169845f84db6b64644a20bd..531c823442c3f8b25844f7974c822c3da755dc7d 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -94,9 +94,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } @@ -851,9 +851,9 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc index 91479a78b63b1c77919a1e04125db94bb286647d..ff27945d18720fee84e9d1edd02df6f6531eebc3 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -89,11 +89,10 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -180,11 +179,10 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index c3d79b0505070cc6858ebf0127286a08882c1973..155ec31fa92d712c8308c8353c7bcdc8159d3517 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -77,11 +77,10 @@ class TargetAssignOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index fbf4b55dfe44e02226c2321a739298f1484a7f37..a60f42de66a68f68498a85b89af512f85a144751 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -129,10 +129,10 @@ class YoloBoxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 0b8fc79826f1c5192117f2866e28fcd5bf846727..21aca33f65a1aaf45522887430b837db51b20d88 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -26,11 +26,10 @@ class Yolov3LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -179,11 +178,10 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index ada4d18eb00c19c3e16becd20f64f3f96fd839a1..2620fa3a8fa7aa2ba769e5d5ba3892923ec377fd 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -89,9 +89,9 @@ class DetectionMAPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"), platform::CPUPlace()); } diff --git a/paddle/fluid/operators/determinant_op.cc b/paddle/fluid/operators/determinant_op.cc index 56e39747afc5e2637f161f67dbd3a7588a8080d1..62cecbd36ae47b33eb77e3397dc13b5c6eda500e 100644 --- a/paddle/fluid/operators/determinant_op.cc +++ b/paddle/fluid/operators/determinant_op.cc @@ -70,11 +70,11 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.cc b/paddle/fluid/operators/dgc_clip_by_norm_op.cc index 7c759490393583771a6532d0322d027dae4cb6ef..2f8d7ca96f8670936a4e61b26560127e99373081 100644 --- a/paddle/fluid/operators/dgc_clip_by_norm_op.cc +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.cc @@ -31,13 +31,15 @@ class DGCClipByNormOp : public ClipByNormOp { return ClipByNormOp::InferShape(ctx); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/dgc_op.cc b/paddle/fluid/operators/dgc_op.cc index e247ab05ebadd30a697aa773a19db9c53b78a1c5..171dc84000c98db0d344150ac6a026681eba393f 100644 --- a/paddle/fluid/operators/dgc_op.cc +++ b/paddle/fluid/operators/dgc_op.cc @@ -45,13 +45,15 @@ class DGCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step" || var_name == "k" || var_name == "nranks") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 804834a974aad314147a4f6482e98800ba340922..c6ee1180b5b66d39ad539c55fee6753373b4f0b9 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -27,24 +27,26 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Seed") { VLOG(10) << "var_name:" << var_name << " does not need to transform in dropout op"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -133,11 +135,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/edit_distance_op.cc b/paddle/fluid/operators/edit_distance_op.cc index c4c5db6b50cdabc15c7c4dbd02dc7a85135d9cff..5eef3d72b39086c954091f76066571b9a120a95e 100644 --- a/paddle/fluid/operators/edit_distance_op.cc +++ b/paddle/fluid/operators/edit_distance_op.cc @@ -24,10 +24,10 @@ class EditDistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/eigvalsh_op.cc b/paddle/fluid/operators/eigvalsh_op.cc index 9d09b96280e2f0647994e4858b5a1f10800935d6..27c70f1e9b9a9ac737922b96858435ed18ab5b6c 100644 --- a/paddle/fluid/operators/eigvalsh_op.cc +++ b/paddle/fluid/operators/eigvalsh_op.cc @@ -66,11 +66,11 @@ class EigvalshGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Eigenvectors"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 5f169e20e3dc3f9e3b7e28e62c95fe22ac1ea01a..458fc7afb9de2205cf036e17badec6c465309762 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -68,11 +68,11 @@ class EinsumGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 8c7aa350b4372df621f61f793a897e5a1b740908..4c1afa3f6ca8fdc0fbf2d5ecfe34973770e11501 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -41,25 +41,22 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 7f233eba88333e98e924707ddeabfcd4b9b9ba59..8d1b52325de69eb734f28512f529099b64ca4ec1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -27,26 +27,23 @@ class ElementwiseMulOp : public ElementwiseOp { public: using ElementwiseOp::ElementwiseOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 1ed8f4eb012a24881908f6604f78a5c0da4104b3..7048cf50293c926ebb404069495ccd7cd66227f6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -151,39 +151,36 @@ class ElementwiseOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When elementwise is first oneDNN op (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -300,26 +297,23 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -345,25 +339,22 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -380,7 +371,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type input_data_type; if (ctx.HasInput("DDX") == false) { @@ -399,22 +390,19 @@ class ElementwiseOpDoubleGradWithoutDXDY input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY"); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -446,26 +434,23 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type input_data_type; input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index 47dc2eb383249b0a82019f0f7f876bf02a7bbdf3..a5c707d460a2504411632ab23b559616c5fad540 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -53,21 +53,23 @@ class EmptyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& context) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(context.Attr("dtype")), context.GetPlace()); } diff --git a/paddle/fluid/operators/expand_as_op.cc b/paddle/fluid/operators/expand_as_op.cc index b793d835fca98b8798a9efb078bd2a126bf9f7b0..107fe9f6174b615cd48c41db98c5d425279507e2 100644 --- a/paddle/fluid/operators/expand_as_op.cc +++ b/paddle/fluid/operators/expand_as_op.cc @@ -106,11 +106,11 @@ class ExpandAsGradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 09dc0f68cce2afe2b6b0b6c2a53d6c9fabe6b193..5e0f98c3eedb37d34f3c37dffa6871dddcd9032f 100644 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -25,11 +25,10 @@ class ExpandAsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -77,11 +76,11 @@ class ExpandAsV2GradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 67b8102181e1b8355b3e739a780814a7a0f3ad1f..43fd505acdae415695c8f501a21b8805ed52a0af 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -77,22 +77,23 @@ class ExpandOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_times_tensor" || var_name == "ExpandTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -206,22 +207,24 @@ class ExpandGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_times_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 6bf40fd3bb6b8e0603eac44405f6f4955de2febc..cbd322f38767e9e3598a90ae28c2cd172c41552d 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -33,22 +33,24 @@ class ExpandV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -150,22 +152,24 @@ class ExpandV2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "expand_shapes_tensor" || var_name == "Shape") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/exponential_op.cc b/paddle/fluid/operators/exponential_op.cc index 26e06e50a77840bc6aa8f55a145e84fff857fc14..52ddd9ebfa16f0a8990507ee40b45a6c0000872f 100644 --- a/paddle/fluid/operators/exponential_op.cc +++ b/paddle/fluid/operators/exponential_op.cc @@ -24,10 +24,10 @@ class ExponentialOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc index 629400a403e461115d8c98ef6ee927fe3af3df44..57582c694e8d997cf4602c4b5989d21d316f0add 100644 --- a/paddle/fluid/operators/eye_op.cc +++ b/paddle/fluid/operators/eye_op.cc @@ -25,9 +25,9 @@ class EyeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index a5742af7425a32578d46e4e9ae0dedf314bbcca7..65e4b28326a9fe55f3dd1657e60a70a4bc6da16c 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -405,11 +405,10 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -472,10 +471,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -553,10 +552,10 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -631,11 +630,10 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -711,11 +709,10 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -791,10 +788,10 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -847,11 +844,11 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index d4d160d315def63d3c6ed5cf5045be78718c573b..94b5ba1c5c89a51f75b6e48b9001b7490da9b494 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -124,11 +124,11 @@ class FCOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index bf79a98d21df47c19e2dca836653e5dfaea04371..2efe0eeb720b0f95e8a8b75a0e8e8c63ae402135 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -31,23 +31,24 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); const auto &data_type = ctx.Attr("dtype"); if (data_type >= 0) { - kt.data_type_ = static_cast(data_type); + kt.set_dtype(phi::TransToPhiDataType( + static_cast(data_type))); } return kt; } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc index 871a8314c5ad63eca8f03b6ecc24cef847a9724b..66c0470ac0b53b9e72fb212f9d3d224ad5b08c9a 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc @@ -23,13 +23,13 @@ namespace operators { class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kernel_type = framework::OpKernelType( + phi::KernelKey kernel_type = phi::KernelKey( static_cast(ctx.Attr("dtype")), - ctx.device_context()); + ctx.GetPlace()); if (ctx.Attr("force_cpu")) { - kernel_type.place_ = platform::CPUPlace(); + kernel_type.set_backend(phi::Backend::CPU); } return kernel_type; } diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 82c6b89063bea24bfee03ba14b87725b44c544ca..4b2ee8763cd659164f39deac26ca05bbb3e90ad3 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -56,46 +56,47 @@ class FillConstantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::proto::VarType::Type(ctx.Attr("dtype")); - framework::OpKernelType kt = - framework::OpKernelType(input_data_type, ctx.GetPlace()); + phi::KernelKey kt = phi::KernelKey(input_data_type, ctx.GetPlace()); // TODO(zyfncg) The force_cpu and place_type are conflicted, it's an issue // left before, and we may merge them in the future. // In order to invoke new fill_constant kernel, the place of OpKernelType // will be setted by force_cpu and place_type here. if (ctx.Attr("force_cpu")) { - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); } auto place_type = ctx.Attr("place_type"); if (place_type != -1) { switch (place_type) { case 0: - kt.place_ = platform::CPUPlace(); + kt.set_backend(phi::Backend::CPU); break; case 1: case 2: - kt.place_ = platform::CUDAPlace(); + kt.set_backend(phi::Backend::GPU); break; case 3: - kt.place_ = platform::XPUPlace(); + kt.set_backend(phi::Backend::XPU); break; case 4: - kt.place_ = platform::NPUPlace(); + kt.set_backend(phi::Backend::NPU); break; default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/operators/fill_diagonal_op.cc b/paddle/fluid/operators/fill_diagonal_op.cc index 8a7f5daa9f857ee6e09e22516c1480c91c385433..373a63b7fffc61a26ab71951f914d967cd3fa00c 100644 --- a/paddle/fluid/operators/fill_diagonal_op.cc +++ b/paddle/fluid/operators/fill_diagonal_op.cc @@ -50,10 +50,10 @@ class FillIDiagonalOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -71,12 +71,12 @@ class FillIDiagonalGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // Note: don't get data type from ctx.Input("Input"); auto dtype = framework::TransToProtoVarType( ctx.Input(framework::GradVarName("Out"))->type()); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 8937676c344ff8ce3920a4b1834b0efcfa3e1b5a..bcb7081847111cbe029f561dc2452ec315a017d4 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -51,9 +51,9 @@ class FillOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/fill_zeros_like_op.cc b/paddle/fluid/operators/fill_zeros_like_op.cc index 8bd0e328c1f5b6a9cbb9bb377d5f3b5c4259f856..aff240ca4a42ce603f83bf423f68858193bd11e9 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cc @@ -55,9 +55,9 @@ class FillZerosLikeOp2 : public FillZerosLikeOp { using FillZerosLikeOp::FillZerosLikeOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/filter_by_instag_op.cc b/paddle/fluid/operators/filter_by_instag_op.cc index 808792468ff38d892087d9a08762c8a241e413a6..3fe43017eb7623fa04631f808ac21d7c4d372dc5 100644 --- a/paddle/fluid/operators/filter_by_instag_op.cc +++ b/paddle/fluid/operators/filter_by_instag_op.cc @@ -59,10 +59,10 @@ class FilterByInstagOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -126,11 +126,11 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 9b96e27ab761e6b835f295dfd9b24d309fd14106..54e35a6f03dd315b93078432fdb63e10c98925a0 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -81,11 +81,11 @@ class FlattenOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -153,11 +153,11 @@ class FlattenGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -217,11 +217,11 @@ class Flatten2Op : public framework::OperatorWithKernel { ctx->ShareLoD("X", "XShape"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -269,11 +269,11 @@ class Flatten2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -387,11 +387,11 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index c1acc9a38b915e86713ae89cd234533653f1eaeb..4f59e88bd039143367c98d0d54685dd5d7ab2f93 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -65,15 +65,10 @@ class FSPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context(), - layout_, - library_); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -131,11 +126,11 @@ class FSPOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index f25dc393d3a191a7991ea3077d2acac5fa6e0ea9..6b1f533b34f4d52847a72d69d1cf43e5ad3eb926 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -253,11 +253,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -588,11 +588,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 02494e33e1241f646aeb8b4fd47e61763ddaf95b..a6fa80a4939728043b311127ee10db6c50325d56 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -60,11 +60,11 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -190,11 +190,11 @@ class FusedBiasDropoutResidualLnGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.cc b/paddle/fluid/operators/fused/fused_bn_activation_op.cc index e68be43eb7ec09853a7b586384210982f296e43b..88b11f1ef39c5840ca4031d9faa5acdb41469cde 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.cc @@ -156,7 +156,7 @@ void FusedBatchNormActOp::InferShape(framework::InferShapeContext *ctx) const { ctx->ShareLoD("X", "Y"); } -framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormActOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -187,11 +187,7 @@ framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType( platform::errors::PreconditionNotMet( "Variance input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void FusedBatchNormActOpMaker::Make() { @@ -297,7 +293,7 @@ void FusedBatchNormActGradOp::InferShape( ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } -framework::OpKernelType FusedBatchNormActGradOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormActGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -315,14 +311,8 @@ framework::OpKernelType FusedBatchNormActGradOp::GetExpectedKernelType( platform::errors::NotFound("Can not get the tensor value of Y@GRAD.")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.h b/paddle/fluid/operators/fused/fused_bn_activation_op.h index b71812db9d3d382335967eb196a2281f9143a8a0..78ba849eaaacb2f0d4caecfcec54927eb8106bd8 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.h @@ -33,7 +33,7 @@ class FusedBatchNormActOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -43,7 +43,7 @@ class FusedBatchNormActGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc index 08f7087b48d01968f049c04d50003c7202ac15d2..58a950f9238e095e10245e19bad9455ff08f0733 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc @@ -134,7 +134,7 @@ void FusedBatchNormAddActOp::InferShape( ctx->ShareLoD("X", "Y"); } -framework::OpKernelType FusedBatchNormAddActOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormAddActOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -152,11 +152,7 @@ framework::OpKernelType FusedBatchNormAddActOp::GetExpectedKernelType( ctx.Input("Bias")->dtype()), platform::errors::InvalidArgument("Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void FusedBatchNormAddActOpMaker::Make() { @@ -255,7 +251,7 @@ void FusedBatchNormAddActGradOp::InferShape( ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } -framework::OpKernelType FusedBatchNormAddActGradOp::GetExpectedKernelType( +phi::KernelKey FusedBatchNormAddActGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -273,14 +269,8 @@ framework::OpKernelType FusedBatchNormAddActGradOp::GetExpectedKernelType( platform::errors::NotFound("Can not get the tensor value of Y@GRAD.")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h index bdb1f2f35444c5459ca5d1ea6fc4b1f1d939889d..2d20a880e77728f8945281592d2e3623b2ce0dca 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h @@ -33,7 +33,7 @@ class FusedBatchNormAddActOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; @@ -43,7 +43,7 @@ class FusedBatchNormAddActGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 8c81a646fdebb108ffa39211f96659f8f4f0cae7..2e7152ecb294e3277da52d68bf56ee3d69d56fe8 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -172,14 +172,14 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE_EQ(ctx.Input("X")->dtype(), ctx.Input("Y")->dtype(), platform::errors::InvalidArgument( "The element's type of input should be the same.")); - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -389,11 +389,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc index 4f8c4d12d6b58b8a64d81404fc51312932631344..232321c65bf2b5bd2e04e81a1c7f621941b027e3 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc @@ -103,7 +103,7 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("Embs"); auto input_data_type = framework::proto::VarType::Type(0); @@ -119,7 +119,7 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::PreconditionNotMet( "All Inputs of fused_embedding_eltwise_layernorm OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 11b9044bc547287aaa78f7daac7ff9dc54ef7b01..bec18220e9afdb949769e2dd7f025409b1f6857c 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -169,11 +169,11 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->ShareLoD("Ids", "XX"); } -framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( +phi::KernelKey FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"), - ctx.device_context()); + ctx.GetPlace()); } void FusedEmbeddingFCLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h index 19039ec55946d35528afb565e75e36452cb912b5..29db2e9961532cb5f4ddb239b936219044ece445 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h @@ -25,7 +25,7 @@ class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index bbb5ce50c90cac85cdd47db28891f61509e35bcc..a5f20ffadc105326d65a9d0341dd07fa4804388d 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -72,10 +72,10 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -141,10 +141,10 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 3bf039829ac3d6133f1d2ba31a60ce22eba55104..b6edf5ed4482d4e11a6781c87cb9fbfbf7900a5d 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -120,10 +120,10 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -344,11 +344,11 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input = ctx.Input("X"); auto input_data_type = framework::TransToProtoVarType(input->dtype()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index b4fc1b57d8f15e09ce4c008dbc123666b9c510da..187eb4fc07ea2f3dbc422ab8c618190a4c75bfec 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -144,12 +144,10 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -318,12 +316,10 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc index e1be5afa0bd6868229d0fd2ccf657f7bdacf0e51..ca9edb682b1c98924e72b8f5f56ab63ca19bfdfc 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -166,22 +166,24 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 89d2275e06d34c0def78f74aa96b72dedc5a5f2e..5448578c2617e68baca8e46c92c2fd9d9d2726ab 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -133,22 +133,24 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 95c82c72efde154eec69caec3bd1d7f6cd8d471a..79ad83ab1477866bc220776596240232367464a7 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -93,7 +93,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -109,10 +109,10 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of fused_seqpool_cvm OP are Empty!")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - // return framework::OpKernelType(framework::proto::VarType::FP32, + return phi::KernelKey(input_data_type, ctx.GetPlace()); + // return phi::KernelKey(framework::proto::VarType::FP32, // ctx.device_context()); - // return framework::OpKernelType( + // return phi::KernelKey( // OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -210,11 +210,11 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc index 9df22199106726bc1f3922459bf6e9066b4d6ef5..7b737d6885610ada165a7b4df25dda59eb305fe9 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -71,11 +71,10 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_group_op.cc b/paddle/fluid/operators/fused/fusion_group_op.cc index 36b97ea7b12bc5d1358aba9bb60704f9f02d27ae..362819d97ff95df19278048646f8474cf1aaa87c 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.cc +++ b/paddle/fluid/operators/fused/fusion_group_op.cc @@ -76,10 +76,10 @@ class FusionGroupOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - platform::CUDAPlace(0)); + return phi::KernelKey(framework::proto::VarType::FP32, + platform::CUDAPlace(0)); }; }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index fc7804f9c4e8c9fc9acbe2596e7e887dc7e6ae2b..b7e4fa54938ae7781a0a2bbc16f0af167959fb25 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -147,10 +147,10 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "XX"); } -framework::OpKernelType FusionGRUOp::GetExpectedKernelType( +phi::KernelKey FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_gru_op.h b/paddle/fluid/operators/fused/fusion_gru_op.h index 94bf38068d0dd395e741a925e71c5a1a0f4fb3c1..e811df655099d89527330a83f4be05504c3c53f5 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.h +++ b/paddle/fluid/operators/fused/fusion_gru_op.h @@ -25,7 +25,7 @@ class FusionGRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index c526fdc18428c7636fb20ff679d63c826174d258..57d40f1cae86ef4993e4afd169ef01f94a89b25f 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -170,10 +170,10 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "XX"); } -framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( +phi::KernelKey FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.h b/paddle/fluid/operators/fused/fusion_lstm_op.h index 93f8eb981bbd9689e311f4d471a1b3e3649c5f41..c62060d7c225cdb115c79d52f6a378f6b1a915fe 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.h +++ b/paddle/fluid/operators/fused/fusion_lstm_op.h @@ -25,7 +25,7 @@ class FusionLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index bab06f55be8569400e516927020916e4334e35f1..154b0366eebcfc31a233bc8aefccbaf0a3104811 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -99,10 +99,10 @@ void FusionRepeatedFCReluOp::InferShape( ctx->ShareLoD("X", /*->*/ "Out"); } -framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( +phi::KernelKey FusionRepeatedFCReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionRepeatedFCReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h index 16025bf5181b65dc94082be1d437add2c8ea756e..62eae8f7c0525c4923bdee2a622bcf3fe8191824 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h @@ -25,7 +25,7 @@ class FusionRepeatedFCReluOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index c9166919636bf78ad90e25d65e6ad434a63b3cdf..e9428aea006e78bb760525fbcf604de97bde4fcb 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -88,10 +88,10 @@ void FusionSeqConvEltAddReluOp::InferShape( ctx->ShareLoD("X", "Out"); } -framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( +phi::KernelKey FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h index 96f231f9a3cd56f3ea4ce61325923a2821fe3f7e..42e0c57b1133aa79c3a0a887ae1e75b34b0302af 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h @@ -25,7 +25,7 @@ class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index dd5b3c0073f3cedef2ef282a088b348db17d7aa6..86eb7053f88e89a9996af6d844641294df1b638f 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -102,10 +102,10 @@ void FusionSeqExpandConcatFCOp::InferShape( ctx->ShareLoD("X", "Out", 0); } -framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( +phi::KernelKey FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h index 495de5f233445043b65723504115de4976721b68..7438b6c7174873230bb5640352f03c2368ca7791 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h @@ -25,7 +25,7 @@ class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index f2f7801d7c2c5c029859be16c7c05fe7803ef426..9fe789e3102a2f62df45ddb9ef7bf22c2b3270d0 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -68,10 +68,10 @@ void FusionSeqPoolConcatOp::InferShape( } } -framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( +phi::KernelKey FusionSeqPoolConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqPoolConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h index 2e2d6e07dc7e5f36c14d3bec6ea458bd185418d4..5761330a76614b92ac4e3d1b49fbb19799b4b9ba 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.h @@ -25,7 +25,7 @@ class FusionSeqPoolConcatOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc index e3953f9e6abc040c9a80e721d39113dc247aef95..f9ee16eb8109b052ffc2a80e5fadc1491f747f27 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -67,10 +67,10 @@ void FusionSeqPoolCVMConcatOp::InferShape( ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast(n)}); } -framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( +phi::KernelKey FusionSeqPoolCVMConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSeqPoolCVMConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h index b9d7d0dfc340ec45f98f6a5686746df1db89b240..6d45ad4cb96ddd2627590fe7e6a3ef0a3120a922 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h @@ -25,7 +25,7 @@ class FusionSeqPoolCVMConcatOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 8d7f792f3c25b19ad6966f5dd315b1f6638a0383..67fcc6527416c5cb5d6bf4eb7cc7b6f9825404cc 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -63,10 +63,10 @@ void FusionSquaredMatSubOp::InferShape( ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]}); } -framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( +phi::KernelKey FusionSquaredMatSubOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } void FusionSquaredMatSubOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h index fc6a54fd9eb03b0bed5376ddf9b219015ea498b7..41bde97c4bdb09bd6573ed947aaad5811dbe7f9a 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h @@ -26,7 +26,7 @@ class FusionSquaredMatSubOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 0552c3ce9b51151d7f5061778922f32e2b98c2a1..b66ea9b202062f40dc2d83a396e0a4b655e4f951 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -138,13 +138,12 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); } -framework::OpKernelType MultiGRUOp::GetExpectedKernelType( +phi::KernelKey MultiGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"))); } void MultiGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/multi_gru_op.h b/paddle/fluid/operators/fused/multi_gru_op.h index 1846d819600f9604719a5513cff5439589882150..956fcce59c1313e0483510577d00a32992a75c7c 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.h +++ b/paddle/fluid/operators/fused/multi_gru_op.h @@ -28,7 +28,7 @@ class MultiGRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op.cc b/paddle/fluid/operators/fused/resnet_basic_block_op.cc index b449ca3bbe8da4f6d8f617a75dab8a504266f90b..d17e6c9872a029f41bf35830f46b3e7709c581e5 100644 --- a/paddle/fluid/operators/fused/resnet_basic_block_op.cc +++ b/paddle/fluid/operators/fused/resnet_basic_block_op.cc @@ -219,7 +219,7 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -247,10 +247,7 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -545,21 +542,15 @@ class ResNetBasicBlockGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar(framework::GradVarName("Y")), platform::errors::NotFound( "Can not find Y@GRAD in the execution context.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc index 4b46dc76b260e67e2bed190a6bdd05adec860c77..05aa019a5a488fee55d49e8a32d015905f808019 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -200,7 +200,7 @@ class ResNetUnitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -217,10 +217,7 @@ class ResNetUnitOp : public framework::OperatorWithKernel { ctx.Input("BiasX")->dtype()), platform::errors::InvalidArgument( "Bias input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -392,21 +389,15 @@ class ResNetUnitGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar(framework::GradVarName("Y")), platform::errors::NotFound( "Can not find Y@GRAD in the execution context.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 4907153a11874a25dc782ca31a65914100c2257e..4b85dee9a270be1d8a29ff869c8430170f445708 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -32,21 +32,22 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Axis") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -55,21 +56,23 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Axis") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc index d98721cfff300c108eff3f0688eaca07ff2c1fa0..84f5479a61dcdf3acd1705a188bec8ef9f717004 100644 --- a/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/gaussian_random_batch_size_like_op.cc @@ -23,9 +23,9 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 0f81d7fec31845414008935d301152370cbcb18f..03c1c4dd64c4c8d354861842c02fca6739b9d427 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -51,22 +51,24 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index b28c8bdc1a297891945626b5f4d009f7d41e77cc..207589edd591a6a5ca0cac35f2848e11ebac87b5 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -254,7 +254,7 @@ paddle::small_vector outputs { {% macro get_expected_kernel(op) %} {% set kernel = op["kernel"] %} -framework::OpKernelType GetExpectedKernelType( +phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { {%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {% if kernel["data_type"]["candidates"] | length == 1 %} @@ -273,7 +273,7 @@ framework::OpKernelType GetExpectedKernelType( } {% endif %} {% endif %} - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } {% endmacro %} diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index 658352d844d9a21267a7da946a50aaace0591e89..7df3a292e50e197063ea1d6550db6bc94ea381a5 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -47,11 +47,10 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_khop_sampler_op.cc b/paddle/fluid/operators/graph_khop_sampler_op.cc index 4702d66c3ccb3ff1b9aab68f429f6a6a288882d0..1cb5ac3c3071cab347cebbec02f332a76850bf22 100644 --- a/paddle/fluid/operators/graph_khop_sampler_op.cc +++ b/paddle/fluid/operators/graph_khop_sampler_op.cc @@ -90,11 +90,10 @@ class GraphKhopSamplerOP : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Row"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_reindex_op.cc b/paddle/fluid/operators/graph_reindex_op.cc index 7bdd1708b68993f7086cc1911350e64779566e16..c24af3f16d7c6ad5bdd9af133f447e6ac9af433f 100644 --- a/paddle/fluid/operators/graph_reindex_op.cc +++ b/paddle/fluid/operators/graph_reindex_op.cc @@ -25,11 +25,10 @@ class GraphReindexOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_sample_neighbors_op.cc b/paddle/fluid/operators/graph_sample_neighbors_op.cc index 14f17f77dcb6f34f29c91af4f1a7238d0f87b8c1..0e7a1c97b7912f74fdade39de358560c70604ee2 100644 --- a/paddle/fluid/operators/graph_sample_neighbors_op.cc +++ b/paddle/fluid/operators/graph_sample_neighbors_op.cc @@ -25,11 +25,10 @@ class GraphSampleNeighborsOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Row"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index 9e57884c1412c7a341c4a146f5242f421f371cc5..afdbaf0ca7729e6fff508127e1f0bdd77e383311 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -25,11 +25,10 @@ class GraphSendRecvOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -43,11 +42,11 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/graph_send_ue_recv_op.cc b/paddle/fluid/operators/graph_send_ue_recv_op.cc index 561c7e06f0b3752ab2262d15d17efc33286a20ac..2a252bcf70368cd53bf2b3f597fb281f5e52f148 100644 --- a/paddle/fluid/operators/graph_send_ue_recv_op.cc +++ b/paddle/fluid/operators/graph_send_ue_recv_op.cc @@ -25,11 +25,10 @@ class GraphSendUERecvOP : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -45,11 +44,11 @@ class GraphSendUERecvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 7331c792ea568d6672f0afc3b8db632f06e84b37..90e15ef273456db255f7e3164fe2144c22f720c5 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -114,7 +114,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); @@ -132,8 +132,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Input(Y@GRAD) phi::DenseTensor of " "GroupNormGradOp should not be null")); - return framework::OpKernelType(framework::TransToProtoVarType(t->dtype()), - ctx.GetPlace()); + return phi::KernelKey(framework::TransToProtoVarType(t->dtype()), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 1c10692d15fadd90c7a6f6273187987a8f33537a..ed7dfa03494053839e11a56867c9fab0eb333cfd 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -305,11 +305,11 @@ class GRUGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(weight_grad_name, weight_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 8e05454f1aefc72a6f386257e2e96d5f21d08a1d..7bd104472fe558682d73eecdf19028362671258a 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -270,11 +270,11 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(weight_grad_name, weight_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Hidden")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 7255abcb7b4b63f96008d182f597c30e1724f55c..e1de4a9a4d312c693f07a264c708197afba4a5ac 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -66,10 +66,10 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -213,10 +213,10 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/identity_loss_op.cc b/paddle/fluid/operators/identity_loss_op.cc index bc9986c7ffea16f4054a1271f9a53b1d05725752..76e7f8a733e40d760d61d4504051e4223cfb499e 100644 --- a/paddle/fluid/operators/identity_loss_op.cc +++ b/paddle/fluid/operators/identity_loss_op.cc @@ -27,11 +27,10 @@ class IdentityLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -59,11 +58,11 @@ class IdentityLossGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, platform::CPUPlace()); + return phi::KernelKey(input_data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc index e2274d87c43f356cbc3c70b1a853a16d3e9b1363..a2fdd53e032118611571c995f29ba497640996ae 100644 --- a/paddle/fluid/operators/imag_op.cc +++ b/paddle/fluid/operators/imag_op.cc @@ -58,12 +58,12 @@ class ImagGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); auto complex_dtype = framework::ToComplexType(dtype); - return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + return phi::KernelKey(complex_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/increment_op.cc b/paddle/fluid/operators/increment_op.cc index 342ef41d4173d94e0aba13cbfee46e4e8f20a79c..5fbde4f449bc47184fd0612f167a658351ef752b 100644 --- a/paddle/fluid/operators/increment_op.cc +++ b/paddle/fluid/operators/increment_op.cc @@ -39,11 +39,12 @@ class IncrementOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); // IncrementOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; diff --git a/paddle/fluid/operators/index_add_op.cc b/paddle/fluid/operators/index_add_op.cc index b856e479fba5238befb744a9bb4a7a20af204a76..da3b720ae3ebeca858718ae367c1059107a71965 100644 --- a/paddle/fluid/operators/index_add_op.cc +++ b/paddle/fluid/operators/index_add_op.cc @@ -26,10 +26,10 @@ class IndexAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -79,11 +79,11 @@ class IndexAddGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index a80324d5d303a7ad5685c18ee64d669a7f3e95c4..5acc9f1bd13c23b2fbc8755a4cf1243aec79ced3 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -30,7 +30,7 @@ class InplaceABNOp : public paddle::operators::BatchNormOp { using paddle::operators::BatchNormOp::BatchNormOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -61,11 +61,7 @@ class InplaceABNOp : public paddle::operators::BatchNormOp { platform::errors::InvalidArgument( "Variance input should be of float type")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -135,7 +131,7 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto* var = ctx.InputVar(framework::GradVarName("Y")); auto input_data_type = framework::TransToProtoVarType( @@ -154,11 +150,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index c9f33799c9e1049794029e09854d49c9f3f7997d..289df565b88d7d634144735526b9302ee8907104 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -29,7 +29,7 @@ limitations under the License. */ namespace paddle { namespace operators { -framework::OpKernelType InstanceNormOp::GetExpectedKernelType( +phi::KernelKey InstanceNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, @@ -54,7 +54,7 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType( "Bias input should be of float type")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } void InstanceNormOpMaker::Make() { @@ -98,7 +98,7 @@ NCHW `[batch, in_channels, in_height, in_width]` )DOC"); } -framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( +phi::KernelKey InstanceNormGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { @@ -115,11 +115,11 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } -framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( +phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar("DY"); if (var == nullptr) { @@ -136,8 +136,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer, diff --git a/paddle/fluid/operators/instance_norm_op.h b/paddle/fluid/operators/instance_norm_op.h index 05e2bde9739247645430aafda5cfd6581355fe55..9a885e47e40a02787c26f89710e83991cd061c00 100644 --- a/paddle/fluid/operators/instance_norm_op.h +++ b/paddle/fluid/operators/instance_norm_op.h @@ -29,7 +29,7 @@ class InstanceNormOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; @@ -38,7 +38,7 @@ class InstanceNormGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; @@ -47,7 +47,7 @@ class InstanceNormDoubleGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override; }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index c1b2ae3ea531b8699f36bf321620de4404eed925..999e6df67c3beb70f0732fdf8f772e2ed61ad663 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -337,18 +337,18 @@ class InterpolateOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -357,16 +357,17 @@ class InterpolateOp : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif if (var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -589,22 +590,24 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 95404bbd4a8a7a71cf67990d8f987d22d0dcdbce..e3c4b0be18693d5cbaa8f2d1ea5c85c4daf65e2c 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -441,18 +441,18 @@ class InterpolateV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -461,18 +461,19 @@ class InterpolateV2Op : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif if (var_name == "OutSize" || var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -692,23 +693,25 @@ class InterpolateV2OpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "OutSize" || var_name == "SizeTensor" || var_name == "Scale") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index f03051e2a519ea16c18507c8d8b6cc3da0ba331b..8f68ef13e4a4f5be6a852ded4f25f60631deb03c 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -47,7 +47,7 @@ class OverflowOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { int dtype = -1; auto *x_var = ctx.InputVar("X"); @@ -65,8 +65,8 @@ class OverflowOp : public framework::OperatorWithKernel { "The input type mismatch, the type of Input(X) must be Tensor or " "SelectedRows, please check your input.")); } - return framework::OpKernelType(framework::proto::VarType::Type(dtype), - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::Type(dtype), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 9a06fd369f8824c8a82a2886d4072403c9c474c9..e45e686dd0eeace04e9f241b69b0fb840f006490 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -24,10 +24,10 @@ class KLDivLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -104,11 +104,11 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index 707d9a47006f20ff6eb79a4fcb4283385c88c368..6349ec65a9646ea9b13f6d262c9a7f566fe39de5 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -29,26 +29,23 @@ class KronOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -110,27 +107,24 @@ class KronGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto out_grad_name = framework::GradVarName("Out"); - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 461d77f324bcf6f75ed6a1e58fc74a0f38d333cc..062e33f26610cc109395d7084ab2e167f9e943aa 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -101,7 +101,7 @@ class LayerNormOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -113,7 +113,7 @@ class LayerNormOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -203,7 +203,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); PADDLE_ENFORCE_NOT_NULL( @@ -218,14 +218,8 @@ class LayerNormGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL( t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); - framework::LibraryType library = framework::LibraryType::kPlain; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - layout, - library); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc index fbb091a78a5058741b33a04fbdeeabdaa6b0648e..ffae23c7025379a9d17e0e4435282965b3418c03 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cc +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -35,7 +35,7 @@ class LimitByCapacityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the expert_count and capacity should be same as int64 auto expert_count_dtype = @@ -54,7 +54,7 @@ class LimitByCapacityOp : public framework::OperatorWithKernel { framework::proto::VarType::INT64, platform::errors::InvalidArgument("The dtype of the expert_count and " "capacity should be same as int64")); - return framework::OpKernelType(expert_count_dtype, ctx.GetPlace()); + return phi::KernelKey(expert_count_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index 64fe6562a6c7d9c0d17a194fc9f9142af4b563e9..26f90851d566a5cde730ceda468c2ed25a43b9ab 100644 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -298,9 +298,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of computation kernel of linear_chain_crf // is determined by its input "Emission". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), platform::CPUPlace()); } @@ -343,12 +343,11 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { protected: // Explicitly set that the data type of output of the linear_chain_crf_grad // operator is determined by its input: gradients of LogLikelihood. - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("LogLikelihood")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("LogLikelihood")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index d9dcfbed5967fa5b7135f270c059b784d8237530..e3fade6d6120fc43599a5febbc74f5e6905852d9 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -28,22 +28,24 @@ class LinspaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (platform::is_xpu_place(tensor.place())) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 78c06e8c24a000933bcb6b39dc0fe4eb5d2760ec..5f03e1304b69d00a38e1558b699ece7a9080634d 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -27,11 +27,9 @@ class LoadCombineOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = framework::OpKernelType( - framework::proto::VarType::FP32, ctx.GetPlace()); - return kt; + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/load_combine_op.h b/paddle/fluid/operators/load_combine_op.h index 16e53dbead02e5db945d05c9e885418d46f38cf8..258275f403dcaf6dd2bc5f1c31537008d1d7994b 100644 --- a/paddle/fluid/operators/load_combine_op.h +++ b/paddle/fluid/operators/load_combine_op.h @@ -116,14 +116,15 @@ class LoadCombineOpKernel : public framework::OpKernel { // Get data from fin to tensor paddle::framework::DeserializeFromStream(*buffer, tensor, dev_ctx); - auto in_dtype = framework::TransToProtoVarType(tensor->dtype()); - auto out_dtype = - load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { // convert to float16 tensor - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor fp16_tensor; // copy LoD info to the new tensor fp16_tensor.set_lod(tensor->lod()); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 0c66dbd36568f7f844880305b91a52d67e5fbb34..434c0db2b8faa7295cd7c974368c361e84968ed8 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -26,11 +26,9 @@ class LoadOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = framework::OpKernelType( - framework::proto::VarType::FP32, ctx.GetPlace()); - return kt; + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/load_op_npu.cc b/paddle/fluid/operators/load_op_npu.cc index 8c00f0868300a0c0b0d5c04011b174f3de13c75e..0e8517fd7b5296629e8d9ddbdfc0b6831f66eaff 100644 --- a/paddle/fluid/operators/load_op_npu.cc +++ b/paddle/fluid/operators/load_op_npu.cc @@ -85,13 +85,15 @@ class LoadOpKernel : public framework::OpKernel { } auto load_as_fp16 = ctx.Attr("load_as_fp16"); - auto in_dtype = framework::TransToProtoVarType(tensor->dtype()); - auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { // convert to float16 tensor - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor fp16_tensor; // copy LoD info to the new tensor fp16_tensor.set_lod(tensor->lod()); diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 374bb8920fbbd5e5304d7a04c6d9738084b30c7a..502afbf0c77f10b440b023680c3522ce8479a9e9 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -62,20 +62,19 @@ class LoDResetOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; @@ -202,11 +201,11 @@ class LoDResetGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/log_softmax_op.cc b/paddle/fluid/operators/log_softmax_op.cc index 99da0b08af75b9e1012edcdf1fa70b9e6bd34d7a..eb3ee5b7cd9be2dac0b1bd5e21a4d57cf221da98 100644 --- a/paddle/fluid/operators/log_softmax_op.cc +++ b/paddle/fluid/operators/log_softmax_op.cc @@ -29,11 +29,11 @@ class LogSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -86,11 +86,11 @@ class LogSoftmaxGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc index 5e5e25a56dbca16e2ffd5e45f2ded2e54191399e..171ee209ebd0ef7b4d985014f907fae442b0ed36 100644 --- a/paddle/fluid/operators/logspace_op.cc +++ b/paddle/fluid/operators/logspace_op.cc @@ -28,9 +28,9 @@ class LogspaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/lookup_table_dequant_op.cc b/paddle/fluid/operators/lookup_table_dequant_op.cc index e0ca707ffa70d279f514bab31fed284674d9a697..09636f600a6611c261eed16dd9c23eade868bf62 100644 --- a/paddle/fluid/operators/lookup_table_dequant_op.cc +++ b/paddle/fluid/operators/lookup_table_dequant_op.cc @@ -85,10 +85,10 @@ class LookupTableDequantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 8ad3966a1d236f96cbab75d3031c20d405eab32f..6bb9f9ee19e42c4266c65f6a3db23062181fc9bc 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -67,10 +67,10 @@ class LookupTableOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -191,11 +191,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 84f8c6cf6492a26f0495a051ec70623e501967df..3af95c484f9f21d0e655446f59adf710045b9df3 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -67,10 +67,10 @@ class LookupTableV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; @@ -135,11 +135,11 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index ce31108aa54485e16f61f27b7dbbe595691488d0..5a6ed730477a1845ae7674c4de4aa0ed3082533e 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -222,18 +222,18 @@ class LRNOp : public framework::OperatorWithKernel { ctx->SetOutputDim("MidOut", x_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -242,13 +242,12 @@ class LRNOp : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for lrn // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -346,18 +345,18 @@ class LRNOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -366,13 +365,12 @@ class LRNOpGrad : public framework::OperatorWithKernel { // Some models may have intentionally set "AnyLayout" for lrn // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index b7310ed475994cd5b6941f24a0dc77c6368181b0..7250cf65e488ed6c5626c5149c33081625f19d70 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -135,11 +135,10 @@ class LSTMOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -304,11 +303,10 @@ class LSTMGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index dc36b3431d48912b9dfbb60f217bc0098ad6c1b3..63cf07e35b7cb9c0eb50cddb4acc450597375efc 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -143,11 +143,10 @@ class LSTMPOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context().GetPlace()); } }; @@ -388,11 +387,11 @@ class LSTMPGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index b02a2fe13a2b0ad038a9d311e363c3431b15650d..bf19e28af0adc1cd18d8c13a660a468734dec88c 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -26,7 +26,7 @@ class LstsqOp : public framework::OperatorWithKernel { protected: // The output of lstsq is always complex-valued even for real-valued inputs - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (dtype != framework::proto::VarType::FP32 && @@ -34,7 +34,7 @@ class LstsqOp : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "unsupported data type: %s!", dtype)); } - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index 923c14f3db0f6942cc692efb65aeaa17ff355a49..5a111f278b21a26bcb5da9585233c72930989d20 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -44,10 +44,10 @@ class LUOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -105,10 +105,10 @@ class LUGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cc b/paddle/fluid/operators/margin_cross_entropy_op.cc index 9e9ee9c561159cc6d2eb6e8b81d1b85cd0d96b28..5688ca2fc78ff9b45c4799d713274c07f34ad5b1 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cc +++ b/paddle/fluid/operators/margin_cross_entropy_op.cc @@ -26,11 +26,11 @@ class MarginCrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + ctx.device_context().GetPlace()); } }; @@ -96,11 +96,11 @@ class MarginCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/marker_op.cc b/paddle/fluid/operators/marker_op.cc index 3de4f4451d07e287dc39fa4f1c92e4cb31c0d504..0cd3ccd686d2c38362b110516555167c087ff37b 100644 --- a/paddle/fluid/operators/marker_op.cc +++ b/paddle/fluid/operators/marker_op.cc @@ -30,10 +30,9 @@ class MarkerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index cba18b3cdb2261df50fd3a9cf13ca44c11d09ff8..b1c623b002503b9d874b059e3bbad2e1dc67e158 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -665,39 +665,36 @@ class MatMulOp : public framework::OperatorWithKernel { context->ShareLoD("X", "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When matmul is first oneDNN op in a chain (there was some non oneDNN op // previously) // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -846,11 +843,11 @@ class MatMulOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 0a76f43175dc4f153cc7b44915554397e0ba7073..c52fc08c91d5258d3c52b44234f74b3b3474b442 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -131,38 +131,35 @@ class MatMulV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { #ifdef PADDLE_WITH_MKLDNN // When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN // op previously) then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) && phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); + return phi::KernelKey(tensor.place(), + phi::DataLayout::kNHWC, + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; @@ -195,26 +192,23 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (framework::IsComplexType(expected_kernel_type.data_type_)) { + const phi::KernelKey& expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.dtype())) { // only promote inputs’s types when contains complex input - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index a63d3cb86f75c209798f9fa10aec3ab023ce47ac..16ca2cf09ec0b34a65157f45258f02983f15edd4 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -84,12 +84,10 @@ class MatrixRankOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - phi::DataLayout layout = phi::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index 0e75629f711a90e49311e36bfbaca29a4440472e..3728fbee53478ebad60a9c4f7132203510eeb053 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -40,9 +40,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 7715cf87731ec6fe53baadc53c16d6dad9acc2ad..0c628a46518b5704d065b598334852ec4251756d 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -59,11 +59,11 @@ class MeanGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index 82feee0f695db5b5e18968a934593af9ac700775..06af45d48506a0e8c38b86dc37bb926437788001 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -37,20 +37,19 @@ class MemcpyD2HOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_h2d_op.cc b/paddle/fluid/operators/memcpy_h2d_op.cc index 1426b23dc1b6640fde39e9d0df4dbc11914084a0..8d3fc63154a67b913d1ddef60351a16a09bd638a 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.cc +++ b/paddle/fluid/operators/memcpy_h2d_op.cc @@ -38,20 +38,19 @@ class MemcpyH2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/memcpy_op.cc b/paddle/fluid/operators/memcpy_op.cc index 66cf6a00b7af43be56df3f2715ae9ed8e502e7c8..f000a1cc0dbc4bcff454a84cf49488a83a85fd6c 100644 --- a/paddle/fluid/operators/memcpy_op.cc +++ b/paddle/fluid/operators/memcpy_op.cc @@ -54,20 +54,19 @@ class MemcpyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/meshgrid_op.cc b/paddle/fluid/operators/meshgrid_op.cc index 7921e8844c112029f8823f3c75a56f7d0d4d52e5..f813b9e341908fa1d15a949337f52c5ef560fc44 100644 --- a/paddle/fluid/operators/meshgrid_op.cc +++ b/paddle/fluid/operators/meshgrid_op.cc @@ -30,7 +30,7 @@ class MeshgridOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -47,7 +47,7 @@ class MeshgridOp : public framework::OperatorWithKernel { "All Inputs of Meshgrid OP are Empty!")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -94,11 +94,11 @@ class MeshgridGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index f8e57adc703c12c577bb53fa9ea874acd6b8c605..25e32b51978437472165a94ed529f5b448a4d11d 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -24,10 +24,10 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Out"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 7529523becaf8512315c7bd4297a24b7c0c70d8a..8910e61e42f0e40d7c4b28fc1d1901b787b398c7 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -26,11 +26,11 @@ class AucOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index 30302ceb820d60c16dc6bfd8e6d99bdf3bd12aff..0652151320d819a9a9129a88b27a0b13c384b19b 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -143,11 +143,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), - ctx.device_context()); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/moe_op.cc b/paddle/fluid/operators/moe_op.cc index 6832beeaa8e008435dc5c4cf70092cfd83e503da..186ac1fc434a7f204828b88c3e25f12ed40fbd63 100644 --- a/paddle/fluid/operators/moe_op.cc +++ b/paddle/fluid/operators/moe_op.cc @@ -27,10 +27,10 @@ class MoeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 02537512c9d6a827a417955e63d2bea064323f5e..8236bdd5993e551526fc94d1b21246787cc9133c 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -25,16 +25,16 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -103,10 +103,10 @@ class MulGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index ba263427caa877fff13ea0d6610715b5cd0c23a4..c057d7673077acb4e1298ff45629a543914f5531 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -28,11 +28,10 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -104,11 +103,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc index d57c5f18bdd032d252264c547603afa109604d86..f0bc985f3ea9ec78f752b1917f86b90f2eb76850 100644 --- a/paddle/fluid/operators/nanmedian_op.cc +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -28,10 +28,10 @@ class NanmedianOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -99,11 +99,11 @@ class NanmedianGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index b80de062796a0557f20f8218be5bca23d8532cd5..286c851278117979e34adff517c3727e8371e0f4 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -113,11 +113,10 @@ class NCEOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; @@ -279,11 +278,10 @@ class NCEOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/nop_op.cc b/paddle/fluid/operators/nop_op.cc index 876468f8a7eacaf931e4a76ca0f78f18a4279207..709b1f4f1f020901d659b35b69b8329391faa64a 100644 --- a/paddle/fluid/operators/nop_op.cc +++ b/paddle/fluid/operators/nop_op.cc @@ -25,10 +25,9 @@ class NopOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/number_count_op.cc b/paddle/fluid/operators/number_count_op.cc index 29f0a5bf57fd4998d2fa09ccbb3603d4588e5da8..e636bc98bfca5a695de2de345e6589d68d23618a 100644 --- a/paddle/fluid/operators/number_count_op.cc +++ b/paddle/fluid/operators/number_count_op.cc @@ -28,7 +28,7 @@ class NumberCountOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the numbers should be same as int64 auto number_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "numbers"); @@ -37,7 +37,7 @@ class NumberCountOp : public framework::OperatorWithKernel { framework::proto::VarType::INT64, platform::errors::InvalidArgument( "The dtype of the number_dtype should be int64")); - return framework::OpKernelType(number_dtype, ctx.GetPlace()); + return phi::KernelKey(number_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 0cd6cab49eb11e641551fd6f981c6c1a701af916..ffb3081ca0ba903f1a4eab6ed6755f530ad4fd35 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -56,22 +56,23 @@ class OneHotOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index f5b55fcf0275a2cf080eae22836c85c8293ff04b..a2ef01a89e5933d7f6d1673b94291ee341def5cf 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -29,22 +29,23 @@ class OneHotV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 262aa0fc350e2ce08e2c28889ab7ec5d54d2795e..aa78843724d4ef8252bf869ea13bf66fefad333f 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -24,10 +24,10 @@ class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 54643a39bcd4c6109bc6d02087ad67fb6f635dab..fc260c7e99af4215fd122120dfefdcd6f6a5ba22 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -29,10 +29,10 @@ class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index cf447bc593103565c20f8659714e16d6e11d196b..2a7dc7f3116189f6ec69b140c3b4e0c841c08137 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -23,23 +23,25 @@ class AdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || var_name == "SkipUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index 12429933e03d33e77b3a89a1041055d82dc1b15d..51397e210a264019e9dab5d5ca775a65d43a94b7 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -24,10 +24,10 @@ class AdamaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index 6c73439c625511425634419f919629dc904bab59..8ae9f86ac4ef240f38a1cc2d17f66a004a859f8f 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -80,10 +80,10 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc index 2b4b1c1a109bd2f56dbc8400e927dbfce271c414..e8b719dc6250b7e5e38530275c27911902b8a682 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc @@ -35,13 +35,15 @@ class DGCMomentumOp : public MomentumOp { return MomentumOp::InferShape(ctx); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "current_step" || var_name == "nranks") { VLOG(10) << "var_name:" << var_name << " need not to transform"; - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc index e32cf3625174265b0fd89c8b224428d5d299b46c..ad99133f3562f535b5b8e6085a4e2e8b404db02e 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc @@ -24,10 +24,10 @@ class DistributedFusedLambInitOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index d810f8df7370a75ff83d49b278cbbf0b4aa312e0..f7b8dacfc5aa5f891ba8d4c2d9194f02a60f7fea 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -24,17 +24,19 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index f5710f2e7d8ebca9f7c17ed464ac271c24e3530a..a752517ea8de0423a2e2a213521f63f5194ff6c3 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -72,10 +72,10 @@ class DpsgdOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 22be1f5ac685a67922b046262a1e16d27244bb06..d8110b5bbbe14d4751caf59140f56018e0625d94 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -69,11 +69,11 @@ class FTRLOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc index df55ffa116a2fc1a587f243806b1e4e6bdf01dc8..c6c4397332280ef96a671de96f14e371bd107695 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.cc +++ b/paddle/fluid/operators/optimizers/lamb_op.cc @@ -29,21 +29,23 @@ class LambOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const phi::KernelKey &expected_kernel_type) const { if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index a5c641cc70a0a629a655adaa27c4265251ccf6de..b5b15fa09ea05fc9f00b2e64116e6862966e8ad8 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -133,11 +133,11 @@ class LarsMomentumOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cc b/paddle/fluid/operators/optimizers/merged_adam_op.cc index 867cfe0268c510130133dcb706cce1da2e2befb8..2be0d28a1a245d71a951aabc271dae924e078d71 100644 --- a/paddle/fluid/operators/optimizers/merged_adam_op.cc +++ b/paddle/fluid/operators/optimizers/merged_adam_op.cc @@ -23,23 +23,25 @@ class MergedAdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto param_dtype = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(param_dtype, ctx.GetPlace()); + return phi::KernelKey(param_dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || var_name == "SkipUpdate") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } else { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } } }; diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index 85b2f818fe137ec05159a88af85b37d67a40f4d3..17d31e35fdec2341aa6add96e1fabc7b6aee9c50 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -25,11 +25,11 @@ class MergedMomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto param_dtype = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(param_dtype, ctx.GetPlace()); + return phi::KernelKey(param_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index ad1ae550745dd25a928425a0bfc9b74976c002fc..316f742a2fd360e5e0ca729b7d13aab6ea4be80b 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -114,11 +114,11 @@ class MomentumOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc index d3d45ad3c6ba6562ea9c3c0e8dc8acb65ba531ed..8def9c961f757047627a34e5ef60bce725e70e5a 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -31,11 +31,11 @@ class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel { ctx->SetOutputDim("StepOut", dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 598b84415f9ec919225a692503d7dbdb6d97ed0b..076f5137cab92f7cc34aed27e95e63da783219f0 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -72,10 +72,10 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index 21b145ee49d7ca0ec1ec05c3be5a3c8416bdeb0c..d7e01aa07109ea030a6995684dc131750c6cd982 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -52,10 +52,10 @@ class ProximalGDOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index b8883f22e9256dd6329e6d3df811c35f4e2289ef..ac445d30c31afe97aa3dad23646c789ddd6962b3 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -28,7 +28,7 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); @@ -46,21 +46,18 @@ class SGDOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "LearningRate") { - return framework::OpKernelType( - framework::TransToProtoVarType(tensor.dtype()), - tensor.place(), - tensor.layout()); + return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/optimizers/sparse_momentum_op.h b/paddle/fluid/operators/optimizers/sparse_momentum_op.h index 9eea5c11cb074d9680f399739cebc1ed473059c4..7ea3b29cfadf15a167eefaa8ee98f75a81a71020 100644 --- a/paddle/fluid/operators/optimizers/sparse_momentum_op.h +++ b/paddle/fluid/operators/optimizers/sparse_momentum_op.h @@ -176,11 +176,11 @@ class SparseMomentumOp : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 6686912941a3265dd2b73e3106ce376965dedd8f..91eeed0e9008eccddb95e308d5982828699d8963 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -696,7 +696,7 @@ class Pad2dOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -706,32 +706,31 @@ class Pad2dOp : public framework::OperatorWithKernel { ctx.Input("X") ->mem_desc() .data.format_desc.blocking.inner_nblks == 0) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -832,11 +831,11 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index f457151b707e7e9a6faaca1de8504f7d39724d3a..0bfb02bc455e9136a6f83da3ca8a8c1e7b9c586f 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -30,7 +30,7 @@ class Pad3dOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -40,32 +40,31 @@ class Pad3dOp : public framework::OperatorWithKernel { ctx.Input("X") ->mem_desc() .data.format_desc.blocking.inner_nblks == 0) { - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -183,11 +182,11 @@ class Pad3dOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 28d264ba8e41fb2aa7f85429ca890c054d54b8d7..9b08bb3fc1e1c6c5451f979e8bbdc9ddbafdf259 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -62,11 +62,10 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Y"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context().GetPlace()); } }; @@ -210,11 +209,10 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Y"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 2951091508dd6d82af7ea0ccf5817e088d1f42c9..fd23f5779397009e57270afe2a66d3cab9beaf84 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -32,10 +32,10 @@ class PadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -107,11 +107,11 @@ class PadOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/partial_concat_op.cc b/paddle/fluid/operators/partial_concat_op.cc index 01095b6d429b4a9a04f24cd3b9e9c3fbaae48355..a8a7d82e4627a5011ef9dffdc13ed7e0294eb07c 100644 --- a/paddle/fluid/operators/partial_concat_op.cc +++ b/paddle/fluid/operators/partial_concat_op.cc @@ -89,7 +89,7 @@ class PartialConcatOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -105,7 +105,7 @@ class PartialConcatOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of PartialSum OP are Empty!")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -138,11 +138,11 @@ class PartialConcatGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/partial_sum_op.cc b/paddle/fluid/operators/partial_sum_op.cc index 6473f8d603789b5466a3fe1e133fbb054db7deba..a2255d8e07abf886158615916925044e875f52ee 100644 --- a/paddle/fluid/operators/partial_sum_op.cc +++ b/paddle/fluid/operators/partial_sum_op.cc @@ -91,7 +91,7 @@ class PartialSumOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto inputs = ctx.MultiInput("X"); auto input_data_type = framework::proto::VarType::Type(0); @@ -108,7 +108,7 @@ class PartialSumOp : public framework::OperatorWithKernel { 1, platform::errors::InvalidArgument( "All Inputs of PartialSum OP are Empty!")); - return framework::OpKernelType(input_data_type, platform::CPUPlace()); + return phi::KernelKey(input_data_type, platform::CPUPlace()); } }; @@ -141,11 +141,11 @@ class PartialSumGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index c160dc28bfda4c47f219c12ade32b93031b799b1..b03f2954d2c2053fc8e5e52feb2c1315d51a1999 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -42,7 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { (src_tz[src_tz.size() - 2] % ksize[0] == 0)); } -framework::OpKernelType PoolOp::GetExpectedKernelType( +phi::KernelKey PoolOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -50,15 +50,15 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } -framework::OpKernelType PoolOp::GetKernelTypeForVar( +phi::KernelKey PoolOp::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); @@ -67,16 +67,15 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( // Some models may have intentionally set "AnyLayout" for pool // op. Treat this as NCHW (default data_format value) if (dl != phi::DataLayout::kAnyLayout) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), dl); + return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype()); } } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } -framework::OpKernelType PoolOpGrad::GetExpectedKernelType( +phi::KernelKey PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -84,26 +83,26 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } -framework::OpKernelType PoolOpGrad::GetKernelTypeForVar( +phi::KernelKey PoolOpGrad::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const { + const phi::KernelKey& expected_kernel_type) const { #ifdef PADDLE_WITH_MKLDNN - if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && + if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN)) { auto attrs = Attrs(); auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::StringToDataLayout(data_format)); + return phi::KernelKey(tensor.place(), + phi::StringToDataLayout(data_format), + expected_kernel_type.dtype()); } #endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } void Pool2dOpMaker::Make() { diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 9bb7572c103aeebf595781f5334ca8153ae6d4cd..a935c6b14fd5abd9f44823eb01b2d36cc4d57840 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -24,13 +24,13 @@ class PoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class PoolOpGrad : public framework::OperatorWithKernel { @@ -38,13 +38,13 @@ class PoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override; + const phi::KernelKey& expected_kernel_type) const override; }; class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 57aef714a05022bc8a30ff712b3d53327f6e5bf1..74b98069bf647fa153978790cf8934fc5968074c 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -36,11 +36,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); } }; @@ -49,11 +48,11 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index dc8a088ad2f5c456f367a8831b7066f45989d9cc..3f4d8125671e4b5d1a4352af82dd894db9438411 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -167,11 +167,10 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Score"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Score"), + ctx.device_context().GetPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 8a2199e0231bf8d88895868e948717ef8b915516..5100b4f86989eeca60dea4ac02f57acd1807e93d 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -30,11 +30,11 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -93,11 +93,11 @@ class PReluGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index ca291187b9cddb16ef6ed6b52ce3f83846ec6ddb..d1c455331b4e780ca7651c8e92deffb49e3ffc55 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -135,11 +135,10 @@ class PRROIPoolOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -161,11 +160,10 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/prune_gate_by_capacity_op.cc b/paddle/fluid/operators/prune_gate_by_capacity_op.cc index 14494f426d2d072b2d4f1f8ad95df14b11d5ff7f..388b65f3dd67436e10a744ce7a2ff3a76e2059b8 100644 --- a/paddle/fluid/operators/prune_gate_by_capacity_op.cc +++ b/paddle/fluid/operators/prune_gate_by_capacity_op.cc @@ -66,7 +66,7 @@ class PruneGateByCapacityOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto gate_idx_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "GateIdx"); @@ -82,7 +82,7 @@ class PruneGateByCapacityOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The dtype of the gate_idx and expert_count should " "be same as int64")); - return framework::OpKernelType(gate_idx_data_type, ctx.device_context()); + return phi::KernelKey(gate_idx_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc b/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc index 046269a396ee0c1aadeda1621a7d57fa6194a497..e080f96e88ae703459bd038fef547fc44ac601d7 100644 --- a/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/pscore/distributed_lookup_table_op.cc @@ -78,9 +78,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc index 97391bc0e8b9de1416f19940a33129767bd74d93..1950991b7b563614de03e5d3c888c2c27e8cd4f2 100644 --- a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc +++ b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc @@ -51,9 +51,9 @@ class DistributedPushSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/pscore/send_and_recv_op.cc b/paddle/fluid/operators/pscore/send_and_recv_op.cc index d3f1d17e7a3f18d3352d0ab1786fb196ee5c5e01..d252621116302c7ad543c25e4a7cebc9c5b53275 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op.cc @@ -60,10 +60,10 @@ class SendAndRecvOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return phi::KernelKey(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index 1222f97c091688d6d4f2f7404fd5a0c39aca7d5b..a8534179237e87b5e8fe863d01cda835edf1b065 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -83,11 +83,10 @@ class PSROIPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -96,11 +95,10 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pull_box_extended_sparse_op.cc b/paddle/fluid/operators/pull_box_extended_sparse_op.cc index 36ebc2ef67ed36dff9084a976aa44658a93766cc..7b949fa4338c72c1379fcd71866ff23e41779e9e 100644 --- a/paddle/fluid/operators/pull_box_extended_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_extended_sparse_op.cc @@ -72,10 +72,9 @@ class PullBoxExtendedSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -131,11 +130,11 @@ class PushBoxExtendedSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 14d8bacfa93e4a0606dabab0e46197b9fc418130..c58a176d5263558fc422ab2d08909930d5e1ca13 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -56,10 +56,9 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -119,11 +118,11 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_gpups_sparse_op.cc b/paddle/fluid/operators/pull_gpups_sparse_op.cc index 052c5d3c8b0e6be625e9a4f14b22cc578f7e6589..821cfdab6f10c17bf70fb24fa329ebf9d138d07d 100644 --- a/paddle/fluid/operators/pull_gpups_sparse_op.cc +++ b/paddle/fluid/operators/pull_gpups_sparse_op.cc @@ -64,10 +64,9 @@ class PullGpuPSSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -129,11 +128,11 @@ class PushGpuPSSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_sparse_op.cc b/paddle/fluid/operators/pull_sparse_op.cc index 5023a620af27e8dde4eb8e50eddd004bee6b2ffa..7dc9ae98e0e41cfd94e39bdfb19997ef016cc785 100644 --- a/paddle/fluid/operators/pull_sparse_op.cc +++ b/paddle/fluid/operators/pull_sparse_op.cc @@ -58,10 +58,9 @@ class PullSparseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -127,11 +126,11 @@ class PushSparseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/pull_sparse_v2_op.cc b/paddle/fluid/operators/pull_sparse_v2_op.cc index c0c7c4e036fd0a35ceb252063f8a1c7e34de1506..88a0ac86c2532dfaaa340dbddb9b2ec41eebc640 100644 --- a/paddle/fluid/operators/pull_sparse_v2_op.cc +++ b/paddle/fluid/operators/pull_sparse_v2_op.cc @@ -51,10 +51,9 @@ class PullSparseV2Op : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -119,11 +118,11 @@ class PushSparseV2Op : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/push_dense_op.cc b/paddle/fluid/operators/push_dense_op.cc index 7ab49f2c2f2aad99f6d0ced1572ca02b93538b8b..e13d7574809205c00ac099eb54eed555fc969659 100644 --- a/paddle/fluid/operators/push_dense_op.cc +++ b/paddle/fluid/operators/push_dense_op.cc @@ -30,10 +30,9 @@ class PushDenseOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc index a24b234a05da7db40980dc6a011d538603e83e7c..d445dca2501e77f64119bbed15ebc6c356992f7a 100644 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -225,10 +225,10 @@ class PyramidHashOP : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), + ctx.GetPlace()); } }; @@ -465,10 +465,10 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "W"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 7f9d472cb5284fed9f1d09086da8020af0fb45d7..f143bc3a5022d863a70e3c15950a2160e546ad1d 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -124,10 +124,10 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index c98e15fcff9a6e947d6b6700f817ef5ec5eb887c..83be35f998e945fcc6391ab975c894ea3b7a6b8e 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -19,13 +19,13 @@ namespace paddle { namespace operators { -framework::OpKernelType QuantOp::GetExpectedKernelType( +phi::KernelKey QuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), + return phi::KernelKey( + phi::Backend::ONEDNN, phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"))); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/quantize_op.h b/paddle/fluid/operators/quantize_op.h index 46a0469c806e151edf9c80197c7ca8d3beff10eb..3426af2b3619d5e2d806b0b5c8f2669d89c95894 100644 --- a/paddle/fluid/operators/quantize_op.h +++ b/paddle/fluid/operators/quantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class QuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class QuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 9752f21b7eca7d68df1599b062d199978b6aec26..810680ea5d4233ae8f9c63fca600e4cb6574fe00 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -86,9 +86,9 @@ class RandintOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 6736cb4c87c07f741af577597adf8f060b9948c2..11ba62197d7273d9e7b824a03c49283ace76a3ac 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -54,11 +54,10 @@ class RandomCropOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", phi::make_ddim(out_dim)); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/random_routing_op.cc b/paddle/fluid/operators/random_routing_op.cc index c20e1248073bace90dcbc965e6b4d1d58e46b047..320f5cd1cf236e5f38fa20fa61296fa53bf8d1e8 100644 --- a/paddle/fluid/operators/random_routing_op.cc +++ b/paddle/fluid/operators/random_routing_op.cc @@ -55,7 +55,7 @@ class RandomRoutingOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // the dtype of the gate_idx should be same as int64 const auto topk_idx_dtype = @@ -67,7 +67,7 @@ class RandomRoutingOp : public framework::OperatorWithKernel { const auto& topk_value_type = OperatorWithKernel::IndicateVarDataType(ctx, "TopK_Value"); - return framework::OpKernelType(topk_value_type, ctx.GetPlace()); + return phi::KernelKey(topk_value_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc index 78366efc53bf90f3232b63cfc8c01437a6f204c1..187b227f331707142b37d4d6b2c3b3e8c5c10c05 100644 --- a/paddle/fluid/operators/randperm_op.cc +++ b/paddle/fluid/operators/randperm_op.cc @@ -44,11 +44,11 @@ class RandpermOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc index 8a965034ac45acb9ebc63d4c41ca1f99a55bac46..08706bc70521dec7f5a98dc1e3ec1a2677669e0c 100644 --- a/paddle/fluid/operators/range_op.cc +++ b/paddle/fluid/operators/range_op.cc @@ -29,15 +29,17 @@ class RangeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (platform::is_xpu_place(tensor.place())) { - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/rank_attention_op.cc b/paddle/fluid/operators/rank_attention_op.cc index 80bd022aff340f4bfd26641614536b9cfd399852..afc3388f429d261d92fd9a873ef21c0ff1ae282a 100644 --- a/paddle/fluid/operators/rank_attention_op.cc +++ b/paddle/fluid/operators/rank_attention_op.cc @@ -79,11 +79,10 @@ class RankAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -118,11 +117,11 @@ class RankAttentionGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/read_file_op.cc b/paddle/fluid/operators/read_file_op.cc index 602f98dadbe14bc28034729cebd9ab29aa35f0c1..9b42a895a927ce62e4a51381dbe760dcaab1aafb 100644 --- a/paddle/fluid/operators/read_file_op.cc +++ b/paddle/fluid/operators/read_file_op.cc @@ -61,10 +61,10 @@ class ReadFileOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::UINT8, - platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::UINT8, + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc index 617c47530c9e310f8808cdb74712e0f3f7cf099d..94cdc2d658959287ca6aa52e63fa591050695ab3 100644 --- a/paddle/fluid/operators/real_op.cc +++ b/paddle/fluid/operators/real_op.cc @@ -58,12 +58,12 @@ class RealGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); auto complex_dtype = framework::ToComplexType(dtype); - return framework::OpKernelType(complex_dtype, ctx.GetPlace()); + return phi::KernelKey(complex_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 0cc7bf2898f8615035b1e83c4b223c41e98ad39a..ecf8119ed2a1917e0a291359f781cf4a67378d26 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -455,12 +455,12 @@ class ReduceGradKernel : public framework::OpKernel { phi::DenseTensor tmp_tensor; auto* pre_input = context.Input(framework::GradVarName("Out")); - auto in_kernel_type = framework::OpKernelType( - framework::TransToProtoVarType(pre_input->dtype()), - context.GetPlace()); - auto out_kernel_type = framework::OpKernelType( - static_cast(in_dtype), - context.GetPlace()); + auto in_kernel_type = + phi::KernelKey(framework::TransToProtoVarType(pre_input->dtype()), + context.GetPlace()); + auto out_kernel_type = + phi::KernelKey(static_cast(in_dtype), + context.GetPlace()); framework::TransDataType( in_kernel_type, out_kernel_type, *pre_input, &tmp_tensor); ComputeFromInput(&tmp_tensor, context); @@ -584,7 +584,7 @@ class ReduceOp : public framework::OperatorWithKernel { return true; } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -606,7 +606,7 @@ class ReduceOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "float16 can only be used on GPU or NPU or MLU place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -615,10 +615,11 @@ class ReduceOpUseInputPlace : public ReduceOp { using ReduceOp::ReduceOp; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); - kt.place_ = ctx.Input("X")->place(); + phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx); + kt.set_backend( + phi::TransToPhiBackend(ctx.Input("X")->place())); return kt; } }; @@ -663,7 +664,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { int out_dtype = ctx.Attr("out_dtype"); auto input_data_type = @@ -679,7 +680,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 7f5a174952e81b942175008d58f050689aeeb2f3..cd695511d31f22994b344a96c149b012476e7a14 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -49,18 +49,17 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { int in_dtype = ctx.Attr("out_dtype"); if (in_dtype >= 0) { - return framework::OpKernelType( + return phi::KernelKey( static_cast(in_dtype), ctx.GetPlace()); } - return framework::OpKernelType( - framework::OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h index 7b1b6bc831f0e4aa0cb00f0907deaf04b83cfa89..38d526778e1c79db2720c5001d778bc61afe768e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h @@ -83,12 +83,10 @@ class ReduceSumGradKernel : public framework::OpKernel { phi::DenseTensor tmp_tensor; auto* pre_input = context.Input(framework::GradVarName("Out")); - auto in_kernel_type = framework::OpKernelType( - framework::TransToProtoVarType(pre_input->dtype()), - context.GetPlace()); - auto out_kernel_type = framework::OpKernelType( - static_cast(in_dtype), - context.GetPlace()); + auto in_kernel_type = phi::KernelKey(context.GetPlace(), + phi::DataLayout::ALL_LAYOUT, + pre_input->dtype()); + auto out_kernel_type = phi::KernelKey(in_dtype, context.GetPlace()); framework::TransDataType( in_kernel_type, out_kernel_type, *pre_input, &tmp_tensor); ComputeFromInput(&tmp_tensor, context); diff --git a/paddle/fluid/operators/repeat_interleave_op.cc b/paddle/fluid/operators/repeat_interleave_op.cc index aaef332bd0007abfef2506f6dab1d3b76dfd7644..44d022f4d5fbcec3d4a6cfe1a6a57b342a82b00b 100644 --- a/paddle/fluid/operators/repeat_interleave_op.cc +++ b/paddle/fluid/operators/repeat_interleave_op.cc @@ -86,10 +86,10 @@ class RepeatInterleaveOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -111,11 +111,11 @@ class RepeatInterleaveGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index d0cc991e959c70e1da89e995d444f727063af530..354a5d820ee28ea76bec16430b1d3f4cd05ea3cb 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -19,13 +19,13 @@ namespace paddle { namespace operators { -framework::OpKernelType ReQuantOp::GetExpectedKernelType( +phi::KernelKey ReQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), + return phi::KernelKey( + phi::Backend::ONEDNN, phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + phi::TransToPhiDataType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"))); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/requantize_op.h b/paddle/fluid/operators/requantize_op.h index 5b2f0148f152992b808a75a907150c81e6bf3694..a53ea52394814e3398af0146f087f61bc5589443 100644 --- a/paddle/fluid/operators/requantize_op.h +++ b/paddle/fluid/operators/requantize_op.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::OpKernelType; +using phi::KernelKey; class ReQuantOp : public framework::OperatorWithKernel { public: @@ -34,7 +34,7 @@ class ReQuantOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; }; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e980aa66e7ca33467cfe216fbf04e3b5649d9c15..b4191fb46cf3f73ae1334c5e3396e3a3232682da 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -246,22 +246,24 @@ class ReshapeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -359,12 +361,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -602,22 +603,24 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -630,22 +633,23 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 93877aa8251cb781830e0429fb41ec57ce3407d6..07c3aac52078a570480a76de93ea711fb3637efb 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -28,12 +28,12 @@ class ReverseOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc index 3528cc957faf9cd5c6e7957e3ed1bc39e5aff70e..2f75d5aaf2244b8ea565b095b65c9e7f860843be 100644 --- a/paddle/fluid/operators/rnn_op.cc +++ b/paddle/fluid/operators/rnn_op.cc @@ -30,11 +30,10 @@ class RNNOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } }; @@ -116,11 +115,11 @@ class RNNGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 4407fbf1a8c969a935ae60e374c3d6814d6039ff..1a08a01542b0c48e23d5b3222bc507039384afde 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -25,11 +25,10 @@ class ROIAlignOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -51,11 +50,10 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index e79975e6254eb5cf9dade98e0cba0729b2180b84..dadbd1115b4774f8edd6e28aee2fad11cc66a246 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -28,11 +28,10 @@ class ROIPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -53,11 +52,10 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index 823eb03aff6ce12b298ca9a40588a3e5c2c13d3c..53f6969695e8e0c1522dab863afc55f3b66b8c1d 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -27,10 +27,10 @@ class RReluOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index eb4f1b88c6370d2809bfe1dbfbb03457da45e2bf..88d51eabaf94b49fa52f2c5cd70689fa5a5e8d8d 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -47,17 +47,18 @@ class RunProgramOp : public framework::OperatorWithKernel { * * Of course, the data type here is also not important. */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; @@ -173,17 +174,18 @@ class RunProgramGradOp : public framework::OperatorWithKernel { protected: /* see [Why use single type kernel] */ - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index ee9abf6f35400fe5571a455e6af682be41479f27..db9944ffb110dbe7cb7246335ae5d90fbf6f3a51 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -177,12 +177,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits"); - framework::OpKernelType kt = - framework::OpKernelType(data_type, ctx.device_context()); - return kt; + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -234,13 +232,11 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("SampledLogits")); - framework::OpKernelType kt = - framework::OpKernelType(data_type, ctx.device_context()); - return kt; + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 6d4e844d03ed8be2c9b1e69c58cc8ce245559745..0263180a45a0d93a818cf7377c62e599729f5209 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -30,19 +30,19 @@ class SaveCombineOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } // TODO(lujun): The override here is just to bypass transform // in operator impl, which is not elegant enough. - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(tensor.place(), + phi::DataLayout::ALL_LAYOUT, + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/save_combine_op.h b/paddle/fluid/operators/save_combine_op.h index bf5e2a5e4d90f2c87528d43233c6d471f9fb0800..10acb286ee49ded8f971cb49e2231bd1b9a08822 100644 --- a/paddle/fluid/operators/save_combine_op.h +++ b/paddle/fluid/operators/save_combine_op.h @@ -99,12 +99,14 @@ void SaveCombineTensorKernel(const Context& dev_ctx, "The Tensor with Index (%d) to be saved is not initialized.", i)); // Serialize tensors one by one // Check types to see if a fp16 transformation is required - auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); - auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor.dtype(); + auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { auto place = dev_ctx.GetPlace(); - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor out; framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); // copy LoD info to the new tensor diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 179a18ba8d7ceed6c33e5b3da66486a58f01bce6..3af82952f4e297c8ae89ca55491431c691106ce1 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -30,10 +30,10 @@ class SaveOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/save_op.h b/paddle/fluid/operators/save_op.h index 7b78ac1ecea876c645a52018f4420067888e87c4..e33fc68f39faffa4f678c54c44fc030eddcba1f2 100644 --- a/paddle/fluid/operators/save_op.h +++ b/paddle/fluid/operators/save_op.h @@ -90,12 +90,14 @@ class SaveOpKernel : public framework::OpKernel { "Cannot open %s to save variables.", filename)); auto save_as_fp16 = ctx.Attr("save_as_fp16"); - auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); - auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + auto in_dtype = tensor.dtype(); + auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; if (in_dtype != out_dtype) { - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); phi::DenseTensor out; framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); // copy LoD info to the new tensor diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 7416269e33dd2e6a43c67757cc6eff3153288010..2cfd09698652da1faa58409f7f1149312f890269 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -27,11 +27,11 @@ class ScaleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index 93d57aedd8aff0f9b955228b75c7b0d3160b7e04..f6d1974968967710b16fa3d6635f2c0d58de32ba 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -26,10 +26,9 @@ class SeedOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::INT32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::INT32, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/segment_pool_op.cc b/paddle/fluid/operators/segment_pool_op.cc index 2cdc5746614bbe1492982f142c6ab2fc1e50bfba..c2199b70365b1ea48251f2222e3ae42da37bd35a 100644 --- a/paddle/fluid/operators/segment_pool_op.cc +++ b/paddle/fluid/operators/segment_pool_op.cc @@ -28,11 +28,10 @@ class SegmentPoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -113,11 +112,11 @@ class SegmentPoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 117fc4ebe0c36030cd1836fbb739a2eb6cf1e708..63aef4a6282a4341a35b673e867e188075c85f1f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -121,11 +121,11 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index b1223618eea0d528b5a55e7dc16a0ad8640b2f0f..a9e0b21b7bb1958a4418db3dc7962a9c1959d222 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -83,10 +83,10 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ctx->ShareLoD("Y", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -166,11 +166,11 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index 67573b543db5e640fcf2bf3f88c8ca7b1f29dd01..4a3100b14b4d2eabc585edbf42638c1236e6ce98 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -128,10 +128,10 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ "Out"); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -238,11 +238,11 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index c38077986109942c24c568bd3bf3772186e093da..940d5caaaa73a68e6ce32568c5072d87ef571fc8 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -39,21 +39,22 @@ class SequenceMaskOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "depth_tensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index 6957920131ceac06bd22680dab636eff08e32778..12c82250157282dd2bb58307318b4ca4f1c3ee68 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -134,10 +134,10 @@ class SequencePadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -246,11 +246,11 @@ class SequencePadGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 778b2f885494538a78a9a5273babd8340c6d881c..938b23a22a63c93427c27c1d4b99eca91cac376c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -154,11 +154,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 17961181fb36b4cc9918a57d679b8054726671e6..a626d487b532c2e88229a46b62ec72fbb07e448f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -126,11 +126,10 @@ class SequenceScatterOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -146,11 +145,11 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 9375cea85c78f5a58bb32ad1600ee003a01479d4..b7e2ff766f04819d3776682c42995f956f0084dc 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -56,11 +56,10 @@ class SequenceSliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -81,11 +80,11 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 80f13a51ab0b1249bbbca9e39ace4c219aee65a6..4089cdb9fcb3f886fc34618023768d14cc18a1e7 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -36,14 +36,15 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); phi::DataLayout layout_ = DataLayout::kAnyLayout; if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; @@ -120,14 +121,15 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); phi::DataLayout layout_ = DataLayout::kAnyLayout; if (ctx.HasAttr("data_format")) { layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc index b19dfe40ed95e1dff22c33f5fb20cec240d63535..c57cd949512f6bcca7713d6c45765debae4d8ea1 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc @@ -100,10 +100,10 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index fe91dd00d4f860cd507806a49b8c004b38e59048..bddad088fe375c870cca63dabed797d2ba0aa6d2 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -86,10 +86,10 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; @@ -156,11 +156,11 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index d635feee58b586e4035d89ad5f1a78bd2f3c9819..19ce77b6b4d111e63c924c3c118ffa4870376e28 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -45,22 +45,24 @@ class SetValue : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StepsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -214,23 +216,25 @@ class SetValueGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto in_tensor = ctx.Input(framework::GradVarName("Out")); - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - in_tensor->place()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + in_tensor->place()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StepsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index 6849b4e42721e222b37f4894eee15325c73c7ebe..24d2f1104db5a949eb2a24f2d23ae562ab1fb21e 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -26,21 +26,21 @@ class ShapeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } protected: - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/shuffle_batch_op.cc b/paddle/fluid/operators/shuffle_batch_op.cc index 6eeec761120b04f46d0d2d0322e8f948506c4e54..d34c102d0ed544e0a1749f5dd72d5f3a6e8b4c3a 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cc +++ b/paddle/fluid/operators/shuffle_batch_op.cc @@ -55,18 +55,20 @@ class ShuffleBatchOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "Seed") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } return framework::OperatorWithKernel::GetKernelTypeForVar( var_name, tensor, expected_kernel_type); @@ -123,11 +125,11 @@ class ShuffleBatchOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index 7e98514cde37097845a327bd99e6e1d9c6dc3f76..b72d3557b65f0a677776e1829382430a52289f6c 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,11 +35,11 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -89,11 +89,11 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 5c5343bf4248c49f76c88dbf70bb5fae621534e9..536e878c6fc091febcec3c8ec0d1dca7f59144dd 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -74,11 +74,10 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index 094e87f384bcd44098b0ece821f5d86a102e4f3c..695807a4c3c6e529c4b8222ef410253ca69b2d09 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -25,17 +25,19 @@ class SizeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto dtype = framework::proto::VarType::FP32; // dtype is not important - return framework::OpKernelType(dtype, ctx.GetPlace()); + return phi::KernelKey(dtype, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index a418719907872dc37041ef4c98ec7249cf29c4a6..426eec0b0ea8384d6a5a4affe1e543033f61dc99 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -132,7 +132,7 @@ class SliceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *in_var = ctx.InputVar("Input"); if (in_var->IsType()) { @@ -144,9 +144,8 @@ class SliceOp : public framework::OperatorWithKernel { "The tensor Input (Input) of Slice op is not initialized.")); // NOTE: cuda pinned tensor need to copy its data to target place if (platform::is_cuda_pinned_place(in_tensor.place())) { - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()), + ctx.GetPlace()); } #ifdef PADDLE_WITH_MKLDNN @@ -162,33 +161,37 @@ class SliceOp : public framework::OperatorWithKernel { // created, so in that scenario a fallback is needed if (ctx.Input("Input") ->mem_desc() - .data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + .data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor.dtype()), in_tensor.place()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor.dtype()), + in_tensor.place()); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -322,7 +325,7 @@ class SliceOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); @@ -335,28 +338,32 @@ class SliceOpGrad : public framework::OperatorWithKernel { // created, so in that scenario a fallback is needed if (ctx.Input(framework::GradVarName("Out")) ->mem_desc() - .data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + .data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index bc11f53e009353092d353babcfa3725544b82276..99383363e65eb9804277192793584502b8a35fae 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -31,7 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. std::string data_format = ctx.Attr("data_format"); @@ -48,7 +48,8 @@ class SoftmaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; @@ -116,7 +117,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. std::string data_format = ctx.Attr("data_format"); @@ -132,7 +133,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { PADDLE_THROW(platform::errors::InvalidArgument( "float16 can only be used on GPU/NPU/XPU/MLU and custom place")); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index a2ca77cc606038c082fa1010f923815a91d45f74..df142f3350c0e7e4f31ba0f525e1f39a091a1b95 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -218,11 +218,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -310,11 +309,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 0d4af9c0ce94ae5cbe74160361ce999bfcd83bc3..ed9c82c34feb9f74313099db4a66db25df0dcf3e 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -204,11 +204,11 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/sparse_attention_op.cc b/paddle/fluid/operators/sparse_attention_op.cc index 48dc3d782481dbcc1a007079146960a6adff1f06..26dfc0fbbc64d60fed01627c2bb87dbdd07cc513 100644 --- a/paddle/fluid/operators/sparse_attention_op.cc +++ b/paddle/fluid/operators/sparse_attention_op.cc @@ -122,11 +122,11 @@ class SparseAttentionOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "Q", "K"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -169,11 +169,11 @@ class SparseAttentionOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 372e31aa9af6308d65a6d4615a38f1e0d19493ed..85bd8676652736466b84858fa135254813711acd 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -25,9 +25,9 @@ class SpectralNormOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; @@ -143,9 +143,9 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index fc7e8a869e3ef82e22b00650881fcfd4f2d9eff3..47f6306acbe80e97a0a2f9582676923069fa7ad7 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -108,7 +108,7 @@ class SplitOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); @@ -120,25 +120,27 @@ class SplitOp : public framework::OperatorWithKernel { // 16(depending on which blocking format is used) submemory cannot be // created, so in that scenario a fallback is needed const auto x_md = ctx.Input("X")->mem_desc(); - if (x_md.data.format_desc.blocking.inner_nblks == 0) - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::ONEDNN, - framework::LibraryType::kMKLDNN); + if (x_md.data.format_desc.blocking.inner_nblks == 0) { + return phi::KernelKey(phi::Backend::ONEDNN, + phi::DataLayout::ONEDNN, + phi::TransToPhiDataType(input_data_type)); + } } #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxisTensor" || var_name == "SectionsTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cc b/paddle/fluid/operators/squared_l2_distance_op.cc index dc1848b3ee1247b227a460566618f72e8d20c1db..f1ed2d3ee6813585e637dc9fefe25a5d7231a5b5 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cc +++ b/paddle/fluid/operators/squared_l2_distance_op.cc @@ -200,9 +200,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 7b023bcdf662cccfdfe89f9a2074c6a04bbfad33..115901d3ee2ee56ee3260dd821bcbc17d7b6166d 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -120,11 +120,11 @@ class SqueezeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; @@ -139,11 +139,11 @@ class SqueezeGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc index d30320f9952ee3457090b4555e947bc8b8a717a0..9cc78eb300a7e2e2cc4849b0efd6846d2c67ce14 100644 --- a/paddle/fluid/operators/stack_op.cc +++ b/paddle/fluid/operators/stack_op.cc @@ -31,11 +31,11 @@ class StackOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/stft_op.cc b/paddle/fluid/operators/stft_op.cc index 986911a1391456b270ba0ab1ce5b1b93010a14d9..8c9507bc89ceda77616f4ee0ac36e8b47daa705a 100644 --- a/paddle/fluid/operators/stft_op.cc +++ b/paddle/fluid/operators/stft_op.cc @@ -79,10 +79,10 @@ class StftOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(in_dtype, ctx.GetPlace()); + return phi::KernelKey(in_dtype, ctx.GetPlace()); } }; @@ -140,12 +140,12 @@ class StftGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto in_dtype = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); const auto kernel_dtype = framework::ToRealType(in_dtype); - return framework::OpKernelType(kernel_dtype, ctx.GetPlace()); + return phi::KernelKey(kernel_dtype, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index c08f214ab58bcb96ce861d7fb38cce065a9ca964..fffd99ae76b34042ff1587d114907dcc6440cfe9 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -31,7 +31,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *in_var = ctx.InputVar("Input"); auto is_in_var_array = in_var->IsType(); @@ -50,35 +50,37 @@ class StridedSliceOp : public framework::OperatorWithKernel { string::to_string(tensor.place()))); } } - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + ctx.GetPlace()); } // NOTE: cuda pinned tensor need to copy its data to target place auto in_tensor = ctx.Input("Input"); if (platform::is_cuda_pinned_place(in_tensor->place())) { - return framework::OpKernelType( - framework::TransToProtoVarType(in_tensor->dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(in_tensor->dtype()), + ctx.GetPlace()); } - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - in_tensor->place()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + in_tensor->place()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor" || var_name == "StridesTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StridesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -164,26 +166,30 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "StartsTensor" || var_name == "EndsTensor" || var_name == "StridesTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || var_name == "StridesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index f1a7688372adc785159017123adf7a32ba5c92f7..35128b0085687e8b8e9f5d27d2cd03c364d2928e 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -469,19 +469,19 @@ class FasterTokenizerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::INT64, - paddle::platform::CPUPlace()); + return phi::KernelKey(framework::proto::VarType::INT64, + paddle::platform::CPUPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + const phi::KernelKey& expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + tensor.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 098167cb69d7a85b6f49bdc041ae4fb040e7f364..a4902a85fcba7b5ae173b173d8fe75cada079f44 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -30,7 +30,7 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto x_vars = ctx.MultiInputVar("X"); auto x_vars_name = ctx.InputNames("X"); @@ -87,27 +87,24 @@ class SumOp : public framework::OperatorWithKernel { } // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - return framework::OpKernelType(data_type, ctx.GetPlace()); + return phi::KernelKey(data_type, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { - return framework::OpKernelType( - framework::TransToProtoVarType(value.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(value.dtype()), + ctx.GetPlace()); } } // if input sparse vars are not initialized, use an default kernel type. - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& x_var : x_vars) { auto& array = x_var->Get(); for (auto& each : array) { if (each.numel() != 0 && each.IsInitialized()) { - return framework::OpKernelType( - framework::TransToProtoVarType(each.dtype()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType(each.dtype()), + ctx.GetPlace()); } } } diff --git a/paddle/fluid/operators/tdm_child_op.cc b/paddle/fluid/operators/tdm_child_op.cc index c91f0b989e3acd19bac291048cbfb9e722bb8c64..0ec2c1e85bf838ab7eaf2a265033d809305d59ad 100644 --- a/paddle/fluid/operators/tdm_child_op.cc +++ b/paddle/fluid/operators/tdm_child_op.cc @@ -102,10 +102,10 @@ class TDMChildOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/tdm_sampler_op.cc b/paddle/fluid/operators/tdm_sampler_op.cc index 7480c103941a56da298aac32e15d449e3b85824d..66e9728d88f7333dcb897b9979fa44aaf4a50162 100644 --- a/paddle/fluid/operators/tdm_sampler_op.cc +++ b/paddle/fluid/operators/tdm_sampler_op.cc @@ -118,10 +118,10 @@ class TDMSamplerOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index bad44798680530430d67d2246944ca2a4fd77bd8..fdb78f9da326e359e655f6d75c12603873c57baa 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -74,11 +74,10 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { // Explicitly set that the data type of computation kernel of // teacher_student_sigmoid_loss // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -186,11 +185,10 @@ class TeacherStudentSigmoidLossGradientOp // Explicitly set that the data type of computation kernel of // teacher_student_sigmoid_loss // is determined by its input "X". - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 119fcf4f49bc53cc02ab603625fbb21cc7d4cf1b..32fc06f57872a192e78dfa3cbc4428f8b1d93190 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -28,10 +28,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -120,11 +120,11 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { } } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/tile_op.cc b/paddle/fluid/operators/tile_op.cc index 172e96737061da99839fdd132c8216b3b8b59137..9ea804b24438c15f999831ec9266b7f3ba17a026 100644 --- a/paddle/fluid/operators/tile_op.cc +++ b/paddle/fluid/operators/tile_op.cc @@ -29,22 +29,23 @@ class TileOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -121,22 +122,24 @@ class TileGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { + const phi::KernelKey& expected_kernel_type) const override { if (var_name == "repeat_times_tensor" || var_name == "RepeatTimes") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 42d5433792d14c141db3b34e98d0475326942e86..22eb23c93f3f9f394fb0a608506fd0c635d99f6a 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -65,15 +65,10 @@ class TopkOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context(), - layout_, - library_); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; @@ -128,11 +123,11 @@ class TopkOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/transfer_layout_op.cc b/paddle/fluid/operators/transfer_layout_op.cc index 5bba1c225a58822d76add9166a494b64d8f11670..a197546b357bb11b5e8a2a8aa051dd1c65451ee4 100644 --- a/paddle/fluid/operators/transfer_layout_op.cc +++ b/paddle/fluid/operators/transfer_layout_op.cc @@ -42,7 +42,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // kernel's device type is decided by input tensor place auto *in = ctx.InputVar("X"); @@ -59,14 +59,16 @@ class TransferLayoutOp : public framework::OperatorWithKernel { in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace(); // dtype is not important - return framework::OpKernelType(framework::proto::VarType::FP32, place); + return phi::KernelKey(framework::proto::VarType::FP32, place); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return expected_kernel_type; + const phi::KernelKey &expected_kernel_type) const override { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index e81c619db4348e38c4d63ebf489cc8cfa05e032e..d49cbad1147e91790750ed15b26bc6df8accbc8e 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -97,12 +97,13 @@ class TransposeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto &data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -192,13 +193,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -229,13 +231,14 @@ class Transpose2Op : public TransposeOp { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; @@ -333,14 +336,15 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); std::string data_format = ctx.Attr("data_format"); phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type)); } }; diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc index 525dd17c39bb9b8bec05416af824d55e642417c1..0e78aa20faa323ebae8c6fabcf7b38da15a81481 100644 --- a/paddle/fluid/operators/tree_conv_op.cc +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -151,11 +151,11 @@ class TreeConvOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), - ctx.device_context()); + ctx.GetPlace()); } }; @@ -215,11 +215,11 @@ class TreeConvGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), - ctx.device_context()); + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc index 62dc419fd026278c42305a28ad6d2e2cd75bc06d..66e4c3a57890e3f143476d468c04a6cef6fc04eb 100644 --- a/paddle/fluid/operators/triangular_solve_op.cc +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -23,10 +23,10 @@ class TriangularSolveOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/tril_indices_op.cc b/paddle/fluid/operators/tril_indices_op.cc index bae34fa5f563523173e3f20dfe7e4a405c8911d0..4631900e3b3936b9d399083f2860005b7fc33af0 100644 --- a/paddle/fluid/operators/tril_indices_op.cc +++ b/paddle/fluid/operators/tril_indices_op.cc @@ -27,9 +27,9 @@ class TrilIndicesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/triu_indices_op.cc b/paddle/fluid/operators/triu_indices_op.cc index d02b54f6083f8b80cabe21d284f42ffee17ebb2e..8167cb3e3f3909e678c9b8591e3805ee8d8edb09 100644 --- a/paddle/fluid/operators/triu_indices_op.cc +++ b/paddle/fluid/operators/triu_indices_op.cc @@ -24,9 +24,9 @@ class TriuIndicesOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index 1d29a9c518976ee7696b7555ee58235afb3e7409..c5a4a1268fd7dedcd16935b997839b3edef133c8 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -31,15 +31,11 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - phi::DataLayout layout{phi::DataLayout::kAnyLayout}; - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), - ctx.device_context(), - layout, - library); + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc index a9191c09ea1c6710d8f2046344886eb1d03bdb29..6d4206d3430e9a99addae9d52358e83c19d0ef9e 100644 --- a/paddle/fluid/operators/uniform_random_batch_size_like_op.cc +++ b/paddle/fluid/operators/uniform_random_batch_size_like_op.cc @@ -23,9 +23,9 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp { protected: using BatchSizeLikeOp::BatchSizeLikeOp; - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/uniform_random_inplace_op.cc b/paddle/fluid/operators/uniform_random_inplace_op.cc index 09870c8401e496dfb248d79e8e1305d40919eccc..d43d1cd1252ae8c3d31e9f873578b0efe59a9468 100644 --- a/paddle/fluid/operators/uniform_random_inplace_op.cc +++ b/paddle/fluid/operators/uniform_random_inplace_op.cc @@ -57,10 +57,10 @@ class UniformRandomInplaceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 7ba22baff99b9516f3a95601944e37226646497e..e2605332ccef92baa29a28b8f9ecaf3ffbc740ef 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -136,22 +136,24 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( + return phi::KernelKey( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "ShapeTensorList" || var_name == "ShapeTensor") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; diff --git a/paddle/fluid/operators/unique_consecutive_op.cc b/paddle/fluid/operators/unique_consecutive_op.cc index 97cd31141da2bd96220a856291a2d1c1caec20d6..d57a9ceacf4ca1aa0e3e37711281542a1dd9186a 100644 --- a/paddle/fluid/operators/unique_consecutive_op.cc +++ b/paddle/fluid/operators/unique_consecutive_op.cc @@ -26,10 +26,10 @@ class UniqueConsecutiveOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/unique_op.cc b/paddle/fluid/operators/unique_op.cc index c99f60ca873b1cad1124cd3ebf776cc4cff94e93..5484a16ca6bd4d1111b82254e3f588f809917d0e 100644 --- a/paddle/fluid/operators/unique_op.cc +++ b/paddle/fluid/operators/unique_op.cc @@ -98,18 +98,17 @@ class UniqueOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // Return CPUPlace when Attr("is_sorted") is false. Because it means // that fluid.layers.unique is called, but there is no cuda kernel. if (!ctx.Attr("is_sorted")) { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } else { // new version paddle.unique is called. - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } } }; diff --git a/paddle/fluid/operators/unique_with_counts_op.cc b/paddle/fluid/operators/unique_with_counts_op.cc index 6e60078f6ab48604aa776e2bb46cd3a6334c62c4..3726fd978bd935ffe4ebcf69dda909f762e532d1 100644 --- a/paddle/fluid/operators/unique_with_counts_op.cc +++ b/paddle/fluid/operators/unique_with_counts_op.cc @@ -44,11 +44,10 @@ class UniqueWithCountsOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index 92e20820132265480eed5fbf96f40a01a6134b02..6eb6b81eb4d3d6c5ab09081d48d0da2b5d275e18 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -148,11 +148,10 @@ int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) { class UnpoolOp : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -161,11 +160,10 @@ class UnpoolOp : public framework::OperatorWithKernel { class Unpool3dOp : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -204,11 +202,10 @@ class Unpool3dOpGradMaker : public framework::SingleGradOpMaker { class UnpoolOpGrad : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: @@ -217,11 +214,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { class Unpool3dOpGrad : public framework::OperatorWithKernel { protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.GetPlace()); } public: diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index d092c03a56398488e44ab6bce5162b66568e607f..5c6816a171fbc5db3fe66788506c567941feeaba 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -144,23 +144,24 @@ class UnsqueezeOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::TransToProtoVarType( - ctx.Input("X")->type()), - ctx.device_context()); + return phi::KernelKey(framework::TransToProtoVarType( + ctx.Input("X")->type()), + ctx.GetPlace()); } - framework::OpKernelType GetKernelTypeForVar( + phi::KernelKey GetKernelTypeForVar( const std::string &var_name, const phi::DenseTensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { + const phi::KernelKey &expected_kernel_type) const override { if (var_name == "AxesTensor" || var_name == "AxesTensorList") { - return expected_kernel_type; + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); } - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); } }; @@ -225,11 +226,11 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X")); } - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 3f09b20068975e6d4a2600240645774875f85709..93811358c7e50ffd5a93e853485341df47b0a8ea 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -28,15 +28,10 @@ class WarpCTCOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - phi::DataLayout layout_ = phi::DataLayout::kAnyLayout; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.GetPlace(), - layout_, - library_); + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace()); } }; @@ -146,11 +141,11 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Loss")), - ctx.GetPlace()); + return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/where_index_op.cc b/paddle/fluid/operators/where_index_op.cc index 52448b08c5e110d00c0ebda1921d03e906e162b1..2b19b62595eec599187d543279bb349694fb7890 100644 --- a/paddle/fluid/operators/where_index_op.cc +++ b/paddle/fluid/operators/where_index_op.cc @@ -25,10 +25,10 @@ class WhereIndexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - framework::OpKernelType GetExpectedKernelType( + phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Condition"); - return framework::OpKernelType(data_type, ctx.device_context()); + return phi::KernelKey(data_type, ctx.GetPlace()); } }; diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 89ae772f30be4d9c6b459044a7c04f8e145999e4..ad0a83546eb9dabfbb32e1e6a9de0a5fb0bdaf8d 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -26,6 +26,7 @@ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/type_defs.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -53,10 +54,27 @@ class KernelKey { KernelKey(Backend backend, DataLayout layout, DataType dtype) : backend_(backend), layout_(layout), dtype_(dtype) {} + explicit KernelKey(Place place) + : backend_(TransToPhiBackend(place)), + layout_(DataLayout::ALL_LAYOUT), + dtype_(DataType::ALL_DTYPE) {} + + explicit KernelKey(const int& dtype, Place place) + : backend_(TransToPhiBackend(place)), + layout_(DataLayout::ALL_LAYOUT), + dtype_(phi::TransToPhiDataType(dtype)) {} + + explicit KernelKey(Place place, DataLayout layout, DataType dtype) + : backend_(TransToPhiBackend(place)), layout_(layout), dtype_(dtype) {} + Backend backend() const { return backend_; } DataLayout layout() const { return layout_; } DataType dtype() const { return dtype_; } + void set_backend(const Backend& backend) { backend_ = backend; } + void set_layout(const DataLayout& layout) { layout_ = layout; } + void set_dtype(const DataType& dtype) { dtype_ = dtype; } + struct Hash { // Note: Now the number of bits we need does not exceed 32 bits, so there is // no need to use 64 bits. If needed in the future, it can be expanded, diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index 6879c6206564cb392d503a592adc3a8dd506f504..edb841aeb1caaf2edb24430e167c49868b906581 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -125,6 +125,7 @@ enum ProtoDataType { FP16 = 4, FP32 = 5, FP64 = 6, + RAW = 17, UINT8 = 20, INT8 = 21, BF16 = 22, @@ -163,6 +164,8 @@ inline DataType TransToPhiDataType(const int& dtype) { return DataType::BOOL; case ProtoDataType::PSTRING: return DataType::PSTRING; + case ProtoDataType::RAW: + return DataType::ALL_DTYPE; default: return DataType::UNDEFINED; } @@ -198,6 +201,8 @@ inline int TransToProtoVarType(const DataType& dtype) { return ProtoDataType::BOOL; case DataType::PSTRING: return ProtoDataType::PSTRING; + case DataType::UNDEFINED: + return ProtoDataType::RAW; default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported data type `%s` when casting it into " diff --git a/paddle/phi/kernels/impl/searchsorted_kernel_impl.h b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h index e3cd6f5828d04849729bde72dbb044b31f488c57..6c0891e59bcb98366c4d97627d894ab6911ef008 100644 --- a/paddle/phi/kernels/impl/searchsorted_kernel_impl.h +++ b/paddle/phi/kernels/impl/searchsorted_kernel_impl.h @@ -147,7 +147,7 @@ class SearchSortedFunctor { }; template -static void VisitDataType(DataType type, Visitor visitor) { +void VisitDataTypeForSearchSorted(DataType type, Visitor visitor) { if (type == DataType::FLOAT32) { visitor.template apply(); } else if (type == DataType::FLOAT64) { @@ -178,13 +178,13 @@ void SearchsortedKernel(const Context& ctx, int* out_data = out->data(); SearchSortedFunctor functor( ctx, &sorted_sequence, &value, right, out_data); - VisitDataType(value.dtype(), functor); + VisitDataTypeForSearchSorted(value.dtype(), functor); } else { ctx.template Alloc(out); int64_t* out_data = out->data(); SearchSortedFunctor functor( ctx, &sorted_sequence, &value, right, out_data); - VisitDataType(value.dtype(), functor); + VisitDataTypeForSearchSorted(value.dtype(), functor); } }