未验证 提交 4383494f 编写于 作者: H HongyuJia 提交者: GitHub

[Unify KernelKey] change OpKernelType->KernelKey (#49138)

* execute use kernel_key first

* change OpKernelType->KernelKey

* fix py3 compile error, remove redundant header files

* fix build_strategy_test

* fix DataType::RAW

* fix custom_type test: operator_test.cc

* fix transform place

* fix backends_are_same_class

* try fix place TransDataDevice

* support all KernelKey

* fix TransformData

* fix place_are_same_class

* fix merge

* fix test_params_no_grad

* fix specific place of GetExpectedKernelType

* fix specific place of GetExpectedKernelType

* fix GetKernelTypeForVar

* fix dtype error

* fix fetch_v2

* change GetKernelTypeForVar

* fix interpreter

* fix typo error

* polish codes

* polish codes

* polish codes

* fix conflict
上级 723ceed9
......@@ -422,9 +422,9 @@ class CustomOperator : public OperatorWithKernel {
* The RAW type is used here as the data type, indicating that
* it can only be determined at runtime.
*/
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace());
return phi::KernelKey(ctx.GetPlace());
}
/**
......@@ -432,13 +432,13 @@ class CustomOperator : public OperatorWithKernel {
* Because the kernel data type is RAW, we should skip the cast for
* data type difference when PrepareData.
*/
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const override {
return OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
const phi::KernelKey& expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
};
......
......@@ -47,15 +47,17 @@ class TestOpWithKernel : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
return phi::KernelKey(phi::Backend::GPU,
phi::DataLayout::ALL_LAYOUT,
phi::DataType::FLOAT32);
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("input")->place());
return phi::KernelKey(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("input")->place());
}
}
};
......
......@@ -50,13 +50,14 @@ void CastDataLayout::apply() {
}
}
void TransDataLayout(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out) {
phi::DenseTensor* out,
const phi::Place& place) {
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_type.place_),
backends_are_same_class(kernel_type_for_var.backend(),
expected_kernel_type.backend()),
platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));
......@@ -72,21 +73,20 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
auto src_dim = in.dims();
std::vector<int64_t> dst_dim;
auto axis = GetAxis(kernel_type_for_var.data_layout_,
expected_kernel_type.data_layout_);
auto axis =
GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout());
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(expected_kernel_type.place_, in.dtype());
out->mutable_data(place, in.dtype());
framework::VisitDataType(
framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
framework::VisitDataType(framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(place), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_);
out->set_layout(expected_kernel_type.layout());
}
} // namespace framework
......
......@@ -54,10 +54,11 @@ struct CastDataLayout {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out);
phi::DenseTensor* out,
const phi::Place& place);
} // namespace framework
} // namespace paddle
......@@ -24,22 +24,16 @@ TEST(DataTransform, DataLayoutFunction) {
in.set_layout(phi::DataLayout::kNHWC);
auto kernel_nhwc =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32,
place,
phi::DataLayout::kNHWC,
paddle::framework::LibraryType::kPlain);
phi::KernelKey(place, phi::DataLayout::kNHWC, phi::DataType::FLOAT32);
auto kernel_ncwh =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32,
place,
phi::DataLayout::kNCHW,
paddle::framework::LibraryType::kPlain);
phi::KernelKey(place, phi::DataLayout::kNCHW, phi::DataType::FLOAT32);
paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
paddle::framework::TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out, place);
EXPECT_TRUE(out.layout() == phi::DataLayout::kNCHW);
EXPECT_TRUE(out.dims() == phi::make_ddim({2, 2, 3, 1}));
TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out);
paddle::framework::TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out, place);
EXPECT_TRUE(in.layout() == phi::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == phi::make_ddim({2, 3, 1, 2}));
......
......@@ -36,16 +36,17 @@ static void PassTensorData(phi::DenseTensor *from, phi::DenseTensor *to) {
*from = phi::DenseTensor();
}
void TransformData(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var,
void TransformData(const phi::KernelKey &expected_kernel_type,
const phi::KernelKey &kernel_type_for_var,
const phi::DenseTensor &input_tensor,
phi::DenseTensor *output_tensor) {
phi::DenseTensor *output_tensor,
const phi::Place &place) {
bool transformed = false;
phi::DenseTensor in;
in.ShareDataWith(input_tensor);
phi::DenseTensor out;
const DataLayout lin = kernel_type_for_var.data_layout_;
const DataLayout lout = expected_kernel_type.data_layout_;
const DataLayout lin = kernel_type_for_var.layout();
const DataLayout lout = expected_kernel_type.layout();
// do layout transform
if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN
......@@ -79,43 +80,42 @@ void TransformData(const OpKernelType &expected_kernel_type,
} else {
// Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel
// Do transform via ONEDNN lib
PADDLE_ENFORCE(
kernel_type_for_var.data_layout_ == DataLayout::ONEDNN &&
expected_kernel_type.data_layout_ != DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"TransDataLayoutFromOneDNN only supports "
"transform from ONEDNN to non-ONEDNN"));
PADDLE_ENFORCE(lin == DataLayout::ONEDNN && lout != DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"TransDataLayoutFromOneDNN only supports "
"transform from ONEDNN to non-ONEDNN"));
phi::funcs::TransDataLayoutFromOneDNN(
kernel_type_for_var.data_layout_,
lin,
phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
in,
&out,
expected_kernel_type.place_);
place);
}
} else {
// Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
TransDataLayout(
kernel_type_for_var, expected_kernel_type, in, &out, place);
}
#else
// Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out, place);
#endif
transformed = true;
PassTensorData(&out, &in);
}
// do data type transform
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
if (NeedTransformDataType(expected_kernel_type, kernel_type_for_var)) {
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true;
PassTensorData(&out, &in);
}
// do device transform
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
TransDataDevice(in, expected_kernel_type.place_, &out);
if (kernel_type_for_var.backend() != phi::Backend::ALL_BACKEND &&
!platform::is_same_place(in.place(), place)) {
TransDataDevice(in, place, &out);
transformed = true;
PassTensorData(&out, &in);
}
......
......@@ -33,10 +33,11 @@ namespace framework {
class OpKernelType;
class Variable;
void TransformData(const OpKernelType &expected_kernel_type,
const OpKernelType &kernel_type_for_var,
void TransformData(const phi::KernelKey &expected_kernel_type,
const phi::KernelKey &kernel_type_for_var,
const phi::DenseTensor &input_tensor,
phi::DenseTensor *out);
phi::DenseTensor *out,
const phi::Place &place);
/**
* Set OutVar from InVar, except the tensor is shared with `tensor`
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
......@@ -226,6 +227,11 @@ extern inline bool IsComplexType(const proto::VarType::Type& type) {
type == proto::VarType::COMPLEX128);
}
extern inline bool IsComplexType(const phi::DataType& type) {
return (type == phi::DataType::COMPLEX64 ||
type == phi::DataType::COMPLEX128);
}
extern proto::VarType::Type PromoteTypesIfComplexExists(
const proto::VarType::Type type_a, const proto::VarType::Type type_b);
......
......@@ -129,19 +129,18 @@ struct CastDataType {
}
};
void TransDataType(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
void TransDataType(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out) {
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(in.dtype()),
kernel_type_for_var.data_type_,
platform::errors::InvalidArgument(
"The src dtype(%s) of input tensor and kernel_type(%s) "
"are not conststent.",
DataTypeToString(framework::TransToProtoVarType(in.dtype())),
DataTypeToString(kernel_type_for_var.data_type_)));
auto dst_type = expected_kernel_type.data_type_;
PADDLE_ENFORCE_EQ(in.dtype(),
kernel_type_for_var.dtype(),
platform::errors::InvalidArgument(
"The src dtype(%s) of input tensor and kernel_type(%s) "
"are not conststent.",
DataTypeToString(in.dtype()),
DataTypeToString(kernel_type_for_var.dtype())));
auto dst_type = framework::TransToProtoVarType(expected_kernel_type.dtype());
TransDataType(in, dst_type, out);
}
......
......@@ -28,8 +28,8 @@ class OpKernelType;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataType(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
void TransDataType(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out);
void TransDataType(const phi::DenseTensor& in,
......
......@@ -19,47 +19,26 @@ limitations under the License. */
TEST(DataTypeTransform, CPUTransform) {
auto place = paddle::platform::CPUPlace();
auto kernel_fp16 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bf16 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BF16,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp16 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16);
auto kernel_bf16 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BFLOAT16);
auto kernel_fp32 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
auto kernel_fp64 = phi::KernelKey(
place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64);
auto kernel_int32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32);
auto kernel_int64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64);
auto kernel_bool =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL,
place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL);
// data type transform from float32
{
......
......@@ -24,41 +24,24 @@ TEST(DataTypeTransform, GPUTransform) {
.GetAllocator(gpu_place, context.stream())
.get());
context.PartialInitWithAllocator();
auto kernel_fp16 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP16,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP32,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::FP64,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_int32 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT32,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_int64 =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::INT64,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bool =
paddle::framework::OpKernelType(paddle::framework::proto::VarType::BOOL,
gpu_place,
phi::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp16 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT16);
auto kernel_fp32 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
auto kernel_fp64 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT64);
auto kernel_int32 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT32);
auto kernel_int64 = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::INT64);
auto kernel_bool = phi::KernelKey(
gpu_place, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL);
// data type transform from float32
{
......
......@@ -50,10 +50,10 @@ class SumOpWithKernel : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext &ctx) const override {
return OpKernelType(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("X")->place());
return phi::KernelKey(proto::VarType::FP32,
ctx.Input<phi::DenseTensor>("X")->place());
}
};
......
......@@ -84,9 +84,9 @@ class InferShapeUtilsTestOp : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
return phi::KernelKey(proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -27,22 +27,26 @@ namespace paddle {
namespace framework {
namespace interpreter {
bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key,
const std::string& var_name,
std::string* new_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope,
bool is_fetch_v2,
bool skip_run) {
bool DataTranferHelper::apply(
const phi::KernelKey& kernel_type_for_var,
const framework::OpKernelType& expected_kernel_key,
const phi::DenseTensor* tensor,
const std::string& var_name,
std::string* new_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope,
bool is_fetch_v2,
bool skip_run) {
bool is_transferred = false;
auto* src_var_name = &var_name;
// 1. layout transform
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
if (need_layout_transform(
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
auto op = TransferLayout(*src_var_name,
new_var_name,
kernel_type_for_var.data_layout_,
kernel_type_for_var.layout(),
expected_kernel_key.data_layout_,
var_scope_,
scope_,
......@@ -56,13 +60,16 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
is_transferred = true;
}
// 2. dype transform
if (need_dtype_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferDtype(*src_var_name,
new_var_name,
kernel_type_for_var.data_type_,
expected_kernel_key.data_type_,
var_scope_,
scope_);
if (need_dtype_transform(
kernel_type_for_var,
TransOpKernelTypeToPhiKernelKey(expected_kernel_key))) {
auto op = TransferDtype(
*src_var_name,
new_var_name,
framework::TransToProtoVarType(kernel_type_for_var.dtype()),
expected_kernel_key.data_type_,
var_scope_,
scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
......@@ -72,8 +79,9 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
is_transferred = true;
}
// 3. device transform
if (need_device_transform(kernel_type_for_var, expected_kernel_key)) {
auto src_place = kernel_type_for_var.place_;
if (need_device_transform(
kernel_type_for_var, tensor, expected_kernel_key.place_)) {
auto src_place = tensor->place();
auto dst_place = expected_kernel_key.place_;
auto op = TransferDevice(
......@@ -526,11 +534,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
auto kernel_type_for_var =
static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
var_name_item.first,
*tensor_in,
framework::TransOpKernelTypeToPhiKernelKey(
expected_kernel_key));
// apply data transform
is_transferred =
data_transfer_helper.apply(kernel_type_for_var,
expected_kernel_key,
tensor_in,
var_name,
&new_var_name,
new_op_func_nodes,
......
......@@ -34,8 +34,9 @@ class DataTranferHelper {
Scope* local_scope)
: place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key,
bool apply(const phi::KernelKey& kernel_type_for_var,
const framework::OpKernelType& expected_kernel_key,
const phi::DenseTensor* tensor,
const std::string& var_name,
std::string* new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes,
......@@ -79,28 +80,28 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
framework::Scope* local_scope,
bool skip_run = false);
inline bool need_device_transform(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
auto& src_place = kernel_type_for_var.place_;
auto& dst_place = expected_kernel_key.place_;
if (platform::is_same_place(src_place, dst_place) ||
(platform::is_cuda_pinned_place(src_place) &&
platform::is_cpu_place(dst_place))) {
inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const phi::DenseTensor* tensor,
const phi::Place& expected_place) {
if (kernel_type_for_var.backend() == phi::Backend::ALL_BACKEND ||
platform::is_same_place(tensor->place(), expected_place) ||
(platform::is_cuda_pinned_place(tensor->place()) &&
platform::is_cpu_place(expected_place))) {
return false;
}
return true;
}
inline bool need_dtype_transform(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
inline bool need_dtype_transform(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_key) {
return framework::NeedTransformDataType(kernel_type_for_var,
expected_kernel_key);
}
inline bool need_layout_transform(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
return framework::NeedTransformLayout(kernel_type_for_var.data_layout_,
expected_kernel_key.data_layout_);
inline bool need_layout_transform(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_key) {
return framework::NeedTransformLayout(kernel_type_for_var.layout(),
expected_kernel_key.layout());
}
std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
......
......@@ -730,8 +730,8 @@ bool BuildOpFuncList(const platform::Place& place,
auto* dev_ctx = pool.Get(place);
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
auto expected_kernel_key = framework::TransPhiKernelKeyToOpKernelType(
op_with_kernel->GetExpectedKernelType(exec_ctx));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (op_with_kernel->CanCUDNNBeUsed(exec_ctx,
expected_kernel_key.data_type_)) {
......@@ -741,6 +741,10 @@ bool BuildOpFuncList(const platform::Place& place,
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key);
if (platform::places_are_same_class(exec_ctx.GetPlace(),
expected_kernel_key.place_)) {
expected_kernel_key.place_ = exec_ctx.GetPlace();
}
// step 2. select op kernel
auto run_phi_kernel = false;
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace framework {
......@@ -108,15 +109,32 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return ret;
}
inline bool NeedTransformDataType(const OpKernelType& l,
const OpKernelType& r) {
return (l.data_type_ != r.data_type_);
inline bool NeedTransformDataType(const phi::KernelKey& l,
const phi::KernelKey& r) {
return l.dtype() != phi::DataType::ALL_DTYPE &&
r.dtype() != phi::DataType::ALL_DTYPE && l.dtype() != r.dtype();
}
inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) {
return (!platform::places_are_same_class(l.place_, r.place_)) ||
(l.data_type_ != r.data_type_) ||
NeedTransformLayout(l.data_layout_, r.data_layout_);
inline bool backends_are_same_class(const phi::Backend& l,
const phi::Backend& r) {
if (l == phi::Backend::ALL_BACKEND || r == phi::Backend::ALL_BACKEND) {
return true;
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t num_backends = static_cast<size_t>(phi::Backend::NUM_BACKENDS);
if (static_cast<size_t>(l) > num_backends &&
static_cast<size_t>(r) > num_backends) {
return phi::TransToPhiPlace(l).GetDeviceType() ==
phi::TransToPhiPlace(r).GetDeviceType();
}
#endif
return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r);
}
inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) {
return !backends_are_same_class(l.backend(), r.backend()) ||
NeedTransformDataType(l, r) ||
NeedTransformLayout(l.layout(), r.layout());
}
} // namespace framework
......
......@@ -214,9 +214,10 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
return phi::KernelKey(proto::VarType::FP32,
ctx.device_context().GetPlace());
}
};
......@@ -275,12 +276,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::VarType::FP32,
platform::CUDAPlace(0),
DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
return phi::KernelKey(phi::Backend::GPUDNN,
phi::DataLayout::ALL_LAYOUT,
phi::DataType::FLOAT32);
}
};
......
......@@ -1380,8 +1380,7 @@ bool OperatorWithKernel::SupportXPU() const {
#endif
}
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
bool OperatorWithKernel::SupportsMKLDNN(const phi::DataType data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
auto has_phi_kernel =
......@@ -1389,8 +1388,7 @@ bool OperatorWithKernel::SupportsMKLDNN(
phi_kernels.end(),
[data_type](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::ONEDNN &&
kern_pair.first.dtype() ==
framework::TransToPhiDataType(data_type);
kern_pair.first.dtype() == data_type;
});
if (has_phi_kernel) {
return true;
......@@ -1406,25 +1404,22 @@ bool OperatorWithKernel::SupportsMKLDNN(
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
kern_pair.first.data_type_ == TransToProtoVarType(data_type);
});
}
}
}
bool OperatorWithKernel::SupportsCUDNN(
const proto::VarType::Type data_type) const {
bool OperatorWithKernel::SupportsCUDNN(const phi::DataType data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
paddle::experimental::DataType phi_data_type =
framework::TransToPhiDataType(data_type);
auto has_phi_kernel = std::any_of(
phi_kernels.begin(),
phi_kernels.end(),
[phi_data_type](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::GPUDNN &&
kern_pair.first.dtype() == phi_data_type;
});
auto has_phi_kernel =
std::any_of(phi_kernels.begin(),
phi_kernels.end(),
[data_type](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::GPUDNN &&
kern_pair.first.dtype() == data_type;
});
if (has_phi_kernel) {
return true;
} else {
......@@ -1433,13 +1428,15 @@ bool OperatorWithKernel::SupportsCUDNN(
return false;
} else {
auto& op_kernels = op_kernel_iter->second;
proto::VarType::Type fluid_data_type =
framework::TransToProtoVarType(data_type);
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) {
[fluid_data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kCUDNN &&
kern_pair.first.data_type_ == data_type;
kern_pair.first.data_type_ == fluid_data_type;
});
}
}
......@@ -1509,14 +1506,19 @@ bool OperatorWithKernel::SupportsKernelType(
}
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
phi::DataType data_type) const {
return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") &&
platform::is_cpu_place(ctx.GetPlace()) &&
this->SupportsMKLDNN(data_type);
}
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
return this->CanMKLDNNBeUsed(ctx, phi::TransToPhiDataType(data_type));
}
bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
phi::DataType data_type) const {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
paddle::platform::is_gpu_place(ctx.GetPlace());
......@@ -1528,7 +1530,7 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
#if defined(PADDLE_WITH_CUDA)
if (use_cudnn && data_type == framework::proto::VarType::BF16) {
if (use_cudnn && data_type == phi::DataType::BFLOAT16) {
PADDLE_ENFORCE_GE(
platform::DnnVersion(),
8100,
......@@ -1540,6 +1542,11 @@ bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
return use_cudnn && this->SupportsCUDNN(data_type);
}
bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
return this->CanCUDNNBeUsed(ctx, phi::TransToPhiDataType(data_type));
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
......@@ -1839,8 +1846,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
1,
platform::EventRole::kInnerOp);
if (need_prepare_data_) {
transfer_scope = PrepareData(
scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx);
transfer_scope =
PrepareData(scope,
framework::TransOpKernelTypeToPhiKernelKey(*kernel_type_),
&transfered_inplace_vars,
runtime_ctx,
dev_ctx->GetPlace());
}
}
// exec scope is the scope that kernel actually executed on.
......@@ -1960,7 +1971,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const {
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
phi::KernelKey phi_kernel_key = this->GetExpectedKernelType(ctx);
auto expected_kernel_key =
framework::TransPhiKernelKeyToOpKernelType(phi_kernel_key);
// NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function
// GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and
......@@ -2063,6 +2076,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
}
}
}
if (platform::places_are_same_class(expected_kernel_key.place_,
ctx.GetPlace())) {
expected_kernel_key.place_ = ctx.GetPlace();
}
VLOG(3) << "op type:" << type_
<< ", expected_kernel_key:" << expected_kernel_key;
return expected_kernel_key;
......@@ -2333,9 +2352,10 @@ void OperatorWithKernel::HandleComplexGradToRealGrad(
Scope* OperatorWithKernel::PrepareData(
const Scope& scope,
const OpKernelType& expected_kernel_key,
const phi::KernelKey& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const {
RuntimeContext* ctx,
const phi::Place& place) const {
Scope* new_scope = nullptr;
const std::unordered_set<std::string>* no_buffer_ins = nullptr;
......@@ -2378,7 +2398,7 @@ Scope* OperatorWithKernel::PrepareData(
// has to be created and registered
if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) &&
(expected_kernel_key.layout() != DataLayout::ONEDNN) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
DataLayout::kNHWC) &&
(tensor_in->dims().size() >= 3)) {
......@@ -2411,35 +2431,33 @@ Scope* OperatorWithKernel::PrepareData(
auto kernel_type_for_var =
GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key);
bool need_trans_dtype =
kernel_type_for_var.data_type_ != expected_kernel_key.data_type_;
NeedTransformDataType(expected_kernel_key, kernel_type_for_var);
bool need_trans_layout = NeedTransformLayout(
kernel_type_for_var.data_layout_, expected_kernel_key.data_layout_);
kernel_type_for_var.layout(), expected_kernel_key.layout());
if (!need_trans_dtype && !need_trans_layout) {
if (!run_phi_kernel_ &&
platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_key.place_)) {
backends_are_same_class(kernel_type_for_var.backend(),
expected_kernel_key.backend())) {
continue;
}
}
std::unique_ptr<OpKernelType> new_expected_kernel_key = nullptr;
std::unique_ptr<phi::KernelKey> new_expected_kernel_key = nullptr;
if (run_phi_kernel_ && in_def != nullptr &&
in_def->backend != phi::Backend::ALL_BACKEND) {
auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if ((in_def->backend != tensor_backend &&
(in_def->backend != phi::Backend::GPUDNN ||
tensor_backend != phi::Backend::GPU) &&
(in_def->backend != phi::Backend::KPS ||
tensor_backend != phi::Backend::XPU) &&
(in_def->backend != phi::Backend::ONEDNN ||
tensor_backend != phi::Backend::CPU)) ||
!(in_def->backend == phi::Backend::GPUDNN &&
tensor_backend == phi::Backend::GPU) &&
!(in_def->backend == phi::Backend::KPS &&
tensor_backend == phi::Backend::XPU) &&
!(in_def->backend == phi::Backend::ONEDNN &&
tensor_backend == phi::Backend::CPU)) ||
tensor_in->place().GetType() == AllocationType::GPUPINNED) {
new_expected_kernel_key = std::make_unique<OpKernelType>(
expected_kernel_key.data_type_,
phi::TransToPhiPlace(in_def->backend),
expected_kernel_key.data_layout_,
expected_kernel_key.library_type_,
expected_kernel_key.customized_type_value_);
new_expected_kernel_key =
std::make_unique<phi::KernelKey>(in_def->backend,
expected_kernel_key.layout(),
expected_kernel_key.dtype());
}
}
......@@ -2474,14 +2492,18 @@ Scope* OperatorWithKernel::PrepareData(
enable_cache_transfer_scope_ = false;
if (!run_by_executor_) {
if (new_expected_kernel_key) {
if ((platform::is_gpu_place(kernel_type_for_var.place_) ||
platform::is_gpu_place(new_expected_kernel_key->place_))) {
if (kernel_type_for_var.backend() == phi::Backend::GPU ||
kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
new_expected_kernel_key->backend() == phi::Backend::GPU ||
new_expected_kernel_key->backend() == phi::Backend::GPUDNN) {
new_scope = TryCreateTransferScope(
kernel_type_for_var, *new_expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true;
}
} else if ((platform::is_gpu_place(kernel_type_for_var.place_) ||
platform::is_gpu_place(expected_kernel_key.place_))) {
} else if (kernel_type_for_var.backend() == phi::Backend::GPU ||
kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
expected_kernel_key.backend() == phi::Backend::GPU ||
expected_kernel_key.backend() == phi::Backend::GPUDNN) {
new_scope = TryCreateTransferScope(
kernel_type_for_var, expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true;
......@@ -2523,11 +2545,15 @@ Scope* OperatorWithKernel::PrepareData(
// Do transfer
phi::DenseTensor out;
TransformData(new_expected_kernel_key ? *new_expected_kernel_key
: expected_kernel_key,
kernel_type_for_var,
*tensor_in,
&out);
TransformData(
new_expected_kernel_key ? *new_expected_kernel_key
: expected_kernel_key,
kernel_type_for_var,
*tensor_in,
&out,
new_expected_kernel_key
? phi::TransToPhiPlace(new_expected_kernel_key->backend())
: place);
SetTensorToVariable(*var, out, trans_var);
}
};
......@@ -2818,30 +2844,29 @@ proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes(
return target_type;
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
phi::KernelKey OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
return phi::KernelKey(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetKernelTypeForVar(
phi::KernelKey OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const {
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// When the op is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
phi::DataLayout::kNHWC);
return phi::KernelKey(
tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
}
#endif
return OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
......
......@@ -638,16 +638,22 @@ class OperatorWithKernel : public OperatorBase {
bool SupportXPU() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsMKLDNN(phi::DataType data_type) const;
bool SupportsCUDNN(proto::VarType::Type data_type) const;
bool SupportsCUDNN(phi::DataType data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type,
const ExecutionContext& exe_ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
phi::DataType data_type) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
phi::DataType data_type) const;
bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
......@@ -665,14 +671,15 @@ class OperatorWithKernel : public OperatorBase {
const std::string& name1,
const std::string& name2) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const;
// change this to public so that in dygraph mode we can call it to check if we
// need transform data
virtual OpKernelType GetKernelTypeForVar(
virtual phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const;
const phi::KernelKey& expected_kernel_type) const;
platform::Place GetExecutionPlace(
const platform::Place& platform) const override {
......@@ -734,9 +741,10 @@ class OperatorWithKernel : public OperatorBase {
* transfered_inplace_vars is a output vector.
*/
Scope* PrepareData(const Scope& scope,
const OpKernelType& expected_kernel_key,
const phi::KernelKey& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const;
RuntimeContext* ctx,
const phi::Place& place) const;
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
......
......@@ -127,14 +127,10 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
int sub_type = ctx.Attr<int>("kernel_sub_type");
return OpKernelType(proto::VarType::FP32,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kPlain,
sub_type);
return phi::KernelKey(
ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
}
};
......@@ -256,16 +252,6 @@ TEST(OpKernel, all) {
// kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);
attr = op_desc.mutable_attrs()->Add();
attr->set_name("kernel_sub_type");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
op2->Run(scope, cpu_place);
// kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
}
REGISTER_OP_WITHOUT_GRADIENT(
......@@ -339,11 +325,11 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "phi::DenseTensor");
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
......@@ -361,11 +347,11 @@ class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
class IndicateSelectedRowsDataTypeTestProtoMaker
......@@ -383,10 +369,10 @@ class IndicateOtherDataTypeTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
......@@ -597,10 +583,10 @@ class OpUnusedVarTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(
proto::VarType::FP32, ctx.GetPlace(), phi::DataLayout::kAnyLayout);
return phi::KernelKey(
ctx.GetPlace(), phi::DataLayout::ALL_LAYOUT, phi::DataType::FLOAT32);
}
};
......
......@@ -34,12 +34,13 @@ global_transfer_scope_key() {
return *x;
}
Scope* TryCreateTransferScope(OpKernelType type0,
OpKernelType type1,
Scope* TryCreateTransferScope(const phi::KernelKey& type0,
const phi::KernelKey& type1,
const Scope* scope) {
Scope* new_scope{nullptr};
size_t infer_cache_key =
CombineHash(OpKernelType::Hash()(type0), OpKernelType::Hash()(type1));
CombineHash(static_cast<size_t>(phi::KernelKey::Hash()(type0)),
static_cast<size_t>(phi::KernelKey::Hash()(type1)));
infer_cache_key =
CombineHash(infer_cache_key, std::hash<const Scope*>()(scope));
......
......@@ -39,8 +39,8 @@ static size_t CombineHash(size_t seed, size_t a) {
return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
Scope* TryCreateTransferScope(OpKernelType type0,
OpKernelType type1,
Scope* TryCreateTransferScope(const phi::KernelKey& type0,
const phi::KernelKey& type1,
const Scope* scope);
} // namespace framework
......
......@@ -23,7 +23,8 @@ if(WITH_XPU)
scalar
int_array
var_helper
profiler)
profiler
place)
else()
cc_library(
prepared_operator
......@@ -40,7 +41,8 @@ else()
scalar
int_array
var_helper
profiler)
profiler
place)
endif()
cc_library(
layer
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace imperative {
......@@ -39,7 +40,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr,
const phi::KernelKey* op_kernel_key = nullptr,
const phi::ArgumentMappingFn* arg_map_fn = nullptr,
const phi::KernelSignature* default_kernel_signature = nullptr)
: var_map_in_(in),
......@@ -47,7 +48,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type),
op_kernel_type_(op_kernel_type),
op_kernel_key_(op_kernel_key),
arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature) {}
......@@ -250,8 +251,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ &&
(op_kernel_type_->data_layout_ == phi::DataLayout::ONEDNN));
return (op_kernel_key_ &&
(op_kernel_key_->layout() == phi::DataLayout::ONEDNN));
}
paddle::small_vector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
......@@ -497,7 +498,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
const framework::OpKernelType* op_kernel_type_;
const phi::KernelKey* op_kernel_key_;
// arg_map_fn_ and default_kernel_signature_ may be nullptr
const phi::ArgumentMappingFn* arg_map_fn_;
const phi::KernelSignature* default_kernel_signature_;
......
......@@ -519,8 +519,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
*/
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
auto tmp_ins_ptr = PrepareData<VarType>(
*op_kernel, ins, prepared_op.kernel_key(), prepared_op.place());
if (tmp_ins_ptr == nullptr) {
prepared_op.Run(ins, outs, attrs, default_attrs);
} else {
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h"
......@@ -75,7 +76,8 @@ template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData(
const framework::OperatorWithKernel& op,
const NameVarMap<VarType>& ins,
const framework::OpKernelType& expected_kernel_key) {
const phi::KernelKey& expected_kernel_key,
const phi::Place& place) {
std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr;
for (const auto& name_pair : ins) {
for (size_t i = 0; i < name_pair.second.size(); ++i) {
......@@ -85,7 +87,8 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) {
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
if (!framework::NeedTransform(kernel_type_for_var,
expected_kernel_key)) {
continue;
} else {
VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
......@@ -111,10 +114,10 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
} else {
phi::DenseTensor out;
TransformData(
expected_kernel_key, kernel_type_for_var, *tensor, &out);
if (NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) {
framework::TransformData(
expected_kernel_key, kernel_type_for_var, *tensor, &out, place);
if (framework::NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input
// directly
......@@ -149,7 +152,7 @@ class PreparedOp {
public:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const phi::KernelKey& kernel_key,
const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
......@@ -157,7 +160,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const phi::KernelKey& kernel_key,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature,
......@@ -200,12 +203,14 @@ class PreparedOp {
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
const framework::OpKernelType& kernel_type() const { return kernel_type_; }
const phi::KernelKey& kernel_key() const { return kernel_key_; }
const phi::Place& place() const { return dev_ctx_->GetPlace(); }
private:
const framework::OperatorBase& op_;
const framework::RuntimeContext& ctx_;
framework::OpKernelType kernel_type_;
phi::KernelKey kernel_key_;
framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to
......
......@@ -92,15 +92,15 @@ TEST(test_var_helper, eager_var_helper) {
ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerVariable>(egr_tensor)));
ASSERT_TRUE(GetDataType<egr::EagerVariable>(egr_tensor) ==
framework::proto::VarType::FP32);
GetCachedValue<egr::EagerVariable>(
egr_tensor,
framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()));
SetCachedValue<egr::EagerVariable>(
egr_tensor,
framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()),
egr_tensor2);
GetCachedValue<egr::EagerVariable>(egr_tensor,
phi::KernelKey(phi::Backend::CPU,
phi::DataLayout::ALL_LAYOUT,
phi::DataType::FLOAT32));
SetCachedValue<egr::EagerVariable>(egr_tensor,
phi::KernelKey(phi::Backend::CPU,
phi::DataLayout::ALL_LAYOUT,
phi::DataType::FLOAT32),
egr_tensor2);
ASSERT_ANY_THROW(GetPlace<egr::EagerVariable>(egr_tensor2));
ASSERT_ANY_THROW(SetType<egr::EagerVariable>(
egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY));
......
......@@ -172,7 +172,8 @@ TEST(test_prepare_op, test_prepare_data) {
PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op),
ins,
prepared_op.kernel_type());
prepared_op.kernel_key(),
gpu_place);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......@@ -229,7 +230,8 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op),
ins,
prepared_op.kernel_type());
prepared_op.kernel_key(),
cpu_place);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......
......@@ -239,35 +239,31 @@ template void SetDataLayout<VariableWrapper>(
/* CheckCachedKey */
template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> var,
const paddle::framework::OpKernelType &key) {
bool CheckCachedKey(std::shared_ptr<VarType> var, const phi::KernelKey &key) {
return GetVariableWrapper(var)->hasCacheKey(key);
}
template <>
bool CheckCachedKey<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> tensor,
const paddle::framework::OpKernelType &key) {
std::shared_ptr<egr::EagerVariable> tensor, const phi::KernelKey &key) {
// TODO(jiabin): Support this later
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is
// equal to self: " << key == key.
return false;
}
template bool CheckCachedKey<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key);
template bool CheckCachedKey<VarBase>(std::shared_ptr<VarBase> var,
const phi::KernelKey &key);
template bool CheckCachedKey<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key);
std::shared_ptr<VariableWrapper> var, const phi::KernelKey &key);
/* GetCachedValue */
template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<VarType> var, const paddle::framework::OpKernelType &key) {
std::shared_ptr<VariableWrapper> GetCachedValue(std::shared_ptr<VarType> var,
const phi::KernelKey &key) {
return GetVariableWrapper(var)->getCacheValue(key);
}
template <>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<egr::EagerVariable> var,
const paddle::framework::OpKernelType &key) {
std::shared_ptr<egr::EagerVariable> var, const phi::KernelKey &key) {
// TODO(jiabin): Support this later
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this
......@@ -277,22 +273,21 @@ std::shared_ptr<VariableWrapper> GetCachedValue(
return std::make_shared<VariableWrapper>("");
}
template std::shared_ptr<VariableWrapper> GetCachedValue<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key);
std::shared_ptr<VarBase> var, const phi::KernelKey &key);
template std::shared_ptr<VariableWrapper> GetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key);
std::shared_ptr<VariableWrapper> var, const phi::KernelKey &key);
/* SetCachedValue */
template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> var,
const paddle::framework::OpKernelType &key,
const phi::KernelKey &key,
std::shared_ptr<VarType> res) {
GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res));
}
template <>
void SetCachedValue<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> tensor,
const paddle::framework::OpKernelType &key,
const phi::KernelKey &key,
std::shared_ptr<egr::EagerVariable> res) {
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this
......@@ -300,13 +295,12 @@ void SetCachedValue<egr::EagerVariable>(
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key
// is equal to self: " << key == key << " and res name is:" << res->Name().
}
template void SetCachedValue<VarBase>(
std::shared_ptr<VarBase> var,
const paddle::framework::OpKernelType &key,
std::shared_ptr<VarBase> res);
template void SetCachedValue<VarBase>(std::shared_ptr<VarBase> var,
const phi::KernelKey &key,
std::shared_ptr<VarBase> res);
template void SetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key,
const phi::KernelKey &key,
std::shared_ptr<VariableWrapper> res);
} // namespace imperative
} // namespace paddle
......@@ -43,16 +43,14 @@ template <typename VarType>
const std::string& GetNameFromVar(std::shared_ptr<VarType> var);
template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key);
bool CheckCachedKey(std::shared_ptr<VarType> tensor, const phi::KernelKey& key);
template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key,
const phi::KernelKey& key,
std::shared_ptr<VarType> res);
template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key);
std::shared_ptr<VariableWrapper> GetCachedValue(std::shared_ptr<VarType> tensor,
const phi::KernelKey& key);
template <typename VarType>
void SetType(std::shared_ptr<VarType> var,
......
......@@ -234,16 +234,15 @@ class VariableWrapper {
}
}
bool hasCacheKey(const paddle::framework::OpKernelType& key) {
bool hasCacheKey(const phi::KernelKey& key) {
return var_cache.find(key) != var_cache.end();
}
std::shared_ptr<VariableWrapper> getCacheValue(
const paddle::framework::OpKernelType& key) {
std::shared_ptr<VariableWrapper> getCacheValue(const phi::KernelKey& key) {
return var_cache[key];
}
void setCacheValue(const paddle::framework::OpKernelType& key,
void setCacheValue(const phi::KernelKey& key,
std::shared_ptr<VariableWrapper> val) {
var_cache[key] = val;
return;
......@@ -323,8 +322,7 @@ class VariableWrapper {
// Used for cache the dtype promotioned variableWrapper in real and complex
// compute of Paddle Quantum
std::map<paddle::framework::OpKernelType, std::shared_ptr<VariableWrapper>>
var_cache;
std::map<phi::KernelKey, std::shared_ptr<VariableWrapper>> var_cache;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
......
......@@ -29,11 +29,11 @@ class AbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......@@ -70,11 +70,11 @@ class AbsGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......@@ -124,20 +124,17 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
return framework::OpKernelType(dtype, ctx.GetPlace());
return phi::KernelKey(dtype, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
framework::TransToProtoVarType(tensor.dtype()),
tensor.place(),
tensor.layout());
const phi::KernelKey& expected_kernel_type) const override {
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
}
};
......
......@@ -80,9 +80,9 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
......@@ -94,7 +94,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
// library = framework::LibraryType::kCUDNN;
// }
// #endif
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
class ActivationOp : public framework::OperatorWithKernel {
......@@ -107,7 +107,7 @@ class ActivationOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
......@@ -134,7 +134,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
......@@ -341,7 +341,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
......@@ -370,7 +370,7 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
......@@ -411,7 +411,7 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
......@@ -487,20 +487,22 @@ class PowOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "FactorTensor") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
......@@ -515,20 +517,22 @@ class PowOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "FactorTensor") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
......@@ -537,7 +541,7 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
......@@ -548,7 +552,7 @@ class PowOpTripleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
......
......@@ -34,11 +34,10 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -54,11 +53,11 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
platform::CPUPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
platform::CPUPlace());
}
};
......
......@@ -145,11 +145,11 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -130,10 +130,10 @@ class AffineGridOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
......@@ -241,11 +241,11 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output"));
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
......
......@@ -65,11 +65,10 @@ class AllcloseOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context().GetPlace());
}
};
......
......@@ -34,10 +34,9 @@ class AllocFloatStatusOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -29,13 +29,13 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
return phi::KernelKey(dtype, ctx.GetPlace());
}
};
......
......@@ -34,10 +34,9 @@ class ClearFloatStatusOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -34,10 +34,9 @@ class GetFloatStatusOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -29,23 +29,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
return phi::KernelKey(dtype, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
const phi::KernelKey& expected_kernel_type) const override {
#ifndef PADDLE_WITH_XPU
if (var_name == "FoundInfinite" || var_name == "StopUpdate") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
#endif
return framework::OperatorWithKernel::GetKernelTypeForVar(
......
......@@ -32,11 +32,11 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......
......@@ -23,10 +23,10 @@ class AscendTriggerOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context().GetPlace());
}
};
......
......@@ -41,16 +41,16 @@ class AssignOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
const phi::KernelKey &expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::Variable *var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
......@@ -58,14 +58,13 @@ class AssignOp : public framework::OperatorWithKernel {
// NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context().GetPlace());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
......
......@@ -31,7 +31,7 @@ class AssignPosOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto cum_count_dtype =
OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
......@@ -46,7 +46,7 @@ class AssignPosOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"The dtype of the cum_count_dtype, eff_num_len and "
"X should be same as int64"));
return framework::OpKernelType(cum_count_dtype, ctx.device_context());
return phi::KernelKey(cum_count_dtype, ctx.device_context().GetPlace());
}
};
......
......@@ -44,9 +44,9 @@ class AssignValueOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
return phi::KernelKey(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
......
......@@ -198,10 +198,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Cell");
}
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
phi::KernelKey AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
void AttentionLSTMOpMaker::Make() {
......
......@@ -25,7 +25,7 @@ class AttentionLSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
......
......@@ -26,10 +26,10 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "param"),
ctx.GetPlace());
}
};
......
......@@ -77,11 +77,10 @@ class BatchFCOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context().GetPlace());
}
};
......@@ -106,11 +105,11 @@ class BatchFCGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context().GetPlace());
}
};
......
......@@ -171,7 +171,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
}
}
framework::OpKernelType BatchNormOp::GetExpectedKernelType(
phi::KernelKey BatchNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
......@@ -202,18 +202,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
platform::errors::InvalidArgument(
"Variance input should be of float type"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
phi::KernelKey BatchNormOp::GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
const phi::KernelKey &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "X") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
......@@ -222,13 +222,12 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), dl);
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
void BatchNormOpMaker::Make() {
......@@ -373,7 +372,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
}
}
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
phi::KernelKey BatchNormGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......@@ -392,18 +391,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
}
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
phi::KernelKey BatchNormGradOp::GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
const phi::KernelKey &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
......@@ -412,13 +411,12 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), dl);
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
template <typename T>
......@@ -515,7 +513,7 @@ void BatchNormDoubleGradOp::InferShape(
}
}
framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar("DY");
if (var == nullptr) {
......@@ -532,8 +530,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
......
......@@ -47,13 +47,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
const phi::KernelKey& expected_kernel_type) const override;
};
class BatchNormGradOp : public framework::OperatorWithKernel {
......@@ -62,13 +62,13 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
const phi::KernelKey& expected_kernel_type) const override;
};
class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
......@@ -77,7 +77,7 @@ class BatchNormDoubleGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
......
......@@ -28,11 +28,10 @@ class BCELossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
......@@ -87,11 +86,10 @@ class BCELossGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
......
......@@ -108,7 +108,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *scores = ctx.Input<phi::DenseTensor>("scores");
size_t level = ctx.Attr<int>("level");
......@@ -116,11 +116,11 @@ class BeamSearchOp : public framework::OperatorWithKernel {
// The current CUDA kernel only support cases with batch_size < 4.
// Compute on CPU for cases with batch_size > 4.
if (batch_size <= 4) {
return framework::OpKernelType(
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
ctx.GetPlace());
} else {
return framework::OpKernelType(
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
platform::CPUPlace());
}
......
......@@ -85,10 +85,10 @@ class BilateralSliceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......@@ -147,11 +147,11 @@ class BilateralSliceOpGrad : public framework::OperatorWithKernel {
}
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -24,19 +24,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::OpKernelType;
class BincountOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type =
ctx.HasInput("Weights")
? OperatorWithKernel::IndicateVarDataType(ctx, "Weights")
: OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.device_context().GetPlace());
}
};
......
......@@ -56,11 +56,10 @@ class BprLossOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of Seq-bpr
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -119,11 +118,10 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -27,14 +27,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Broadcast semantics enforces all input variables having the same
// DataType/VarType
// This condition is also checked during VarType Inference
// Here we simply copy input type to output
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......@@ -127,11 +127,11 @@ class BroadcastTensorsGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context().GetPlace());
}
};
......
......@@ -75,7 +75,7 @@ class CastOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// CastOp kernel's device type is decided by input tensor place
auto *tensor = ctx.Input<phi::DenseTensor>("X");
......@@ -86,9 +86,8 @@ class CastOp : public framework::OperatorWithKernel {
auto &tensor_place = tensor->place();
// NOTE: cuda pinned tensor need to copy its data to target place
if (platform::is_cuda_pinned_place(tensor_place)) {
return framework::OpKernelType(
framework::TransToProtoVarType(tensor->dtype()),
ctx.device_context());
return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
ctx.device_context().GetPlace());
}
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
......@@ -108,20 +107,19 @@ class CastOp : public framework::OperatorWithKernel {
auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype"));
auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype"));
if (src_type == dst_type || MLUSupportsCast(src_type, dst_type)) {
return framework::OpKernelType(
framework::TransToProtoVarType(tensor->dtype()), tensor_place);
return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
tensor_place);
} else {
VLOG(3) << "MLU not support cast type: "
<< framework::DataTypeToString(src_type)
<< " to type: " << framework::DataTypeToString(dst_type)
<< ", fallbacking to CPU one!";
return framework::OpKernelType(
framework::TransToProtoVarType(tensor->dtype()),
platform::CPUPlace());
return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
platform::CPUPlace());
}
#endif
return framework::OpKernelType(
framework::TransToProtoVarType(tensor->dtype()), tensor_place);
return phi::KernelKey(framework::TransToProtoVarType(tensor->dtype()),
tensor_place);
}
};
......
......@@ -53,11 +53,10 @@ class CenterLossOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
......@@ -115,11 +114,11 @@ class CenterLossGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"),
ctx.device_context());
ctx.device_context().GetPlace());
}
};
......
......@@ -88,10 +88,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace());
}
};
......
......@@ -57,10 +57,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
* specified a data type here.
*
*/
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -117,10 +117,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
* Of course, the data type here is also not important.
*/
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -26,11 +26,10 @@ class ClassCenterSampleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Label"),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Label"),
ctx.device_context().GetPlace());
}
};
......
......@@ -26,11 +26,11 @@ namespace operators {
class ClipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......@@ -85,11 +85,11 @@ class ClipOpGrad : public framework::OperatorWithKernel {
}
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......
......@@ -405,20 +405,20 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &context) const override {
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
return framework::OpKernelType(dtype, context.GetPlace());
return phi::KernelKey(dtype, context.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
const phi::KernelKey &expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
};
......
......@@ -27,10 +27,10 @@ class AllReduceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -36,10 +36,10 @@ class AllToAllOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -71,21 +71,23 @@ class CAllReduceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
const phi::KernelKey& expected_kernel_type) const {
if (var_name == "Cond") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
} else {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
};
......
......@@ -26,10 +26,10 @@ class CBroadcastOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -58,10 +58,10 @@ class CConcatOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -65,10 +65,10 @@ class CEmbeddingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
......@@ -149,11 +149,11 @@ class CEmbeddingOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
......
......@@ -35,10 +35,10 @@ class CIdentityOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -66,10 +66,10 @@ class CReduceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -52,10 +52,10 @@ class CScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -73,11 +73,10 @@ class CSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace());
}
};
......@@ -150,11 +149,11 @@ class CSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.device_context());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.GetPlace());
}
};
......
......@@ -66,10 +66,10 @@ class CSplitOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
......@@ -25,10 +28,9 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -23,10 +23,9 @@ class CSyncCommStreamOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -49,10 +49,10 @@ class GlobalGatherOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -52,10 +52,10 @@ class GlobalScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -80,12 +80,12 @@ class PartialRecvOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(dtype);
return framework::OpKernelType(type, ctx.GetPlace());
return phi::KernelKey(type, ctx.GetPlace());
}
};
......
......@@ -51,10 +51,10 @@ class PartialSendOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -69,12 +69,12 @@ class RecvOpV2 : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(dtype);
return framework::OpKernelType(type, ctx.GetPlace());
return phi::KernelKey(type, ctx.GetPlace());
}
};
......
......@@ -38,7 +38,7 @@ class SendOpV2 : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::Variable* var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
......@@ -46,12 +46,11 @@ class SendOpV2 : public framework::OperatorWithKernel {
// NOTE(sandyhouse): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -32,7 +32,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<phi::DenseTensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
......@@ -48,18 +48,20 @@ class ConcatOp : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
......@@ -110,22 +112,24 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
const phi::KernelKey &expected_kernel_type) const override {
if (var_name == "AxisTensor") {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
......
......@@ -97,11 +97,12 @@ class UnaryBitwiseOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place();
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt;
}
};
......@@ -138,11 +139,12 @@ class BinaryBitwiseOp : public framework::OperatorWithKernel {
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place();
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt;
}
};
......
......@@ -61,19 +61,20 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
bool force_cpu = ctx.Attr<bool>("force_cpu");
if (force_cpu) {
kt.place_ = platform::CPUPlace();
kt.set_backend(phi::Backend::CPU);
} else {
if (ctx.Input<phi::DenseTensor>("X")->place().GetType() !=
phi::AllocationType::GPUPINNED) {
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place();
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
} else {
kt.place_ = ctx.GetPlace();
kt.set_backend(phi::TransToPhiBackend(ctx.GetPlace()));
}
}
return kt;
......
......@@ -72,48 +72,49 @@ class FetchV2Op : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
const phi::KernelKey &expected_kernel_type) const override {
if (!tensor.IsInitialized()) {
return expected_kernel_type;
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *fetch_var = ctx.InputVar("X");
if (fetch_var == nullptr) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace());
}
if (fetch_var->IsType<phi::DenseTensor>()) {
auto &src_item = fetch_var->Get<phi::DenseTensor>();
if (!src_item.IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
if (!src_item.initialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
if (src_item.empty() || !src_item[0].IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
return phi::KernelKey(framework::proto::VarType::FP32,
platform::CPUPlace());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -69,11 +69,12 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<phi::DenseTensor>("X")->place();
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt;
}
};
......
......@@ -186,7 +186,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(
return output_shape;
}
framework::OpKernelType ConvOp::GetExpectedKernelType(
phi::KernelKey ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
// todo enable data layout when it's ready
......@@ -208,18 +208,18 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
paddle::framework::DataTypeToString(filter_data_type)));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
framework::OpKernelType ConvOp::GetKernelTypeForVar(
phi::KernelKey ConvOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
......@@ -228,13 +228,12 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for conv
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), dl);
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
void Conv2DOpMaker::Make() {
......@@ -447,23 +446,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
}
}
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
phi::KernelKey ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
phi::KernelKey ConvOpGrad::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
......@@ -472,13 +471,12 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), dl);
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
template <typename T>
......@@ -619,10 +617,10 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
}
}
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
} // namespace operators
......
......@@ -196,13 +196,13 @@ class ConvOp : public framework::OperatorWithKernel {
std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const;
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvOpGrad : public framework::OperatorWithKernel {
......@@ -211,13 +211,13 @@ class ConvOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvOpDoubleGrad : public framework::OperatorWithKernel {
......@@ -226,7 +226,7 @@ class ConvOpDoubleGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
......
......@@ -33,21 +33,21 @@ namespace operators {
using DataLayout = phi::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
phi::KernelKey ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
phi::KernelKey ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
......@@ -56,13 +56,12 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), dl);
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
void Conv2DTransposeOpMaker::Make() {
......@@ -253,10 +252,10 @@ Example:
)DOC");
}
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
phi::KernelKey ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
template <typename T>
......@@ -320,10 +319,10 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
phi::KernelKey ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(data_type, ctx.GetPlace());
return phi::KernelKey(data_type, ctx.GetPlace());
}
} // namespace operators
......
......@@ -38,13 +38,13 @@ class ConvTransposeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
......@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
......@@ -61,7 +61,7 @@ class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
......
......@@ -109,7 +109,7 @@ class CorrelationOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input1");
......@@ -118,7 +118,7 @@ class CorrelationOp : public framework::OperatorWithKernel {
ctx.Input<phi::DenseTensor>("Input2")->dtype()),
platform::errors::InvalidArgument(
"X and Y shoule have the same datatype"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
......@@ -158,9 +158,9 @@ class CorrelationOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace());
}
};
......
......@@ -202,9 +202,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetExpectedKernelType(
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace());
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册