未验证 提交 5a39365a 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Remove OneDNN code in Transpose infershape (#50836)

* remove transpose infershape

* fix ci bugs

* fix ci bugs

* delete transpose infershape

* fix ci bugs

* fix ci bugs
上级 ed19d37f
......@@ -31,7 +31,7 @@ PD_DECLARE_KERNEL(pool2d, OneDNN, ONEDNN);
USE_OP_ITSELF(relu);
PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
PD_DECLARE_KERNEL(transpose, OneDNN, ONEDNN);
USE_OP_ITSELF(shape);
PD_DECLARE_KERNEL(shape, OneDNN, ONEDNN);
USE_OP_ITSELF(crop);
......
......@@ -34,81 +34,6 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Transpose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
int x_rank = x_dims.size();
int axis_size = axis.size();
// Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
// axis_size
PADDLE_ENFORCE_GE(x_rank,
axis_size,
platform::errors::InvalidArgument(
"The input tensor's dimension "
"should be equal to or greater than the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank,
axis_size));
std::vector<int> formated_axis = axis;
std::vector<int> count(axis_size, 0);
for (int i = 0; i < axis_size; i++) {
PADDLE_ENFORCE_LT(axis[i],
axis_size,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
axis_size,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-axis_size,
platform::errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
axis_size,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] = axis[i] + axis_size;
}
PADDLE_ENFORCE_EQ(++count[formated_axis[i]],
1,
platform::errors::InvalidArgument(
"Each element of axis should be unique. but "
"axis[%d] is %d appear not only once",
i,
axis[i]));
}
framework::DDim out_dims(x_dims);
#ifdef PADDLE_WITH_MKLDNN
// Here we need to match dims to paddle layout
// as we are producing non-oneDNN result
if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC)) {
auto dims = phi::vectorize<int>(x_dims);
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
x_dims = x_dims.reshape(dims);
VLOG(3)
<< "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
}
#endif
for (int i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[formated_axis[i]];
}
ctx->SetOutputDim("Out", out_dims);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -218,7 +143,12 @@ class Transpose2Op : public TransposeOp {
: TransposeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
TransposeOp::InferShape(ctx);
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
phi::TransposeInferMeta(x, axis, &out);
if (!ctx->HasOutput("XShape")) return;
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
......@@ -361,6 +291,9 @@ class TransposeGradInferVarType : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(transpose,
TransposeInferShapeFunctor,
PD_INFER_META(phi::TransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad,
TransposeGradInferShapeFunctor,
......@@ -369,13 +302,16 @@ DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad,
DECLARE_INFER_SHAPE_FUNCTOR(transpose2_grad,
Transpose2GradInferShapeFunctor,
PD_INFER_META(phi::TransposeGradInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
transpose,
ops::TransposeOp,
ops::TransposeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
TransposeInferShapeFunctor);
REGISTER_OPERATOR(transpose_grad,
ops::TransposeOpGrad,
ops::TransposeGradInferVarType,
......
......@@ -106,7 +106,7 @@ inline std::string DataLayoutToString(const DataLayout& layout) {
case DataLayout::kAnyLayout:
return "Undefined(AnyLayout)";
case DataLayout::kMKLDNN:
return "MKLDNN";
return "ONEDNN";
case DataLayout::SPARSE_COO:
return "SPARSE_COO";
case DataLayout::SPARSE_CSR:
......
......@@ -4234,7 +4234,9 @@ void TransposeInferMeta(const MetaTensor& x,
int x_rank = x_dims.size();
int axis_size = axis.size();
PADDLE_ENFORCE_EQ(
// Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
// axis_size
PADDLE_ENFORCE_GE(
x_rank,
axis_size,
errors::InvalidArgument("The input tensor's dimension "
......
......@@ -104,6 +104,34 @@ void TransposeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
// Here we need to match dims to paddle layout
// as we are producing non-oneDNN result
auto x_dims = x.dims();
if ((x_dims.size() >= 3) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC)) {
int axis_size = axis.size();
std::vector<int> formated_axis = axis;
std::vector<int> count(axis_size, 0);
for (int i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
formated_axis[i] = axis[i] + axis_size;
}
}
auto dims = phi::vectorize<int>(x_dims);
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
x_dims = x_dims.reshape(dims);
VLOG(3)
<< "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
phi::DDim out_dims(x_dims);
for (size_t i = 0; i < axis.size(); i++) {
out_dims[i] = x_dims[formated_axis[i]];
}
out->Resize(out_dims);
}
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
......
......@@ -38,7 +38,7 @@ TEST(DataLayout, OStream) {
EXPECT_EQ(oss.str(), "NCHW");
oss.str("");
oss << phi::DataLayout::ONEDNN;
EXPECT_EQ(oss.str(), "MKLDNN");
EXPECT_EQ(oss.str(), "ONEDNN");
oss.str("");
try {
oss << phi::DataLayout::NUM_DATA_LAYOUTS;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册