未验证 提交 4383494f 编写于 作者: H HongyuJia 提交者: GitHub

[Unify KernelKey] change OpKernelType->KernelKey (#49138)

* execute use kernel_key first

* change OpKernelType->KernelKey

* fix py3 compile error, remove redundant header files

* fix build_strategy_test

* fix DataType::RAW

* fix custom_type test: operator_test.cc

* fix transform place

* fix backends_are_same_class

* try fix place TransDataDevice

* support all KernelKey

* fix TransformData

* fix place_are_same_class

* fix merge

* fix test_params_no_grad

* fix specific place of GetExpectedKernelType

* fix specific place of GetExpectedKernelType

* fix GetKernelTypeForVar

* fix dtype error

* fix fetch_v2

* change GetKernelTypeForVar

* fix interpreter

* fix typo error

* polish codes

* polish codes

* polish codes

* fix conflict
上级 723ceed9
...@@ -422,9 +422,9 @@ class CustomOperator : public OperatorWithKernel { ...@@ -422,9 +422,9 @@ class CustomOperator : public OperatorWithKernel {
* The RAW type is used here as the data type, indicating that * The RAW type is used here as the data type, indicating that
* it can only be determined at runtime. * it can only be determined at runtime.
*/ */
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace()); return phi::KernelKey(ctx.GetPlace());
} }
/** /**
...@@ -432,13 +432,13 @@ class CustomOperator : public OperatorWithKernel { ...@@ -432,13 +432,13 @@ class CustomOperator : public OperatorWithKernel {
* Because the kernel data type is RAW, we should skip the cast for * Because the kernel data type is RAW, we should skip the cast for
* data type difference when PrepareData. * data type difference when PrepareData.
*/ */
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const override {
return OpKernelType(expected_kernel_type.data_type_, return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.place_, tensor.layout(),
tensor.layout()); expected_kernel_type.dtype());
} }
}; };
......
...@@ -47,15 +47,17 @@ class TestOpWithKernel : public OperatorWithKernel { ...@@ -47,15 +47,17 @@ class TestOpWithKernel : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) { if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel"; VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); return phi::KernelKey(phi::Backend::GPU,
phi::DataLayout::ALL_LAYOUT,
phi::DataType::FLOAT32);
} else { } else {
VLOG(3) << "use default kernel"; VLOG(3) << "use default kernel";
return OpKernelType(proto::VarType::FP32, return phi::KernelKey(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("input")->place()); ctx.Input<phi::DenseTensor>("input")->place());
} }
} }
}; };
......
...@@ -50,13 +50,14 @@ void CastDataLayout::apply() { ...@@ -50,13 +50,14 @@ void CastDataLayout::apply() {
} }
} }
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in, const phi::DenseTensor& in,
phi::DenseTensor* out) { phi::DenseTensor* out,
const phi::Place& place) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::places_are_same_class(kernel_type_for_var.place_, backends_are_same_class(kernel_type_for_var.backend(),
expected_kernel_type.place_), expected_kernel_type.backend()),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place.")); "TransDataLayout only support DataLayout transform on same place."));
...@@ -72,21 +73,20 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -72,21 +73,20 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
auto src_dim = in.dims(); auto src_dim = in.dims();
std::vector<int64_t> dst_dim; std::vector<int64_t> dst_dim;
auto axis = GetAxis(kernel_type_for_var.data_layout_, auto axis =
expected_kernel_type.data_layout_); GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout());
dst_dim.resize(axis.size()); dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]]; dst_dim[i] = src_dim[axis[i]];
} }
out->Resize(phi::make_ddim(dst_dim)); out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(expected_kernel_type.place_, in.dtype()); out->mutable_data(place, in.dtype());
framework::VisitDataType( framework::VisitDataType(framework::TransToProtoVarType(in.dtype()),
framework::TransToProtoVarType(in.dtype()), CastDataLayout(pool.Get(place), axis, in, out));
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.layout());
} }
} // namespace framework } // namespace framework
......
...@@ -54,10 +54,11 @@ struct CastDataLayout { ...@@ -54,10 +54,11 @@ struct CastDataLayout {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in, const phi::DenseTensor& in,
phi::DenseTensor* out); phi::DenseTensor* out,
const phi::Place& place);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,22 +24,16 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -24,22 +24,16 @@ TEST(DataTransform, DataLayoutFunction) {
in.set_layout(phi::DataLayout::kNHWC); in.set_layout(phi::DataLayout::kNHWC);
auto kernel_nhwc = auto kernel_nhwc =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, phi::KernelKey(place, phi::DataLayout::kNHWC, phi::DataType::FLOAT32);
place,
phi::DataLayout::kNHWC,
paddle::framework::LibraryType::kPlain);
auto kernel_ncwh = auto kernel_ncwh =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, phi::KernelKey(place, phi::DataLayout::kNCHW, phi::DataType::FLOAT32);
place,
phi::DataLayout::kNCHW,
paddle::framework::LibraryType::kPlain);
paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out, place);
EXPECT_TRUE(out.layout() == phi::DataLayout::kNCHW); EXPECT_TRUE(out.layout() == phi::DataLayout::kNCHW);
EXPECT_TRUE(out.dims() == phi::make_ddim({2, 2, 3, 1})); EXPECT_TRUE(out.dims() == phi::make_ddim({2, 2, 3, 1}));
TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out); paddle::framework::TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out, place);
EXPECT_TRUE(in.layout() == phi::DataLayout::kNHWC); EXPECT_TRUE(in.layout() == phi::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == phi::make_ddim({2, 3, 1, 2})); EXPECT_TRUE(in.dims() == phi::make_ddim({2, 3, 1, 2}));
......
...@@ -36,16 +36,17 @@ static void PassTensorData(phi::DenseTensor *from, phi::DenseTensor *to) { ...@@ -36,16 +36,17 @@ static void PassTensorData(phi::DenseTensor *from, phi::DenseTensor *to) {
*from = phi::DenseTensor(); *from = phi::DenseTensor();
} }
void TransformData(const OpKernelType &expected_kernel_type, void TransformData(const phi::KernelKey &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const phi::KernelKey &kernel_type_for_var,
const phi::DenseTensor &input_tensor, const phi::DenseTensor &input_tensor,
phi::DenseTensor *output_tensor) { phi::DenseTensor *output_tensor,
const phi::Place &place) {
bool transformed = false; bool transformed = false;
phi::DenseTensor in; phi::DenseTensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
phi::DenseTensor out; phi::DenseTensor out;
const DataLayout lin = kernel_type_for_var.data_layout_; const DataLayout lin = kernel_type_for_var.layout();
const DataLayout lout = expected_kernel_type.data_layout_; const DataLayout lout = expected_kernel_type.layout();
// do layout transform // do layout transform
if (NeedTransformLayout(lout, lin)) { if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -79,43 +80,42 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -79,43 +80,42 @@ void TransformData(const OpKernelType &expected_kernel_type,
} else { } else {
// Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel
// Do transform via ONEDNN lib // Do transform via ONEDNN lib
PADDLE_ENFORCE( PADDLE_ENFORCE(lin == DataLayout::ONEDNN && lout != DataLayout::ONEDNN,
kernel_type_for_var.data_layout_ == DataLayout::ONEDNN && platform::errors::InvalidArgument(
expected_kernel_type.data_layout_ != DataLayout::ONEDNN, "TransDataLayoutFromOneDNN only supports "
platform::errors::InvalidArgument( "transform from ONEDNN to non-ONEDNN"));
"TransDataLayoutFromOneDNN only supports "
"transform from ONEDNN to non-ONEDNN"));
phi::funcs::TransDataLayoutFromOneDNN( phi::funcs::TransDataLayoutFromOneDNN(
kernel_type_for_var.data_layout_, lin,
phi::OneDNNContext::tls().get_cur_paddle_data_layout(), phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
in, in,
&out, &out,
expected_kernel_type.place_); place);
} }
} else { } else {
// Case3 - transfrom between Non-ONEDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(
kernel_type_for_var, expected_kernel_type, in, &out, place);
} }
#else #else
// Case3 - transfrom between Non-ONEDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out, place);
#endif #endif
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
// do data type transform // do data type transform
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { if (NeedTransformDataType(expected_kernel_type, kernel_type_for_var)) {
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
// do device transform // do device transform
if (!platform::is_same_place(kernel_type_for_var.place_, if (kernel_type_for_var.backend() != phi::Backend::ALL_BACKEND &&
expected_kernel_type.place_)) { !platform::is_same_place(in.place(), place)) {
TransDataDevice(in, expected_kernel_type.place_, &out); TransDataDevice(in, place, &out);
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
......
...@@ -33,10 +33,11 @@ namespace framework { ...@@ -33,10 +33,11 @@ namespace framework {
class OpKernelType; class OpKernelType;
class Variable; class Variable;
void TransformData(const OpKernelType &expected_kernel_type, void TransformData(const phi::KernelKey &expected_kernel_type,
const OpKernelType &kernel_type_for_var, const phi::KernelKey &kernel_type_for_var,
const phi::DenseTensor &input_tensor, const phi::DenseTensor &input_tensor,
phi::DenseTensor *out); phi::DenseTensor *out,
const phi::Place &place);
/** /**
* Set OutVar from InVar, except the tensor is shared with `tensor` * Set OutVar from InVar, except the tensor is shared with `tensor`
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -226,6 +227,11 @@ extern inline bool IsComplexType(const proto::VarType::Type& type) { ...@@ -226,6 +227,11 @@ extern inline bool IsComplexType(const proto::VarType::Type& type) {
type == proto::VarType::COMPLEX128); type == proto::VarType::COMPLEX128);
} }
extern inline bool IsComplexType(const phi::DataType& type) {
return (type == phi::DataType::COMPLEX64 ||
type == phi::DataType::COMPLEX128);
}
extern proto::VarType::Type PromoteTypesIfComplexExists( extern proto::VarType::Type PromoteTypesIfComplexExists(
const proto::VarType::Type type_a, const proto::VarType::Type type_b); const proto::VarType::Type type_a, const proto::VarType::Type type_b);
......
...@@ -129,19 +129,18 @@ struct CastDataType { ...@@ -129,19 +129,18 @@ struct CastDataType {
} }
}; };
void TransDataType(const OpKernelType& kernel_type_for_var, void TransDataType(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in, const phi::DenseTensor& in,
phi::DenseTensor* out) { phi::DenseTensor* out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(in.dtype(),
framework::TransToProtoVarType(in.dtype()), kernel_type_for_var.dtype(),
kernel_type_for_var.data_type_, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The src dtype(%s) of input tensor and kernel_type(%s) "
"The src dtype(%s) of input tensor and kernel_type(%s) " "are not conststent.",
"are not conststent.", DataTypeToString(in.dtype()),
DataTypeToString(framework::TransToProtoVarType(in.dtype())), DataTypeToString(kernel_type_for_var.dtype())));
DataTypeToString(kernel_type_for_var.data_type_))); auto dst_type = framework::TransToProtoVarType(expected_kernel_type.dtype());
auto dst_type = expected_kernel_type.data_type_;
TransDataType(in, dst_type, out); TransDataType(in, dst_type, out);
} }
......
...@@ -28,8 +28,8 @@ class OpKernelType; ...@@ -28,8 +28,8 @@ class OpKernelType;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataType(const OpKernelType& kernel_type_for_var, void TransDataType(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in, const phi::DenseTensor& in,
phi::DenseTensor* out); phi::DenseTensor* out);
void TransDataType(const phi::DenseTensor& in, void TransDataType(const phi::DenseTensor& in,
......
...@@ -19,47 +19,26 @@ limitations under the License. */ ...@@ -19,47 +19,26 @@ limitations under the License. */
TEST(DataTypeTransform, CPUTransform) { TEST(DataTypeTransform, CPUTransform) {
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
auto kernel_fp16 = auto kernel_fp16 = phi::KernelKey(
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16);
place,
phi::DataLayout::kAnyLayout, auto kernel_bf16 = phi::KernelKey(
paddle::framework::LibraryType::kPlain); place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BFLOAT16);
auto kernel_bf16 = auto kernel_fp32 = phi::KernelKey(
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BF16, place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
place,
phi::DataLayout::kAnyLayout, auto kernel_fp64 = phi::KernelKey(
paddle::framework::LibraryType::kPlain); place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64);
auto kernel_fp32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_int32 = auto kernel_int32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32, phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32);
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_int64 = auto kernel_int64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64, phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64);
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bool = auto kernel_bool =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL, phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL);
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
// data type transform from float32 // data type transform from float32
{ {
......
...@@ -24,41 +24,24 @@ TEST(DataTypeTransform, GPUTransform) { ...@@ -24,41 +24,24 @@ TEST(DataTypeTransform, GPUTransform) {
.GetAllocator(gpu_place, context.stream()) .GetAllocator(gpu_place, context.stream())
.get()); .get());
context.PartialInitWithAllocator(); context.PartialInitWithAllocator();
auto kernel_fp16 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16, auto kernel_fp16 = phi::KernelKey(
gpu_place, gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16);
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain); auto kernel_fp32 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
auto kernel_fp32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32, auto kernel_fp64 = phi::KernelKey(
gpu_place, gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64);
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain); auto kernel_int32 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32);
auto kernel_fp64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64, auto kernel_int64 = phi::KernelKey(
gpu_place, gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64);
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain); auto kernel_bool = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL);
auto kernel_int32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_int64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bool =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
// data type transform from float32 // data type transform from float32
{ {
......
...@@ -50,10 +50,10 @@ class SumOpWithKernel : public OperatorWithKernel { ...@@ -50,10 +50,10 @@ class SumOpWithKernel : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext *ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext &ctx) const override { const ExecutionContext &ctx) const override {
return OpKernelType(proto::VarType::FP32, return phi::KernelKey(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("X")->place()); ctx.Input<phi::DenseTensor>("X")->place());
} }
}; };
......
...@@ -84,9 +84,9 @@ class InferShapeUtilsTestOp : public OperatorWithKernel { ...@@ -84,9 +84,9 @@ class InferShapeUtilsTestOp : public OperatorWithKernel {
public: public:
using OperatorWithKernel::OperatorWithKernel; using OperatorWithKernel::OperatorWithKernel;
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); return phi::KernelKey(proto::VarType::FP32, ctx.GetPlace());
} }
}; };
......
...@@ -27,22 +27,26 @@ namespace paddle { ...@@ -27,22 +27,26 @@ namespace paddle {
namespace framework { namespace framework {
namespace interpreter { namespace interpreter {
bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, bool DataTranferHelper::apply(
const OpKernelType& expected_kernel_key, const phi::KernelKey& kernel_type_for_var,
const std::string& var_name, const framework::OpKernelType& expected_kernel_key,
std::string* new_var_name, const phi::DenseTensor* tensor,
std::vector<OpFuncNode>* op_func_nodes, const std::string& var_name,
bool use_local_scope, std::string* new_var_name,
bool is_fetch_v2, std::vector<OpFuncNode>* op_func_nodes,
bool skip_run) { bool use_local_scope,
bool is_fetch_v2,
bool skip_run) {
bool is_transferred = false; bool is_transferred = false;
auto* src_var_name = &var_name; auto* src_var_name = &var_name;
// 1. layout transform // 1. layout transform
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { if (need_layout_transform(
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
auto op = TransferLayout(*src_var_name, auto op = TransferLayout(*src_var_name,
new_var_name, new_var_name,
kernel_type_for_var.data_layout_, kernel_type_for_var.layout(),
expected_kernel_key.data_layout_, expected_kernel_key.data_layout_,
var_scope_, var_scope_,
scope_, scope_,
...@@ -56,13 +60,16 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -56,13 +60,16 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
is_transferred = true; is_transferred = true;
} }
// 2. dype transform // 2. dype transform
if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) { if (need_dtype_transform(
auto op = TransferDtype(*src_var_name, kernel_type_for_var,
new_var_name, TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
kernel_type_for_var.data_type_, auto op = TransferDtype(
expected_kernel_key.data_type_, *src_var_name,
var_scope_, new_var_name,
scope_); framework::TransToProtoVarType(kernel_type_for_var.dtype()),
expected_kernel_key.data_type_,
var_scope_,
scope_);
if (op) { if (op) {
RunAndConstructOpFuncNode( RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run); op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
...@@ -72,8 +79,9 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -72,8 +79,9 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
is_transferred = true; is_transferred = true;
} }
// 3. device transform // 3. device transform
if (need_device_transform(kernel_type_for_var, expected_kernel_key)) { if (need_device_transform(
auto src_place = kernel_type_for_var.place_; kernel_type_for_var, tensor, expected_kernel_key.place_)) {
auto src_place = tensor->place();
auto dst_place = expected_kernel_key.place_; auto dst_place = expected_kernel_key.place_;
auto op = TransferDevice( auto op = TransferDevice(
...@@ -526,11 +534,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -526,11 +534,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
auto kernel_type_for_var = auto kernel_type_for_var =
static_cast<const framework::OperatorWithKernel*>(op_base) static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar( ->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key); var_name_item.first,
*tensor_in,
framework::TransOpKernelTypeToPhiKernelKey(
expected_kernel_key));
// apply data transform // apply data transform
is_transferred = is_transferred =
data_transfer_helper.apply(kernel_type_for_var, data_transfer_helper.apply(kernel_type_for_var,
expected_kernel_key, expected_kernel_key,
tensor_in,
var_name, var_name,
&new_var_name, &new_var_name,
new_op_func_nodes, new_op_func_nodes,
......
...@@ -34,8 +34,9 @@ class DataTranferHelper { ...@@ -34,8 +34,9 @@ class DataTranferHelper {
Scope* local_scope) Scope* local_scope)
: place_(place), var_scope_(var_scope), scope_(local_scope) {} : place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const OpKernelType& kernel_type_for_var, bool apply(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const framework::OpKernelType& expected_kernel_key,
const phi::DenseTensor* tensor,
const std::string& var_name, const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes, std::vector<OpFuncNode>* new_op_func_nodes,
...@@ -79,28 +80,28 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, ...@@ -79,28 +80,28 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
framework::Scope* local_scope, framework::Scope* local_scope,
bool skip_run = false); bool skip_run = false);
inline bool need_device_transform(const OpKernelType& kernel_type_for_var, inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_key) { const phi::DenseTensor* tensor,
auto& src_place = kernel_type_for_var.place_; const phi::Place& expected_place) {
auto& dst_place = expected_kernel_key.place_; if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND ||
if (platform::is_same_place(src_place, dst_place) || platform::is_same_place(tensor->place(), expected_place) ||
(platform::is_cuda_pinned_place(src_place) && (platform::is_cuda_pinned_place(tensor->place()) &&
platform::is_cpu_place(dst_place))) { platform::is_cpu_place(expected_place))) {
return false; return false;
} }
return true; return true;
} }
inline bool need_dtype_transform(const OpKernelType& kernel_type_for_var, inline bool need_dtype_transform(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_key) { const phi::KernelKey& expected_kernel_key) {
return framework::NeedTransformDataType(kernel_type_for_var, return framework::NeedTransformDataType(kernel_type_for_var,
expected_kernel_key); expected_kernel_key);
} }
inline bool need_layout_transform(const OpKernelType& kernel_type_for_var, inline bool need_layout_transform(const phi::KernelKey& kernel_type_for_var,
const OpKernelType& expected_kernel_key) { const phi::KernelKey& expected_kernel_key) {
return framework::NeedTransformLayout(kernel_type_for_var.data_layout_, return framework::NeedTransformLayout(kernel_type_for_var.layout(),
expected_kernel_key.data_layout_); expected_kernel_key.layout());
} }
std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
......
...@@ -730,8 +730,8 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -730,8 +730,8 @@ bool BuildOpFuncList(const platform::Place& place,
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
auto exec_ctx = ExecutionContext( auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key = auto expected_kernel_key = framework::TransPhiKernelKeyToOpKernelType(
op_with_kernel->GetExpectedKernelType(exec_ctx); op_with_kernel->GetExpectedKernelType(exec_ctx));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (op_with_kernel->CanCUDNNBeUsed(exec_ctx, if (op_with_kernel->CanCUDNNBeUsed(exec_ctx,
expected_kernel_key.data_type_)) { expected_kernel_key.data_type_)) {
...@@ -741,6 +741,10 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -741,6 +741,10 @@ bool BuildOpFuncList(const platform::Place& place,
VLOG(4) << "expected_kernel_key : " << expected_kernel_key; VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard() // change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key); ApplyDeviceGuard(op, place, &expected_kernel_key);
if (platform::places_are_same_class(exec_ctx.GetPlace(),
expected_kernel_key.place_)) {
expected_kernel_key.place_ = exec_ctx.GetPlace();
}
// step 2. select op kernel // step 2. select op kernel
auto run_phi_kernel = false; auto run_phi_kernel = false;
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -108,15 +109,32 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { ...@@ -108,15 +109,32 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return ret; return ret;
} }
inline bool NeedTransformDataType(const OpKernelType& l, inline bool NeedTransformDataType(const phi::KernelKey& l,
const OpKernelType& r) { const phi::KernelKey& r) {
return (l.data_type_ != r.data_type_); return l.dtype() != phi::DataType::ALL_DTYPE &&
r.dtype() != phi::DataType::ALL_DTYPE && l.dtype() != r.dtype();
} }
inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) { inline bool backends_are_same_class(const phi::Backend& l,
return (!platform::places_are_same_class(l.place_, r.place_)) || const phi::Backend& r) {
(l.data_type_ != r.data_type_) || if (l == phi::Backend::ALL_BACKEND || r == phi::Backend::ALL_BACKEND) {
NeedTransformLayout(l.data_layout_, r.data_layout_); return true;
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t num_backends = static_cast<size_t>(phi::Backend::NUM_BACKENDS);
if (static_cast<size_t>(l) > num_backends &&
static_cast<size_t>(r) > num_backends) {
return phi::TransToPhiPlace(l).GetDeviceType() ==
phi::TransToPhiPlace(r).GetDeviceType();
}
#endif
return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r);
}
inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) {
return !backends_are_same_class(l.backend(), r.backend()) ||
NeedTransformDataType(l, r) ||
NeedTransformLayout(l.layout(), r.layout());
} }
} // namespace framework } // namespace framework
......
...@@ -214,9 +214,10 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -214,9 +214,10 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(InferShapeContext* ctx) const override {} void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context()); return phi::KernelKey(proto::VarType::FP32,
ctx.device_context().GetPlace());
} }
}; };
...@@ -275,12 +276,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel { ...@@ -275,12 +276,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(InferShapeContext* ctx) const override {} void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::FP32, return phi::KernelKey(phi::Backend::GPUDNN,
platform::CUDAPlace(0), phi::DataLayout::ALL_LAYOUT,
DataLayout::kAnyLayout, phi::DataType::FLOAT32);
framework::LibraryType::kCUDNN);
} }
}; };
......
...@@ -1380,8 +1380,7 @@ bool OperatorWithKernel::SupportXPU() const { ...@@ -1380,8 +1380,7 @@ bool OperatorWithKernel::SupportXPU() const {
#endif #endif
} }
bool OperatorWithKernel::SupportsMKLDNN( bool OperatorWithKernel::SupportsMKLDNN(const phi::DataType data_type) const {
const proto::VarType::Type data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_)); phi::TransToPhiKernelName(type_));
auto has_phi_kernel = auto has_phi_kernel =
...@@ -1389,8 +1388,7 @@ bool OperatorWithKernel::SupportsMKLDNN( ...@@ -1389,8 +1388,7 @@ bool OperatorWithKernel::SupportsMKLDNN(
phi_kernels.end(), phi_kernels.end(),
[data_type](phi::KernelKeyMap::const_reference kern_pair) { [data_type](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::ONEDNN && return kern_pair.first.backend() == phi::Backend::ONEDNN &&
kern_pair.first.dtype() == kern_pair.first.dtype() == data_type;
framework::TransToPhiDataType(data_type);
}); });
if (has_phi_kernel) { if (has_phi_kernel) {
return true; return true;
...@@ -1406,25 +1404,22 @@ bool OperatorWithKernel::SupportsMKLDNN( ...@@ -1406,25 +1404,22 @@ bool OperatorWithKernel::SupportsMKLDNN(
[data_type](OpKernelMap::const_reference kern_pair) { [data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) && return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kMKLDNN && kern_pair.first.library_type_ == LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type; kern_pair.first.data_type_ == TransToProtoVarType(data_type);
}); });
} }
} }
} }
bool OperatorWithKernel::SupportsCUDNN( bool OperatorWithKernel::SupportsCUDNN(const phi::DataType data_type) const {
const proto::VarType::Type data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_)); phi::TransToPhiKernelName(type_));
paddle::experimental::DataType phi_data_type = auto has_phi_kernel =
framework::TransToPhiDataType(data_type); std::any_of(phi_kernels.begin(),
auto has_phi_kernel = std::any_of( phi_kernels.end(),
phi_kernels.begin(), [data_type](phi::KernelKeyMap::const_reference kern_pair) {
phi_kernels.end(), return kern_pair.first.backend() == phi::Backend::GPUDNN &&
[phi_data_type](phi::KernelKeyMap::const_reference kern_pair) { kern_pair.first.dtype() == data_type;
return kern_pair.first.backend() == phi::Backend::GPUDNN && });
kern_pair.first.dtype() == phi_data_type;
});
if (has_phi_kernel) { if (has_phi_kernel) {
return true; return true;
} else { } else {
...@@ -1433,13 +1428,15 @@ bool OperatorWithKernel::SupportsCUDNN( ...@@ -1433,13 +1428,15 @@ bool OperatorWithKernel::SupportsCUDNN(
return false; return false;
} else { } else {
auto& op_kernels = op_kernel_iter->second; auto& op_kernels = op_kernel_iter->second;
proto::VarType::Type fluid_data_type =
framework::TransToProtoVarType(data_type);
return std::any_of( return std::any_of(
op_kernels.begin(), op_kernels.begin(),
op_kernels.end(), op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) { [fluid_data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_) && return platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kCUDNN && kern_pair.first.library_type_ == LibraryType::kCUDNN &&
kern_pair.first.data_type_ == data_type; kern_pair.first.data_type_ == fluid_data_type;
}); });
} }
} }
...@@ -1509,14 +1506,19 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1509,14 +1506,19 @@ bool OperatorWithKernel::SupportsKernelType(
} }
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { phi::DataType data_type) const {
return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") && return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") &&
platform::is_cpu_place(ctx.GetPlace()) && platform::is_cpu_place(ctx.GetPlace()) &&
this->SupportsMKLDNN(data_type); this->SupportsMKLDNN(data_type);
} }
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
return this->CanMKLDNNBeUsed(ctx, phi::TransToPhiDataType(data_type));
}
bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { phi::DataType data_type) const {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") && bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
paddle::platform::is_gpu_place(ctx.GetPlace()); paddle::platform::is_gpu_place(ctx.GetPlace());
...@@ -1528,7 +1530,7 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, ...@@ -1528,7 +1530,7 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
if (use_cudnn && data_type == framework::proto::VarType::BF16) { if (use_cudnn && data_type == phi::DataType::BFLOAT16) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
platform::DnnVersion(), platform::DnnVersion(),
8100, 8100,
...@@ -1540,6 +1542,11 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx, ...@@ -1540,6 +1542,11 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
return use_cudnn && this->SupportsCUDNN(data_type); return use_cudnn && this->SupportsCUDNN(data_type);
} }
bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
return this->CanCUDNNBeUsed(ctx, phi::TransToPhiDataType(data_type));
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to " "The default InferShape function of OperatorWithKernel is not allowed to "
...@@ -1839,8 +1846,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1839,8 +1846,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (need_prepare_data_) { if (need_prepare_data_) {
transfer_scope = PrepareData( transfer_scope =
scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); PrepareData(scope,
framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_),
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
} }
} }
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
...@@ -1960,7 +1971,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1960,7 +1971,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto expected_kernel_key = this->GetExpectedKernelType(ctx); phi::KernelKey phi_kernel_key = this->GetExpectedKernelType(ctx);
auto expected_kernel_key =
framework::TransPhiKernelKeyToOpKernelType(phi_kernel_key);
// NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function // NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function
// GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and // GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and
...@@ -2063,6 +2076,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -2063,6 +2076,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
} }
} }
} }
if (platform::places_are_same_class(expected_kernel_key.place_,
ctx.GetPlace())) {
expected_kernel_key.place_ = ctx.GetPlace();
}
VLOG(3) << "op type:" << type_ VLOG(3) << "op type:" << type_
<< ", expected_kernel_key:" << expected_kernel_key; << ", expected_kernel_key:" << expected_kernel_key;
return expected_kernel_key; return expected_kernel_key;
...@@ -2333,9 +2352,10 @@ void OperatorWithKernel::HandleComplexGradToRealGrad( ...@@ -2333,9 +2352,10 @@ void OperatorWithKernel::HandleComplexGradToRealGrad(
Scope* OperatorWithKernel::PrepareData( Scope* OperatorWithKernel::PrepareData(
const Scope& scope, const Scope& scope,
const OpKernelType& expected_kernel_key, const phi::KernelKey& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars, std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const { RuntimeContext* ctx,
const phi::Place& place) const {
Scope* new_scope = nullptr; Scope* new_scope = nullptr;
const std::unordered_set<std::string>* no_buffer_ins = nullptr; const std::unordered_set<std::string>* no_buffer_ins = nullptr;
...@@ -2378,7 +2398,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2378,7 +2398,7 @@ Scope* OperatorWithKernel::PrepareData(
// has to be created and registered // has to be created and registered
if ((tensor_in->layout() == DataLayout::ONEDNN) && if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) && (var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && (expected_kernel_key.layout() != DataLayout::ONEDNN) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() == (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
DataLayout::kNHWC) && DataLayout::kNHWC) &&
(tensor_in->dims().size() >= 3)) { (tensor_in->dims().size() >= 3)) {
...@@ -2411,35 +2431,33 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2411,35 +2431,33 @@ Scope* OperatorWithKernel::PrepareData(
auto kernel_type_for_var = auto kernel_type_for_var =
GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key); GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key);
bool need_trans_dtype = bool need_trans_dtype =
kernel_type_for_var.data_type_ != expected_kernel_key.data_type_; NeedTransformDataType(expected_kernel_key, kernel_type_for_var);
bool need_trans_layout = NeedTransformLayout( bool need_trans_layout = NeedTransformLayout(
kernel_type_for_var.data_layout_, expected_kernel_key.data_layout_); kernel_type_for_var.layout(), expected_kernel_key.layout());
if (!need_trans_dtype && !need_trans_layout) { if (!need_trans_dtype && !need_trans_layout) {
if (!run_phi_kernel_ && if (!run_phi_kernel_ &&
platform::places_are_same_class(kernel_type_for_var.place_, backends_are_same_class(kernel_type_for_var.backend(),
expected_kernel_key.place_)) { expected_kernel_key.backend())) {
continue; continue;
} }
} }
std::unique_ptr<OpKernelType> new_expected_kernel_key = nullptr; std::unique_ptr<phi::KernelKey> new_expected_kernel_key = nullptr;
if (run_phi_kernel_ && in_def != nullptr && if (run_phi_kernel_ && in_def != nullptr &&
in_def->backend != phi::Backend::ALL_BACKEND) { in_def->backend != phi::Backend::ALL_BACKEND) {
auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if ((in_def->backend != tensor_backend && if ((in_def->backend != tensor_backend &&
(in_def->backend != phi::Backend::GPUDNN || !(in_def->backend == phi::Backend::GPUDNN &&
tensor_backend != phi::Backend::GPU) && tensor_backend == phi::Backend::GPU) &&
(in_def->backend != phi::Backend::KPS || !(in_def->backend == phi::Backend::KPS &&
tensor_backend != phi::Backend::XPU) && tensor_backend == phi::Backend::XPU) &&
(in_def->backend != phi::Backend::ONEDNN || !(in_def->backend == phi::Backend::ONEDNN &&
tensor_backend != phi::Backend::CPU)) || tensor_backend == phi::Backend::CPU)) ||
tensor_in->place().GetType() == AllocationType::GPUPINNED) { tensor_in->place().GetType() == AllocationType::GPUPINNED) {
new_expected_kernel_key = std::make_unique<OpKernelType>( new_expected_kernel_key =
expected_kernel_key.data_type_, std::make_unique<phi::KernelKey>(in_def->backend,
phi::TransToPhiPlace(in_def->backend), expected_kernel_key.layout(),
expected_kernel_key.data_layout_, expected_kernel_key.dtype());
expected_kernel_key.library_type_,
expected_kernel_key.customized_type_value_);
} }
} }
...@@ -2474,14 +2492,18 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2474,14 +2492,18 @@ Scope* OperatorWithKernel::PrepareData(
enable_cache_transfer_scope_ = false; enable_cache_transfer_scope_ = false;
if (!run_by_executor_) { if (!run_by_executor_) {
if (new_expected_kernel_key) { if (new_expected_kernel_key) {
if ((platform::is_gpu_place(kernel_type_for_var.place_) || if (kernel_type_for_var.backend() == phi::Backend::GPU ||
platform::is_gpu_place(new_expected_kernel_key->place_))) { kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
new_expected_kernel_key->backend() == phi::Backend::GPU ||
new_expected_kernel_key->backend() == phi::Backend::GPUDNN) {
new_scope = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, *new_expected_kernel_key, &scope); kernel_type_for_var, *new_expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
} }
} else if ((platform::is_gpu_place(kernel_type_for_var.place_) || } else if (kernel_type_for_var.backend() == phi::Backend::GPU ||
platform::is_gpu_place(expected_kernel_key.place_))) { kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
expected_kernel_key.backend() == phi::Backend::GPU ||
expected_kernel_key.backend() == phi::Backend::GPUDNN) {
new_scope = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, expected_kernel_key, &scope); kernel_type_for_var, expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
...@@ -2523,11 +2545,15 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2523,11 +2545,15 @@ Scope* OperatorWithKernel::PrepareData(
// Do transfer // Do transfer
phi::DenseTensor out; phi::DenseTensor out;
TransformData(new_expected_kernel_key ? *new_expected_kernel_key TransformData(
: expected_kernel_key, new_expected_kernel_key ? *new_expected_kernel_key
kernel_type_for_var, : expected_kernel_key,
*tensor_in, kernel_type_for_var,
&out); *tensor_in,
&out,
new_expected_kernel_key
? phi::TransToPhiPlace(new_expected_kernel_key->backend())
: place);
SetTensorToVariable(*var, out, trans_var); SetTensorToVariable(*var, out, trans_var);
} }
}; };
...@@ -2818,30 +2844,29 @@ proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes( ...@@ -2818,30 +2844,29 @@ proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes(
return target_type; return target_type;
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType( phi::KernelKey OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); return phi::KernelKey(IndicateDataType(ctx), ctx.GetPlace());
} }
OpKernelType OperatorWithKernel::GetKernelTypeForVar( phi::KernelKey OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const { const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// When the op is first oneDNN op (there was some non oneDNN op // When the op is first oneDNN op (there was some non oneDNN op
// previously) // previously)
// then we also need to rotate shape NHWC -> NCWH // then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_, return phi::KernelKey(
tensor.place(), tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
phi::DataLayout::kNHWC);
} }
#endif #endif
return OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
......
...@@ -638,16 +638,22 @@ class OperatorWithKernel : public OperatorBase { ...@@ -638,16 +638,22 @@ class OperatorWithKernel : public OperatorBase {
bool SupportXPU() const override; bool SupportXPU() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const; bool SupportsMKLDNN(phi::DataType data_type) const;
bool SupportsCUDNN(proto::VarType::Type data_type) const; bool SupportsCUDNN(phi::DataType data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type, bool SupportsKernelType(const OpKernelType& kernel_type,
const ExecutionContext& exe_ctx) const; const ExecutionContext& exe_ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
phi::DataType data_type) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const; proto::VarType::Type data_type) const;
bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
phi::DataType data_type) const;
bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx, bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const; proto::VarType::Type data_type) const;
...@@ -665,14 +671,15 @@ class OperatorWithKernel : public OperatorBase { ...@@ -665,14 +671,15 @@ class OperatorWithKernel : public OperatorBase {
const std::string& name1, const std::string& name1,
const std::string& name2) const; const std::string& name2) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const;
// change this to public so that in dygraph mode we can call it to check if we // change this to public so that in dygraph mode we can call it to check if we
// need transform data // need transform data
virtual OpKernelType GetKernelTypeForVar( virtual phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const; const phi::KernelKey& expected_kernel_type) const;
platform::Place GetExecutionPlace( platform::Place GetExecutionPlace(
const platform::Place& platform) const override { const platform::Place& platform) const override {
...@@ -734,9 +741,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -734,9 +741,10 @@ class OperatorWithKernel : public OperatorBase {
* transfered_inplace_vars is a output vector. * transfered_inplace_vars is a output vector.
*/ */
Scope* PrepareData(const Scope& scope, Scope* PrepareData(const Scope& scope,
const OpKernelType& expected_kernel_key, const phi::KernelKey& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars, std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const; RuntimeContext* ctx,
const phi::Place& place) const;
void TransferInplaceVarsBack(const Scope& scope, void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars, const std::vector<std::string>& inplace_vars,
......
...@@ -127,14 +127,10 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -127,14 +127,10 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
int sub_type = ctx.Attr<int>("kernel_sub_type"); return phi::KernelKey(
return OpKernelType(proto::VarType::FP32, ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kPlain,
sub_type);
} }
}; };
...@@ -256,16 +252,6 @@ TEST(OpKernel, all) { ...@@ -256,16 +252,6 @@ TEST(OpKernel, all) {
// kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);
attr = op_desc.mutable_attrs()->Add();
attr->set_name("kernel_sub_type");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
op2->Run(scope, cpu_place);
// kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
} }
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(
...@@ -339,11 +325,11 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { ...@@ -339,11 +325,11 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
auto data_type = auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "phi::DenseTensor"); OperatorWithKernel::IndicateVarDataType(ctx, "phi::DenseTensor");
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
...@@ -361,11 +347,11 @@ class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { ...@@ -361,11 +347,11 @@ class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
auto data_type = auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
class IndicateSelectedRowsDataTypeTestProtoMaker class IndicateSelectedRowsDataTypeTestProtoMaker
...@@ -383,10 +369,10 @@ class IndicateOtherDataTypeTest : public OperatorWithKernel { ...@@ -383,10 +369,10 @@ class IndicateOtherDataTypeTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
...@@ -597,10 +583,10 @@ class OpUnusedVarTest : public OperatorWithKernel { ...@@ -597,10 +583,10 @@ class OpUnusedVarTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType( return phi::KernelKey(
proto::VarType::FP32, ctx.GetPlace(), phi::DataLayout::kAnyLayout); ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
} }
}; };
......
...@@ -34,12 +34,13 @@ global_transfer_scope_key() { ...@@ -34,12 +34,13 @@ global_transfer_scope_key() {
return *x; return *x;
} }
Scope* TryCreateTransferScope(OpKernelType type0, Scope* TryCreateTransferScope(const phi::KernelKey& type0,
OpKernelType type1, const phi::KernelKey& type1,
const Scope* scope) { const Scope* scope) {
Scope* new_scope{nullptr}; Scope* new_scope{nullptr};
size_t infer_cache_key = size_t infer_cache_key =
CombineHash(OpKernelType::Hash()(type0), OpKernelType::Hash()(type1)); CombineHash(static_cast<size_t>(phi::KernelKey::Hash()(type0)),
static_cast<size_t>(phi::KernelKey::Hash()(type1)));
infer_cache_key = infer_cache_key =
CombineHash(infer_cache_key, std::hash<const Scope*>()(scope)); CombineHash(infer_cache_key, std::hash<const Scope*>()(scope));
......
...@@ -39,8 +39,8 @@ static size_t CombineHash(size_t seed, size_t a) { ...@@ -39,8 +39,8 @@ static size_t CombineHash(size_t seed, size_t a) {
return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2); return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
} }
Scope* TryCreateTransferScope(OpKernelType type0, Scope* TryCreateTransferScope(const phi::KernelKey& type0,
OpKernelType type1, const phi::KernelKey& type1,
const Scope* scope); const Scope* scope);
} // namespace framework } // namespace framework
......
...@@ -23,7 +23,8 @@ if(WITH_XPU) ...@@ -23,7 +23,8 @@ if(WITH_XPU)
scalar scalar
int_array int_array
var_helper var_helper
profiler) profiler
place)
else() else()
cc_library( cc_library(
prepared_operator prepared_operator
...@@ -40,7 +41,8 @@ else() ...@@ -40,7 +41,8 @@ else()
scalar scalar
int_array int_array
var_helper var_helper
profiler) profiler
place)
endif() endif()
cc_library( cc_library(
layer layer
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -39,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -39,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attr, const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const framework::AttributeMap* default_attr,
const std::string op_type, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr, const phi::KernelKey* op_kernel_key = nullptr,
const phi::ArgumentMappingFn* arg_map_fn = nullptr, const phi::ArgumentMappingFn* arg_map_fn = nullptr,
const phi::KernelSignature* default_kernel_signature = nullptr) const phi::KernelSignature* default_kernel_signature = nullptr)
: var_map_in_(in), : var_map_in_(in),
...@@ -47,7 +48,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -47,7 +48,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
attrs_(attr), attrs_(attr),
default_attrs_(default_attr), default_attrs_(default_attr),
op_type_(op_type), op_type_(op_type),
op_kernel_type_(op_kernel_type), op_kernel_key_(op_kernel_key),
arg_map_fn_(arg_map_fn), arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature) {} default_kernel_signature_(default_kernel_signature) {}
...@@ -250,8 +251,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -250,8 +251,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override { bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ && return (op_kernel_key_ &&
(op_kernel_type_->data_layout_ == phi::DataLayout::ONEDNN)); (op_kernel_key_->layout() == phi::DataLayout::ONEDNN));
} }
paddle::small_vector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize> paddle::small_vector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
...@@ -497,7 +498,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -497,7 +498,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
const framework::OpKernelType* op_kernel_type_; const phi::KernelKey* op_kernel_key_;
// arg_map_fn_ and default_kernel_signature_ may be nullptr // arg_map_fn_ and default_kernel_signature_ may be nullptr
const phi::ArgumentMappingFn* arg_map_fn_; const phi::ArgumentMappingFn* arg_map_fn_;
const phi::KernelSignature* default_kernel_signature_; const phi::KernelSignature* default_kernel_signature_;
......
...@@ -519,8 +519,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -519,8 +519,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
*/ */
auto prepared_op = auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr = auto tmp_ins_ptr = PrepareData<VarType>(
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type()); *op_kernel, ins, prepared_op.kernel_key(), prepared_op.place());
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
prepared_op.Run(ins, outs, attrs, default_attrs); prepared_op.Run(ins, outs, attrs, default_attrs);
} else { } else {
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/imperative/var_helper.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
...@@ -75,7 +76,8 @@ template <typename VarType> ...@@ -75,7 +76,8 @@ template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData( std::shared_ptr<NameVarMap<VarType>> PrepareData(
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
const framework::OpKernelType& expected_kernel_key) { const phi::KernelKey& expected_kernel_key,
const phi::Place& place) {
std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr; std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr;
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (size_t i = 0; i < name_pair.second.size(); ++i) { for (size_t i = 0; i < name_pair.second.size(); ++i) {
...@@ -85,7 +87,8 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -85,7 +87,8 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) { if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) {
auto kernel_type_for_var = op.GetKernelTypeForVar( auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key); name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { if (!framework::NeedTransform(kernel_type_for_var,
expected_kernel_key)) {
continue; continue;
} else { } else {
VLOG(3) << "Transform Variable " << GetNameFromVar(template_var) VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
...@@ -111,10 +114,10 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -111,10 +114,10 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var; (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
} else { } else {
phi::DenseTensor out; phi::DenseTensor out;
TransformData( framework::TransformData(
expected_kernel_key, kernel_type_for_var, *tensor, &out); expected_kernel_key, kernel_type_for_var, *tensor, &out, place);
if (NeedTransformDataType(kernel_type_for_var, if (framework::NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) { expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general // To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input // scenarios, if inplace transformed, return original input
// directly // directly
...@@ -149,7 +152,7 @@ class PreparedOp { ...@@ -149,7 +152,7 @@ class PreparedOp {
public: public:
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const phi::KernelKey& kernel_key,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn, const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature* default_kernel_signature,
...@@ -157,7 +160,7 @@ class PreparedOp { ...@@ -157,7 +160,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const phi::KernelKey& kernel_key,
const phi::ArgumentMappingFn* arg_map_fn, const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature, phi::KernelSignature&& kernel_signature,
...@@ -200,12 +203,14 @@ class PreparedOp { ...@@ -200,12 +203,14 @@ class PreparedOp {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
const framework::OpKernelType& kernel_type() const { return kernel_type_; } const phi::KernelKey& kernel_key() const { return kernel_key_; }
const phi::Place& place() const { return dev_ctx_->GetPlace(); }
private: private:
const framework::OperatorBase& op_; const framework::OperatorBase& op_;
const framework::RuntimeContext& ctx_; const framework::RuntimeContext& ctx_;
framework::OpKernelType kernel_type_; phi::KernelKey kernel_key_;
framework::OperatorWithKernel::OpKernelFunc func_; framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_; platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to // NOTE(chenweihang): Similar op members are used to adapt to
......
...@@ -92,15 +92,15 @@ TEST(test_var_helper, eager_var_helper) { ...@@ -92,15 +92,15 @@ TEST(test_var_helper, eager_var_helper) {
ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerVariable>(egr_tensor))); ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerVariable>(egr_tensor)));
ASSERT_TRUE(GetDataType<egr::EagerVariable>(egr_tensor) == ASSERT_TRUE(GetDataType<egr::EagerVariable>(egr_tensor) ==
framework::proto::VarType::FP32); framework::proto::VarType::FP32);
GetCachedValue<egr::EagerVariable>( GetCachedValue<egr::EagerVariable>(egr_tensor,
egr_tensor, phi::KernelKey(phi::Backend::CPU,
framework::OpKernelType(framework::proto::VarType::FP32, phi::DataLayout::ALL_LAYOUT,
platform::CPUPlace())); phi::DataType::FLOAT32));
SetCachedValue<egr::EagerVariable>( SetCachedValue<egr::EagerVariable>(egr_tensor,
egr_tensor, phi::KernelKey(phi::Backend::CPU,
framework::OpKernelType(framework::proto::VarType::FP32, phi::DataLayout::ALL_LAYOUT,
platform::CPUPlace()), phi::DataType::FLOAT32),
egr_tensor2); egr_tensor2);
ASSERT_ANY_THROW(GetPlace<egr::EagerVariable>(egr_tensor2)); ASSERT_ANY_THROW(GetPlace<egr::EagerVariable>(egr_tensor2));
ASSERT_ANY_THROW(SetType<egr::EagerVariable>( ASSERT_ANY_THROW(SetType<egr::EagerVariable>(
egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY)); egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY));
......
...@@ -172,7 +172,8 @@ TEST(test_prepare_op, test_prepare_data) { ...@@ -172,7 +172,8 @@ TEST(test_prepare_op, test_prepare_data) {
PrepareData<imperative::VarBase>( PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), dynamic_cast<framework::OperatorWithKernel&>(*op),
ins, ins,
prepared_op.kernel_type()); prepared_op.kernel_key(),
gpu_place);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
...@@ -229,7 +230,8 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { ...@@ -229,7 +230,8 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
PrepareData<imperative::VarBase>( PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), dynamic_cast<framework::OperatorWithKernel&>(*op),
ins, ins,
prepared_op.kernel_type()); prepared_op.kernel_key(),
cpu_place);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
......
...@@ -239,35 +239,31 @@ template void SetDataLayout<VariableWrapper>( ...@@ -239,35 +239,31 @@ template void SetDataLayout<VariableWrapper>(
/* CheckCachedKey */ /* CheckCachedKey */
template <typename VarType> template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> var, bool CheckCachedKey(std::shared_ptr<VarType> var, const phi::KernelKey &key) {
const paddle::framework::OpKernelType &key) {
return GetVariableWrapper(var)->hasCacheKey(key); return GetVariableWrapper(var)->hasCacheKey(key);
} }
template <> template <>
bool CheckCachedKey<egr::EagerVariable>( bool CheckCachedKey<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> tensor, std::shared_ptr<egr::EagerVariable> tensor, const phi::KernelKey &key) {
const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later // TODO(jiabin): Support this later
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is
// equal to self: " << key == key. // equal to self: " << key == key.
return false; return false;
} }
template bool CheckCachedKey<VarBase>( template bool CheckCachedKey<VarBase>(std::shared_ptr<VarBase> var,
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key); const phi::KernelKey &key);
template bool CheckCachedKey<VariableWrapper>( template bool CheckCachedKey<VariableWrapper>(
std::shared_ptr<VariableWrapper> var, std::shared_ptr<VariableWrapper> var, const phi::KernelKey &key);
const paddle::framework::OpKernelType &key);
/* GetCachedValue */ /* GetCachedValue */
template <typename VarType> template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue( std::shared_ptr<VariableWrapper> GetCachedValue(std::shared_ptr<VarType> var,
std::shared_ptr<VarType> var, const paddle::framework::OpKernelType &key) { const phi::KernelKey &key) {
return GetVariableWrapper(var)->getCacheValue(key); return GetVariableWrapper(var)->getCacheValue(key);
} }
template <> template <>
std::shared_ptr<VariableWrapper> GetCachedValue( std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<egr::EagerVariable> var, std::shared_ptr<egr::EagerVariable> var, const phi::KernelKey &key) {
const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later // TODO(jiabin): Support this later
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this // reach this, support cache and remove this error check later, or this
...@@ -277,22 +273,21 @@ std::shared_ptr<VariableWrapper> GetCachedValue( ...@@ -277,22 +273,21 @@ std::shared_ptr<VariableWrapper> GetCachedValue(
return std::make_shared<VariableWrapper>(""); return std::make_shared<VariableWrapper>("");
} }
template std::shared_ptr<VariableWrapper> GetCachedValue<VarBase>( template std::shared_ptr<VariableWrapper> GetCachedValue<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key); std::shared_ptr<VarBase> var, const phi::KernelKey &key);
template std::shared_ptr<VariableWrapper> GetCachedValue<VariableWrapper>( template std::shared_ptr<VariableWrapper> GetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var, std::shared_ptr<VariableWrapper> var, const phi::KernelKey &key);
const paddle::framework::OpKernelType &key);
/* SetCachedValue */ /* SetCachedValue */
template <typename VarType> template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> var, void SetCachedValue(std::shared_ptr<VarType> var,
const paddle::framework::OpKernelType &key, const phi::KernelKey &key,
std::shared_ptr<VarType> res) { std::shared_ptr<VarType> res) {
GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res)); GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res));
} }
template <> template <>
void SetCachedValue<egr::EagerVariable>( void SetCachedValue<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> tensor, std::shared_ptr<egr::EagerVariable> tensor,
const paddle::framework::OpKernelType &key, const phi::KernelKey &key,
std::shared_ptr<egr::EagerVariable> res) { std::shared_ptr<egr::EagerVariable> res) {
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this // reach this, support cache and remove this error check later, or this
...@@ -300,13 +295,12 @@ void SetCachedValue<egr::EagerVariable>( ...@@ -300,13 +295,12 @@ void SetCachedValue<egr::EagerVariable>(
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key
// is equal to self: " << key == key << " and res name is:" << res->Name(). // is equal to self: " << key == key << " and res name is:" << res->Name().
} }
template void SetCachedValue<VarBase>( template void SetCachedValue<VarBase>(std::shared_ptr<VarBase> var,
std::shared_ptr<VarBase> var, const phi::KernelKey &key,
const paddle::framework::OpKernelType &key, std::shared_ptr<VarBase> res);
std::shared_ptr<VarBase> res);
template void SetCachedValue<VariableWrapper>( template void SetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var, std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key, const phi::KernelKey &key,
std::shared_ptr<VariableWrapper> res); std::shared_ptr<VariableWrapper> res);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -43,16 +43,14 @@ template <typename VarType> ...@@ -43,16 +43,14 @@ template <typename VarType>
const std::string& GetNameFromVar(std::shared_ptr<VarType> var); const std::string& GetNameFromVar(std::shared_ptr<VarType> var);
template <typename VarType> template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> tensor, bool CheckCachedKey(std::shared_ptr<VarType> tensor, const phi::KernelKey& key);
const paddle::framework::OpKernelType& key);
template <typename VarType> template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> tensor, void SetCachedValue(std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key, const phi::KernelKey& key,
std::shared_ptr<VarType> res); std::shared_ptr<VarType> res);
template <typename VarType> template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue( std::shared_ptr<VariableWrapper> GetCachedValue(std::shared_ptr<VarType> tensor,
std::shared_ptr<VarType> tensor, const phi::KernelKey& key);
const paddle::framework::OpKernelType& key);
template <typename VarType> template <typename VarType>
void SetType(std::shared_ptr<VarType> var, void SetType(std::shared_ptr<VarType> var,
......
...@@ -234,16 +234,15 @@ class VariableWrapper { ...@@ -234,16 +234,15 @@ class VariableWrapper {
} }
} }
bool hasCacheKey(const paddle::framework::OpKernelType& key) { bool hasCacheKey(const phi::KernelKey& key) {
return var_cache.find(key) != var_cache.end(); return var_cache.find(key) != var_cache.end();
} }
std::shared_ptr<VariableWrapper> getCacheValue( std::shared_ptr<VariableWrapper> getCacheValue(const phi::KernelKey& key) {
const paddle::framework::OpKernelType& key) {
return var_cache[key]; return var_cache[key];
} }
void setCacheValue(const paddle::framework::OpKernelType& key, void setCacheValue(const phi::KernelKey& key,
std::shared_ptr<VariableWrapper> val) { std::shared_ptr<VariableWrapper> val) {
var_cache[key] = val; var_cache[key] = val;
return; return;
...@@ -323,8 +322,7 @@ class VariableWrapper { ...@@ -323,8 +322,7 @@ class VariableWrapper {
// Used for cache the dtype promotioned variableWrapper in real and complex // Used for cache the dtype promotioned variableWrapper in real and complex
// compute of Paddle Quantum // compute of Paddle Quantum
std::map<paddle::framework::OpKernelType, std::shared_ptr<VariableWrapper>> std::map<phi::KernelKey, std::shared_ptr<VariableWrapper>> var_cache;
var_cache;
// add this property for users may set stop_gradient themselves and this // add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false // should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1}; int overrided_stop_gradient_{-1};
......
...@@ -29,11 +29,11 @@ class AbsOp : public framework::OperatorWithKernel { ...@@ -29,11 +29,11 @@ class AbsOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -70,11 +70,11 @@ class AbsGradOp : public framework::OperatorWithKernel { ...@@ -70,11 +70,11 @@ class AbsGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -124,20 +124,17 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel { ...@@ -124,20 +124,17 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
return framework::OpKernelType(dtype, ctx.GetPlace()); return phi::KernelKey(dtype, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const override {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
framework::TransToProtoVarType(tensor.dtype()),
tensor.place(),
tensor.layout());
} }
}; };
......
...@@ -80,9 +80,9 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -80,9 +80,9 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
auto data_type = oper.IndicateVarDataType(ctx, name); auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users // FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind // TODO(liuwei1031) figure out the reason behind
...@@ -94,7 +94,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -94,7 +94,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
// library = framework::LibraryType::kCUDNN; // library = framework::LibraryType::kCUDNN;
// } // }
// #endif // #endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
class ActivationOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
...@@ -107,7 +107,7 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -107,7 +107,7 @@ class ActivationOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
...@@ -134,7 +134,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -134,7 +134,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out")); return GetKernelType(ctx, *this, framework::GradVarName("Out"));
} }
...@@ -341,7 +341,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -341,7 +341,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX"); return GetKernelType(ctx, *this, "DDX");
} }
...@@ -370,7 +370,7 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { ...@@ -370,7 +370,7 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX"); return GetKernelType(ctx, *this, "DDX");
} }
...@@ -411,7 +411,7 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { ...@@ -411,7 +411,7 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX"); return GetKernelType(ctx, *this, "DDX");
} }
...@@ -487,20 +487,22 @@ class PowOp : public framework::OperatorWithKernel { ...@@ -487,20 +487,22 @@ class PowOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "FactorTensor") { if (var_name == "FactorTensor") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
}; };
...@@ -515,20 +517,22 @@ class PowOpGrad : public framework::OperatorWithKernel { ...@@ -515,20 +517,22 @@ class PowOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out")); return GetKernelType(ctx, *this, framework::GradVarName("Out"));
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "FactorTensor") { if (var_name == "FactorTensor") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
}; };
...@@ -537,7 +541,7 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -537,7 +541,7 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
...@@ -548,7 +552,7 @@ class PowOpTripleGrad : public framework::OperatorWithKernel { ...@@ -548,7 +552,7 @@ class PowOpTripleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
......
...@@ -34,11 +34,10 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { ...@@ -34,11 +34,10 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -54,11 +53,11 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { ...@@ -54,11 +53,11 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -145,11 +145,11 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ...@@ -145,11 +145,11 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -130,10 +130,10 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -130,10 +130,10 @@ class AffineGridOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
...@@ -241,11 +241,11 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -241,11 +241,11 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output")); ctx, framework::GradVarName("Output"));
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -65,11 +65,10 @@ class AllcloseOp : public framework::OperatorWithKernel { ...@@ -65,11 +65,10 @@ class AllcloseOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
......
...@@ -34,10 +34,9 @@ class AllocFloatStatusOp : public framework::OperatorWithKernel { ...@@ -34,10 +34,9 @@ class AllocFloatStatusOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -29,13 +29,13 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { ...@@ -29,13 +29,13 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto dtype = framework::proto::VarType::FP32; auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) { if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
} }
return framework::OpKernelType(dtype, ctx.GetPlace()); return phi::KernelKey(dtype, ctx.GetPlace());
} }
}; };
......
...@@ -34,10 +34,9 @@ class ClearFloatStatusOp : public framework::OperatorWithKernel { ...@@ -34,10 +34,9 @@ class ClearFloatStatusOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -34,10 +34,9 @@ class GetFloatStatusOp : public framework::OperatorWithKernel { ...@@ -34,10 +34,9 @@ class GetFloatStatusOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -29,23 +29,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { ...@@ -29,23 +29,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto dtype = framework::proto::VarType::FP32; auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) { if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
} }
return framework::OpKernelType(dtype, ctx.GetPlace()); return phi::KernelKey(dtype, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override { const phi::KernelKey& expected_kernel_type) const override {
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
if (var_name == "FoundInfinite" || var_name == "StopUpdate") { if (var_name == "FoundInfinite" || var_name == "StopUpdate") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
#endif #endif
return framework::OperatorWithKernel::GetKernelTypeForVar( return framework::OperatorWithKernel::GetKernelTypeForVar(
......
...@@ -32,11 +32,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { ...@@ -32,11 +32,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -23,10 +23,10 @@ class AscendTriggerOp : public framework::OperatorWithKernel { ...@@ -23,10 +23,10 @@ class AscendTriggerOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context()); ctx.device_context().GetPlace());
} }
}; };
......
...@@ -41,16 +41,16 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -41,16 +41,16 @@ class AssignOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const phi::KernelKey &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.place_, tensor.layout(),
tensor.layout()); expected_kernel_type.dtype());
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
const framework::Variable *var = ctx.InputVar("X"); const framework::Variable *var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) { if (var->IsType<framework::LoDTensorArray>()) {
...@@ -58,14 +58,13 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -58,14 +58,13 @@ class AssignOp : public framework::OperatorWithKernel {
// NOTE(liym27): Support an empty tensor array as Input. // NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float. // And set the kernel type is float.
if (t_arr.size() == 0) { if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context()); ctx.device_context().GetPlace());
} }
} }
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
......
...@@ -31,7 +31,7 @@ class AssignPosOp : public framework::OperatorWithKernel { ...@@ -31,7 +31,7 @@ class AssignPosOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto cum_count_dtype = auto cum_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "cum_count"); OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
...@@ -46,7 +46,7 @@ class AssignPosOp : public framework::OperatorWithKernel { ...@@ -46,7 +46,7 @@ class AssignPosOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and " "The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64")); "X should be same as int64"));
return framework::OpKernelType(cum_count_dtype, ctx.device_context()); return phi::KernelKey(cum_count_dtype, ctx.device_context().GetPlace());
} }
}; };
......
...@@ -44,9 +44,9 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -44,9 +44,9 @@ class AssignValueOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return phi::KernelKey(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
......
...@@ -198,10 +198,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -198,10 +198,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
} }
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( phi::KernelKey AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); ctx.device_context().GetPlace());
} }
void AttentionLSTMOpMaker::Make() { void AttentionLSTMOpMaker::Make() {
......
...@@ -25,7 +25,7 @@ class AttentionLSTMOp : public framework::OperatorWithKernel { ...@@ -25,7 +25,7 @@ class AttentionLSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
......
...@@ -26,10 +26,10 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -26,10 +26,10 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "param"),
OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -77,11 +77,10 @@ class BatchFCOp : public framework::OperatorWithKernel { ...@@ -77,11 +77,10 @@ class BatchFCOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
...@@ -106,11 +105,11 @@ class BatchFCGradOp : public framework::OperatorWithKernel { ...@@ -106,11 +105,11 @@ class BatchFCGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context().GetPlace());
} }
}; };
......
...@@ -171,7 +171,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -171,7 +171,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
} }
} }
framework::OpKernelType BatchNormOp::GetExpectedKernelType( phi::KernelKey BatchNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
...@@ -202,18 +202,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -202,18 +202,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variance input should be of float type")); "Variance input should be of float type"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType BatchNormOp::GetKernelTypeForVar( phi::KernelKey BatchNormOp::GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const { const phi::KernelKey &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
// bias are having shape in NCHW order // bias are having shape in NCHW order
if ((var_name == "X") && if ((var_name == "X") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) { (tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs(); auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
...@@ -222,13 +222,12 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar( ...@@ -222,13 +222,12 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) { if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
expected_kernel_type.data_type_, tensor.place(), dl);
} }
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
void BatchNormOpMaker::Make() { void BatchNormOpMaker::Make() {
...@@ -373,7 +372,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -373,7 +372,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
} }
} }
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( phi::KernelKey BatchNormGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y")); const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) { if (var == nullptr) {
...@@ -392,18 +391,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -392,18 +391,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
} }
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( phi::KernelKey BatchNormGradOp::GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const { const phi::KernelKey &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
// bias are having shape in NCHW order // bias are having shape in NCHW order
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) && if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) { (tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs(); auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
...@@ -412,13 +411,12 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( ...@@ -412,13 +411,12 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) { if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
expected_kernel_type.data_type_, tensor.place(), dl);
} }
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
template <typename T> template <typename T>
...@@ -515,7 +513,7 @@ void BatchNormDoubleGradOp::InferShape( ...@@ -515,7 +513,7 @@ void BatchNormDoubleGradOp::InferShape(
} }
} }
framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY"); const auto *var = ctx.InputVar("DY");
if (var == nullptr) { if (var == nullptr) {
...@@ -532,8 +530,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( ...@@ -532,8 +530,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
PADDLE_THROW( PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty")); platform::errors::InvalidArgument("gradient variable of Y is empty"));
} }
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
......
...@@ -47,13 +47,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -47,13 +47,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override; const phi::KernelKey& expected_kernel_type) const override;
}; };
class BatchNormGradOp : public framework::OperatorWithKernel { class BatchNormGradOp : public framework::OperatorWithKernel {
...@@ -62,13 +62,13 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -62,13 +62,13 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override; const phi::KernelKey& expected_kernel_type) const override;
}; };
class BatchNormDoubleGradOp : public framework::OperatorWithKernel { class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
...@@ -77,7 +77,7 @@ class BatchNormDoubleGradOp : public framework::OperatorWithKernel { ...@@ -77,7 +77,7 @@ class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
......
...@@ -28,11 +28,10 @@ class BCELossOp : public framework::OperatorWithKernel { ...@@ -28,11 +28,10 @@ class BCELossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
...@@ -87,11 +86,10 @@ class BCELossGradOp : public framework::OperatorWithKernel { ...@@ -87,11 +86,10 @@ class BCELossGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
......
...@@ -108,7 +108,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -108,7 +108,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto *scores = ctx.Input<phi::DenseTensor>("scores"); auto *scores = ctx.Input<phi::DenseTensor>("scores");
size_t level = ctx.Attr<int>("level"); size_t level = ctx.Attr<int>("level");
...@@ -116,11 +116,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -116,11 +116,11 @@ class BeamSearchOp : public framework::OperatorWithKernel {
// The current CUDA kernel only support cases with batch_size < 4. // The current CUDA kernel only support cases with batch_size < 4.
// Compute on CPU for cases with batch_size > 4. // Compute on CPU for cases with batch_size > 4.
if (batch_size <= 4) { if (batch_size <= 4) {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
ctx.GetPlace()); ctx.GetPlace());
} else { } else {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
platform::CPUPlace()); platform::CPUPlace());
} }
......
...@@ -85,10 +85,10 @@ class BilateralSliceOp : public framework::OperatorWithKernel { ...@@ -85,10 +85,10 @@ class BilateralSliceOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -147,11 +147,11 @@ class BilateralSliceOpGrad : public framework::OperatorWithKernel { ...@@ -147,11 +147,11 @@ class BilateralSliceOpGrad : public framework::OperatorWithKernel {
} }
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -24,19 +24,17 @@ limitations under the License. */ ...@@ -24,19 +24,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::OpKernelType;
class BincountOp : public framework::OperatorWithKernel { class BincountOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto data_type = auto data_type =
ctx.HasInput("Weights") ctx.HasInput("Weights")
? OperatorWithKernel::IndicateVarDataType(ctx, "Weights") ? OperatorWithKernel::IndicateVarDataType(ctx, "Weights")
: OperatorWithKernel::IndicateVarDataType(ctx, "X"); : OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.device_context().GetPlace());
} }
}; };
......
...@@ -56,11 +56,10 @@ class BprLossOp : public framework::OperatorWithKernel { ...@@ -56,11 +56,10 @@ class BprLossOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of computation kernel of Seq-bpr // Explicitly set that the data type of computation kernel of Seq-bpr
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -119,11 +118,10 @@ class BprLossGradientOp : public framework::OperatorWithKernel { ...@@ -119,11 +118,10 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of computation kernel of cross_entropy // Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -27,14 +27,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel { ...@@ -27,14 +27,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// Broadcast semantics enforces all input variables having the same // Broadcast semantics enforces all input variables having the same
// DataType/VarType // DataType/VarType
// This condition is also checked during VarType Inference // This condition is also checked during VarType Inference
// Here we simply copy input type to output // Here we simply copy input type to output
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -127,11 +127,11 @@ class BroadcastTensorsGradOp : public framework::OperatorWithKernel { ...@@ -127,11 +127,11 @@ class BroadcastTensorsGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context().GetPlace());
} }
}; };
......
...@@ -75,7 +75,7 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -75,7 +75,7 @@ class CastOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
// CastOp kernel's device type is decided by input tensor place // CastOp kernel's device type is decided by input tensor place
auto *tensor = ctx.Input<phi::DenseTensor>("X"); auto *tensor = ctx.Input<phi::DenseTensor>("X");
...@@ -86,9 +86,8 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -86,9 +86,8 @@ class CastOp : public framework::OperatorWithKernel {
auto &tensor_place = tensor->place(); auto &tensor_place = tensor->place();
// NOTE: cuda pinned tensor need to copy its data to target place // NOTE: cuda pinned tensor need to copy its data to target place
if (platform::is_cuda_pinned_place(tensor_place)) { if (platform::is_cuda_pinned_place(tensor_place)) {
return framework::OpKernelType( return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
framework::TransToProtoVarType(tensor->dtype()), ctx.device_context().GetPlace());
ctx.device_context());
} }
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
...@@ -108,20 +107,19 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -108,20 +107,19 @@ class CastOp : public framework::OperatorWithKernel {
auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype")); auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype"));
auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype")); auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype"));
if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) { if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) {
return framework::OpKernelType( return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
framework::TransToProtoVarType(tensor->dtype()), tensor_place); tensor_place);
} else { } else {
VLOG(3) << "MLU not support cast type: " VLOG(3) << "MLU not support cast type: "
<< framework::DataTypeToString(src_type) << framework::DataTypeToString(src_type)
<< " to type: " << framework::DataTypeToString(dst_type) << " to type: " << framework::DataTypeToString(dst_type)
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return framework::OpKernelType( return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
framework::TransToProtoVarType(tensor->dtype()), platform::CPUPlace());
platform::CPUPlace());
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
framework::TransToProtoVarType(tensor->dtype()), tensor_place); tensor_place);
} }
}; };
......
...@@ -53,11 +53,10 @@ class CenterLossOp : public framework::OperatorWithKernel { ...@@ -53,11 +53,10 @@ class CenterLossOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
...@@ -115,11 +114,11 @@ class CenterLossGradOp : public framework::OperatorWithKernel { ...@@ -115,11 +114,11 @@ class CenterLossGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"),
ctx.device_context()); ctx.device_context().GetPlace());
} }
}; };
......
...@@ -88,10 +88,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -88,10 +88,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -57,10 +57,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { ...@@ -57,10 +57,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
* specified a data type here. * specified a data type here.
* *
*/ */
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -117,10 +117,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel { ...@@ -117,10 +117,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
* Of course, the data type here is also not important. * Of course, the data type here is also not important.
*/ */
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -26,11 +26,10 @@ class ClassCenterSampleOp : public framework::OperatorWithKernel { ...@@ -26,11 +26,10 @@ class ClassCenterSampleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Label"),
OperatorWithKernel::IndicateVarDataType(ctx, "Label"), ctx.device_context().GetPlace());
ctx.device_context());
} }
}; };
......
...@@ -26,11 +26,11 @@ namespace operators { ...@@ -26,11 +26,11 @@ namespace operators {
class ClipOp : public framework::OperatorWithKernel { class ClipOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -85,11 +85,11 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -85,11 +85,11 @@ class ClipOpGrad : public framework::OperatorWithKernel {
} }
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -405,20 +405,20 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { ...@@ -405,20 +405,20 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &context) const override { const framework::ExecutionContext &context) const override {
auto dtype = static_cast<framework::proto::VarType::Type>( auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")); context.Attr<int>("dtype"));
return framework::OpKernelType(dtype, context.GetPlace()); return phi::KernelKey(dtype, context.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const phi::KernelKey &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.place_, tensor.layout(),
tensor.layout()); expected_kernel_type.dtype());
} }
}; };
......
...@@ -27,10 +27,10 @@ class AllReduceOp : public framework::OperatorWithKernel { ...@@ -27,10 +27,10 @@ class AllReduceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -36,10 +36,10 @@ class AllToAllOp : public framework::OperatorWithKernel { ...@@ -36,10 +36,10 @@ class AllToAllOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -71,21 +71,23 @@ class CAllReduceOp : public framework::OperatorWithKernel { ...@@ -71,21 +71,23 @@ class CAllReduceOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const { const phi::KernelKey& expected_kernel_type) const {
if (var_name == "Cond") { if (var_name == "Cond") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} else { } else {
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
} }
}; };
......
...@@ -26,10 +26,10 @@ class CBroadcastOp : public framework::OperatorWithKernel { ...@@ -26,10 +26,10 @@ class CBroadcastOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -58,10 +58,10 @@ class CConcatOp : public framework::OperatorWithKernel { ...@@ -58,10 +58,10 @@ class CConcatOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -65,10 +65,10 @@ class CEmbeddingOp : public framework::OperatorWithKernel { ...@@ -65,10 +65,10 @@ class CEmbeddingOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
...@@ -149,11 +149,11 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel { ...@@ -149,11 +149,11 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -35,10 +35,10 @@ class CIdentityOp : public framework::OperatorWithKernel { ...@@ -35,10 +35,10 @@ class CIdentityOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -66,10 +66,10 @@ class CReduceOp : public framework::OperatorWithKernel { ...@@ -66,10 +66,10 @@ class CReduceOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -52,10 +52,10 @@ class CScatterOp : public framework::OperatorWithKernel { ...@@ -52,10 +52,10 @@ class CScatterOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -73,11 +73,10 @@ class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -73,11 +73,10 @@ class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace());
ctx.device_context());
} }
}; };
...@@ -150,11 +149,11 @@ class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -150,11 +149,11 @@ class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")), ctx, framework::GradVarName("Loss")),
ctx.device_context()); ctx.GetPlace());
} }
}; };
......
...@@ -66,10 +66,10 @@ class CSplitOp : public framework::OperatorWithKernel { ...@@ -66,10 +66,10 @@ class CSplitOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -25,10 +28,9 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel { ...@@ -25,10 +28,9 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -23,10 +23,9 @@ class CSyncCommStreamOp : public framework::OperatorWithKernel { ...@@ -23,10 +23,9 @@ class CSyncCommStreamOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -49,10 +49,10 @@ class GlobalGatherOp : public framework::OperatorWithKernel { ...@@ -49,10 +49,10 @@ class GlobalGatherOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -52,10 +52,10 @@ class GlobalScatterOp : public framework::OperatorWithKernel { ...@@ -52,10 +52,10 @@ class GlobalScatterOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -80,12 +80,12 @@ class PartialRecvOp : public framework::OperatorWithKernel { ...@@ -80,12 +80,12 @@ class PartialRecvOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type = framework::proto::VarType::Type type =
framework::proto::VarType::Type(dtype); framework::proto::VarType::Type(dtype);
return framework::OpKernelType(type, ctx.GetPlace()); return phi::KernelKey(type, ctx.GetPlace());
} }
}; };
......
...@@ -51,10 +51,10 @@ class PartialSendOp : public framework::OperatorWithKernel { ...@@ -51,10 +51,10 @@ class PartialSendOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -69,12 +69,12 @@ class RecvOpV2 : public framework::OperatorWithKernel { ...@@ -69,12 +69,12 @@ class RecvOpV2 : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type = framework::proto::VarType::Type type =
framework::proto::VarType::Type(dtype); framework::proto::VarType::Type(dtype);
return framework::OpKernelType(type, ctx.GetPlace()); return phi::KernelKey(type, ctx.GetPlace());
} }
}; };
......
...@@ -38,7 +38,7 @@ class SendOpV2 : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class SendOpV2 : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
const framework::Variable* var = ctx.InputVar("X"); const framework::Variable* var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) { if (var->IsType<framework::LoDTensorArray>()) {
...@@ -46,12 +46,11 @@ class SendOpV2 : public framework::OperatorWithKernel { ...@@ -46,12 +46,11 @@ class SendOpV2 : public framework::OperatorWithKernel {
// NOTE(sandyhouse): Support an empty tensor array as Input. // NOTE(sandyhouse): Support an empty tensor array as Input.
// And set the kernel type is float. // And set the kernel type is float.
if (t_arr.size() == 0) { if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
ctx.device_context());
} }
} }
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -32,7 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -32,7 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<phi::DenseTensor>("X"); auto inputs = ctx.MultiInput<phi::DenseTensor>("X");
auto input_data_type = framework::proto::VarType::Type(0); auto input_data_type = framework::proto::VarType::Type(0);
...@@ -48,18 +48,20 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -48,18 +48,20 @@ class ConcatOp : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!")); "All Inputs of Concat OP are Empty!"));
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") { if (var_name == "AxisTensor") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
}; };
...@@ -110,22 +112,24 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -110,22 +112,24 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") { if (var_name == "AxisTensor") {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
}; };
......
...@@ -97,11 +97,12 @@ class UnaryBitwiseOp : public framework::OperatorWithKernel { ...@@ -97,11 +97,12 @@ class UnaryBitwiseOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place // BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place(); kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt; return kt;
} }
}; };
...@@ -138,11 +139,12 @@ class BinaryBitwiseOp : public framework::OperatorWithKernel { ...@@ -138,11 +139,12 @@ class BinaryBitwiseOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place // BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place(); kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt; return kt;
} }
}; };
......
...@@ -61,19 +61,20 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -61,19 +61,20 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place // CompareOp kernel's device type is decided by input tensor place
bool force_cpu = ctx.Attr<bool>("force_cpu"); bool force_cpu = ctx.Attr<bool>("force_cpu");
if (force_cpu) { if (force_cpu) {
kt.place_ = platform::CPUPlace(); kt.set_backend(phi::Backend::CPU);
} else { } else {
if (ctx.Input<phi::DenseTensor>("X")->place().GetType() != if (ctx.Input<phi::DenseTensor>("X")->place().GetType() !=
phi::AllocationType::GPUPINNED) { phi::AllocationType::GPUPINNED) {
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place(); kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
} else { } else {
kt.place_ = ctx.GetPlace(); kt.set_backend(phi::TransToPhiBackend(ctx.GetPlace()));
} }
} }
return kt; return kt;
......
...@@ -72,48 +72,49 @@ class FetchV2Op : public framework::OperatorWithKernel { ...@@ -72,48 +72,49 @@ class FetchV2Op : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext *ctx) const override {}
protected: protected:
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const phi::KernelKey &expected_kernel_type) const override {
if (!tensor.IsInitialized()) { if (!tensor.IsInitialized()) {
return expected_kernel_type; return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} }
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto *fetch_var = ctx.InputVar("X"); auto *fetch_var = ctx.InputVar("X");
if (fetch_var == nullptr) { if (fetch_var == nullptr) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
if (fetch_var->IsType<phi::DenseTensor>()) { if (fetch_var->IsType<phi::DenseTensor>()) {
auto &src_item = fetch_var->Get<phi::DenseTensor>(); auto &src_item = fetch_var->Get<phi::DenseTensor>();
if (!src_item.IsInitialized()) { if (!src_item.IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
} else if (fetch_var->IsType<phi::SparseCooTensor>()) { } else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>(); auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
if (!src_item.initialized()) { if (!src_item.initialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
} else { } else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
if (src_item.empty() || !src_item[0].IsInitialized()) { if (src_item.empty() || !src_item[0].IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32, return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
} }
return framework::OpKernelType( return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -69,11 +69,12 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -69,11 +69,12 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place // LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place(); kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt; return kt;
} }
}; };
......
...@@ -186,7 +186,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape( ...@@ -186,7 +186,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(
return output_shape; return output_shape;
} }
framework::OpKernelType ConvOp::GetExpectedKernelType( phi::KernelKey ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
// todo enable data layout when it's ready // todo enable data layout when it's ready
...@@ -208,18 +208,18 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -208,18 +208,18 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
paddle::framework::DataTypeToString(filter_data_type))); paddle::framework::DataTypeToString(filter_data_type)));
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType ConvOp::GetKernelTypeForVar( phi::KernelKey ConvOp::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const { const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
// bias are having shape in NCHW order // bias are having shape in NCHW order
if ((var_name == "Input") && if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) { (tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs(); auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
...@@ -228,13 +228,12 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar( ...@@ -228,13 +228,12 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for conv // Some models may have intentionally set "AnyLayout" for conv
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) { if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
expected_kernel_type.data_type_, tensor.place(), dl);
} }
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
...@@ -447,23 +446,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -447,23 +446,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
} }
} }
framework::OpKernelType ConvOpGrad::GetExpectedKernelType( phi::KernelKey ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( phi::KernelKey ConvOpGrad::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const { const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
// bias are having shape in NCHW order // bias are having shape in NCHW order
if (((var_name == "Input") || if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) && (var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) { (tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs(); auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
...@@ -472,13 +471,12 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( ...@@ -472,13 +471,12 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) { if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
expected_kernel_type.data_type_, tensor.place(), dl);
} }
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
template <typename T> template <typename T>
...@@ -619,10 +617,10 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -619,10 +617,10 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
} }
} }
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
} // namespace operators } // namespace operators
......
...@@ -196,13 +196,13 @@ class ConvOp : public framework::OperatorWithKernel { ...@@ -196,13 +196,13 @@ class ConvOp : public framework::OperatorWithKernel {
std::vector<int64_t> ComputeOutputShape( std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const; framework::InferShapeContext* ctx) const;
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override; const phi::KernelKey& expected_kernel_type) const override;
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
...@@ -211,13 +211,13 @@ class ConvOpGrad : public framework::OperatorWithKernel { ...@@ -211,13 +211,13 @@ class ConvOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override; const phi::KernelKey& expected_kernel_type) const override;
}; };
class ConvOpDoubleGrad : public framework::OperatorWithKernel { class ConvOpDoubleGrad : public framework::OperatorWithKernel {
...@@ -226,7 +226,7 @@ class ConvOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -226,7 +226,7 @@ class ConvOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
......
...@@ -33,21 +33,21 @@ namespace operators { ...@@ -33,21 +33,21 @@ namespace operators {
using DataLayout = phi::DataLayout; using DataLayout = phi::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( phi::KernelKey ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( phi::KernelKey ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const { const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
// bias are having shape in NCHW order // bias are having shape in NCHW order
if ((var_name == "Input") && if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && (expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) { (tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs(); auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
...@@ -56,13 +56,12 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( ...@@ -56,13 +56,12 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) { if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType( return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
expected_kernel_type.data_type_, tensor.place(), dl);
} }
} }
#endif #endif
return framework::OpKernelType( return phi::KernelKey(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); tensor.place(), tensor.layout(), expected_kernel_type.dtype());
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -253,10 +252,10 @@ Example: ...@@ -253,10 +252,10 @@ Example:
)DOC"); )DOC");
} }
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( phi::KernelKey ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
template <typename T> template <typename T>
...@@ -320,10 +319,10 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -320,10 +319,10 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( phi::KernelKey ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace()); return phi::KernelKey(data_type, ctx.GetPlace());
} }
} // namespace operators } // namespace operators
......
...@@ -38,13 +38,13 @@ class ConvTransposeOp : public framework::OperatorWithKernel { ...@@ -38,13 +38,13 @@ class ConvTransposeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar( phi::KernelKey GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override; const phi::KernelKey& expected_kernel_type) const override;
}; };
class ConvTransposeOpGrad : public framework::OperatorWithKernel { class ConvTransposeOpGrad : public framework::OperatorWithKernel {
...@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
...@@ -61,7 +61,7 @@ class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -61,7 +61,7 @@ class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
......
...@@ -109,7 +109,7 @@ class CorrelationOp : public framework::OperatorWithKernel { ...@@ -109,7 +109,7 @@ class CorrelationOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input1"); OperatorWithKernel::IndicateVarDataType(ctx, "Input1");
...@@ -118,7 +118,7 @@ class CorrelationOp : public framework::OperatorWithKernel { ...@@ -118,7 +118,7 @@ class CorrelationOp : public framework::OperatorWithKernel {
ctx.Input<phi::DenseTensor>("Input2")->dtype()), ctx.Input<phi::DenseTensor>("Input2")->dtype()),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"X and Y shoule have the same datatype")); "X and Y shoule have the same datatype"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return phi::KernelKey(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -158,9 +158,9 @@ class CorrelationOpGrad : public framework::OperatorWithKernel { ...@@ -158,9 +158,9 @@ class CorrelationOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace());
} }
}; };
......
...@@ -202,9 +202,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -202,9 +202,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace()); platform::CPUPlace());
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册