reshape_mkldnn_op.cc 15.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/flatten_op.h"
17
#include "paddle/fluid/operators/squeeze_op.h"
18
#include "paddle/phi/backends/onednn/onednn_reuse.h"
19

20 21 22 23 24 25 26 27 28 29
namespace {
enum class ReshapeKernelOpName {
  reshape,
  reshape2,
  squeeze,
  flatten,
  flatten2,
};
}  // anonymous namespace

30 31 32
namespace paddle {
namespace operators {

J
jakpiase 已提交
33
static std::vector<int> extract_shape(
34
    const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) {
J
jakpiase 已提交
35 36 37 38 39
  std::vector<int> vec_new_shape;
  vec_new_shape.reserve(list_new_shape_tensor.size());

  for (const auto& tensor : list_new_shape_tensor) {
    PADDLE_ENFORCE_EQ(
40 41
        tensor->dims(),
        phi::make_ddim({1}),
J
jakpiase 已提交
42
        platform::errors::InvalidArgument(
43
            "If the element type of 'shape' in ReshapeOp is phi::DenseTensor, "
J
jakpiase 已提交
44 45 46 47 48 49 50 51 52
            "the element's shape must be [1]. But received the element's shape "
            "is [%s]",
            tensor->dims()));
    vec_new_shape.emplace_back(*tensor->data<int32_t>());
  }

  return vec_new_shape;
}

53
template <typename T, ReshapeKernelOpName op_name>
54 55 56 57 58 59 60 61
class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    RunKernel(ctx);
  }

 private:
  void RunKernel(const framework::ExecutionContext& ctx) const {
62
    const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
63 64
    const auto& onednn_engine = dev_ctx.GetEngine();

65 66
    auto* x = ctx.Input<phi::DenseTensor>("X");
    auto* out = ctx.Output<phi::DenseTensor>("Out");
67

68 69
    framework::DDim x_dims, out_dims;
    InferInOutShape(ctx, x_dims, out_dims);
70

71
    auto x_vec_dims = phi::vectorize(x_dims);
72

73 74 75
    auto x_type = phi::funcs ::ToOneDNNDataType(x->dtype());
    phi::funcs::ReorderOneDNNHandler reorder_handler(
        x_vec_dims, x->dtype(), x_type, onednn_engine);
76 77

    auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
78
        x->mem_desc(), phi::funcs::to_void_cast(x->data<T>()));
79 80 81
    out->Resize(x_dims);  // to match x numel, format is changed later
    // reorder is done into a plain tag to allow usage with blocked formats
    auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
82
        out, phi::funcs::GetPlainOneDNNFormat(x_dims.size()), ctx.GetPlace());
83 84
    auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
                                                    reorder_src_memory_p);
85

86
    auto& astream = phi::OneDNNContext::tls().get_stream();
87 88 89 90 91
    reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);

    astream.wait();

    out->Resize(out_dims);
92 93
    out->set_mem_desc(
        reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims)));
94 95
  }

96
  void InferInOutShape(const framework::ExecutionContext& ctx,
97 98
                       framework::DDim& x_dims,            // NOLINT
                       framework::DDim& out_dims) const {  // NOLINT
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    switch (op_name) {
      case ReshapeKernelOpName::reshape:
        InferShapeReshapeOp(ctx, x_dims, out_dims);
        break;
      case ReshapeKernelOpName::squeeze:
        InferShapeSqueezeOp(ctx, x_dims, out_dims);
        break;
      case ReshapeKernelOpName::flatten:
        InferShapeFlattenOp(ctx, x_dims, out_dims);
        break;
      case ReshapeKernelOpName::flatten2:
        InferShapeFlattenOp(ctx, x_dims, out_dims);
        break;
      default:
        PADDLE_THROW(paddle::platform::errors::OutOfRange(
            "Reshape kernel doesn not support that operator name"));
    }
  }

  void InferShapeReshapeOp(const framework::ExecutionContext& ctx,
119 120
                           framework::DDim& x_dims,            // NOLINT
                           framework::DDim& out_dims) const {  // NOLINT
121 122
    auto* x = ctx.Input<phi::DenseTensor>("X");
    auto* out = ctx.Output<phi::DenseTensor>("Out");
123 124 125 126 127 128 129
    x_dims = x->dims();
    out_dims = out->dims();
    ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
  }

  // in reshape1/2 ops  "ShapeTensor" has highest priority and "Shape" has
  // second highest priority
130 131 132 133
  void ChangeReshapeOutDimsIfNeeded(
      const framework::ExecutionContext& ctx,
      framework::DDim& x_dims,            // NOLINT
      framework::DDim& out_dims) const {  // NOLINT
134 135
    auto list_new_shape_tensor =
        ctx.MultiInput<phi::DenseTensor>("ShapeTensor");
136 137 138 139
    if (list_new_shape_tensor.size() > 0) {
      auto new_shape = extract_shape(list_new_shape_tensor);
      out_dims = ValidateShape(new_shape, x_dims);
    } else if (ctx.HasInput("Shape")) {
140
      auto* shape_tensor = ctx.Input<phi::DenseTensor>("Shape");
141 142 143 144 145 146 147 148 149
      auto* shape_data = shape_tensor->data<int>();

      auto shape =
          std::vector<int>(shape_data, shape_data + shape_tensor->numel());
      out_dims = ValidateShape(shape, x_dims);
    }
  }

  void InferShapeSqueezeOp(const framework::ExecutionContext& ctx,
150 151
                           framework::DDim& x_dims,            // NOLINT
                           framework::DDim& out_dims) const {  // NOLINT
152
    auto* x = ctx.Input<phi::DenseTensor>("X");
153 154 155 156 157 158
    x_dims = x->dims();
    const auto& axes = ctx.Attr<std::vector<int>>("axes");
    out_dims = GetOutputShape(axes, x_dims, true);
  }

  void InferShapeFlattenOp(const framework::ExecutionContext& ctx,
159 160
                           framework::DDim& x_dims,            // NOLINT
                           framework::DDim& out_dims) const {  // NOLINT
161
    auto x = ctx.Input<phi::DenseTensor>("X");
162 163
    x_dims = x->dims();
    auto axes = ctx.Attr<int>("axis");
164
    out_dims = phi::make_ddim(
L
Leo Chen 已提交
165
        FlattenKernel<phi::CPUContext, float>::GetOutputShape(axes, x_dims));
166 167
  }

168 169 170
 protected:
  static framework::DDim ValidateShape(const std::vector<int>& shape,
                                       const framework::DDim& in_dims) {
171 172
    const int64_t in_size = phi::product(in_dims);
    auto in_dims_vec = phi::vectorize(in_dims);
173 174
    bool all_positive = std::all_of(in_dims_vec.cbegin(),
                                    in_dims_vec.cend(),
175 176 177 178 179 180 181 182 183 184 185 186
                                    [](int64_t i) { return i > 0; });
    // only one dimension can be set to -1, whose size will be automatically
    // infered
    const int64_t unk_dim_val = -1;
    const int64_t copy_dim_val = 0;

    std::vector<int64_t> output_shape(shape.size(), 0);
    int64_t capacity = 1;
    int unk_dim_idx = -1;
    for (size_t i = 0; i < shape.size(); ++i) {
      if (shape[i] == unk_dim_val) {
        PADDLE_ENFORCE_EQ(
187 188
            unk_dim_idx,
            -1,
189 190 191
            platform::errors::InvalidArgument(
                "Only one dimension value of 'shape' in ReshapeOp can "
                "be -1. But received shape = [%s], shape[%d] is also -1.",
192 193
                phi::make_ddim(shape),
                i));
194 195 196
        unk_dim_idx = i;
      } else if (shape[i] == copy_dim_val) {
        PADDLE_ENFORCE_LT(
197 198
            static_cast<int>(i),
            in_dims.size(),
199 200 201 202 203
            platform::errors::InvalidArgument(
                "The index of 0 in `shape` must be less than "
                "the input tensor X's dimensions. "
                "But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
                "X's dimensions = %d.",
204 205 206 207
                phi::make_ddim(shape),
                i,
                in_dims,
                in_dims.size()));
208 209
      } else {
        PADDLE_ENFORCE_GT(
210 211
            shape[i],
            0,
212 213 214 215
            platform::errors::InvalidArgument(
                "Each dimension value of 'shape' in ReshapeOp must not "
                "be negative except one unknown dimension. "
                "But received  shape = [%s], shape[%d] = %d.",
216 217 218
                phi::make_ddim(shape),
                i,
                shape[i]));
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
      }

      capacity *= (shape[i] ? shape[i] : in_dims[i]);
      output_shape[i] =
          (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
    }

    if (unk_dim_idx != -1) {
      if (all_positive) {
        // in_size < 0 and is un-determinate in compile time, skip the check,
        // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
        // capacity = -24, in_size = -8, output_shape[0] = 0
        // the following check will fail.
        output_shape[unk_dim_idx] = -in_size / capacity;
        PADDLE_ENFORCE_EQ(
234 235
            output_shape[unk_dim_idx] * capacity,
            -in_size,
236 237 238 239 240 241
            platform::errors::InvalidArgument(
                "The 'shape' attribute in ReshapeOp is invalid. "
                "The input tensor X'size must be divisible by known "
                "capacity of 'shape'. "
                "But received X's shape = [%s], X's size = %d, "
                "'shape' is [%s], known capacity of 'shape' is %d.",
242 243 244 245
                in_dims,
                in_size,
                phi::make_ddim(shape),
                capacity));
246 247 248 249 250 251
      } else {
        output_shape[unk_dim_idx] = -1;
      }
    } else {
      if (all_positive) {
        PADDLE_ENFORCE_EQ(
252 253
            capacity,
            in_size,
254 255 256 257 258 259
            platform::errors::InvalidArgument(
                "The 'shape' in ReshapeOp is invalid. "
                "The input tensor X'size must be equal to the capacity of "
                "'shape'. "
                "But received X's shape = [%s], X's size = %d, 'shape' is "
                "[%s], the capacity of 'shape' is %d.",
260 261 262 263
                in_dims,
                in_size,
                phi::make_ddim(shape),
                capacity));
264 265
      }
    }
266
    return phi::make_ddim(output_shape);
267 268 269
  }
};

270 271
template <typename T, ReshapeKernelOpName op_name>
class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
272 273 274 275 276 277 278
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    RunKernel(ctx);
  }

 private:
  void RunKernel(const framework::ExecutionContext& ctx) const {
279
    const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
280 281
    const auto& onednn_engine = dev_ctx.GetEngine();

282 283
    auto* dout = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto* dx = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
284

285 286 287
    framework::DDim dx_dims;
    InferOutputShapeInGrad(ctx, dx_dims);

288
    auto dout_vec_dims = phi::vectorize(dout->dims());
289

290 291 292
    auto dout_type = phi::funcs::ToOneDNNDataType(dout->dtype());
    phi::funcs::ReorderOneDNNHandler reorder_handler(
        dout_vec_dims, dout->dtype(), dout_type, onednn_engine);
293 294

    auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
295
        dout->mem_desc(), phi::funcs::to_void_cast(dout->data<T>()));
296
    auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
297 298 299
        dx,
        phi::funcs::GetPlainOneDNNFormat(dout_vec_dims.size()),
        ctx.GetPlace());
300 301
    auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
                                                    reorder_src_memory_p);
302

303
    auto& astream = phi::OneDNNContext::tls().get_stream();
304 305 306
    reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
    astream.wait();

307
    dx->Resize(dx_dims);
308
    reorder_dst_memory_p->get_desc().reshape(phi::vectorize(dx_dims));
309 310
  }

311
  void InferOutputShapeInGrad(const framework::ExecutionContext& ctx,
312
                              framework::DDim& x_dims) const {  // NOLINT
313 314 315 316 317
    switch (op_name) {
      case ReshapeKernelOpName::reshape:
        InferShapeReshapeSqueezeGradOp(ctx, x_dims);
        break;
      case ReshapeKernelOpName::reshape2:
318
        InferShapeReshape2Flatten2GradOp(ctx, x_dims);
319 320 321 322 323 324 325 326
        break;
      case ReshapeKernelOpName::squeeze:
        InferShapeReshapeSqueezeGradOp(ctx, x_dims);
        break;
      case ReshapeKernelOpName::flatten:
        InferShapeFlattenGradOp(ctx, x_dims);
        break;
      case ReshapeKernelOpName::flatten2:
327
        InferShapeReshape2Flatten2GradOp(ctx, x_dims);
328 329 330 331 332 333
        break;
      default:
        PADDLE_THROW(paddle::platform::errors::OutOfRange(
            "Reshape grad kernel doesn not support that operator name"));
    }
  }
334

335 336 337
  void InferShapeReshapeSqueezeGradOp(
      const framework::ExecutionContext& ctx,
      framework::DDim& dx_dims) const {  // NOLINT
338
    auto* dx = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
339 340
    dx_dims = dx->dims();
  }
341

342
  void InferShapeReshape2Flatten2GradOp(
343 344
      const framework::ExecutionContext& ctx,
      framework::DDim& dx_dims) const {  // NOLINT
345
    auto xshape_dims = ctx.Input<phi::DenseTensor>("XShape")->dims();
346
    dx_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
347
  }
348

349
  void InferShapeFlattenGradOp(const framework::ExecutionContext& ctx,
350
                               framework::DDim& dx_dims) const {  // NOLINT
351
    dx_dims = ctx.Input<phi::DenseTensor>("X")->dims();
352 353 354 355
  }
};
}  // namespace operators
}  // namespace paddle
356

357 358
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
359 360
    squeeze,
    MKLDNN,
361
    phi::CPUPlace,
362 363 364 365 366
    ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
    ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
                             ReshapeKernelOpName::squeeze>);

REGISTER_OP_KERNEL(
367 368
    squeeze_grad,
    MKLDNN,
369
    phi::CPUPlace,
370 371 372 373 374
    ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
    ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
                                 ReshapeKernelOpName::squeeze>);

REGISTER_OP_KERNEL(
375 376
    reshape,
    MKLDNN,
377
    phi::CPUPlace,
378 379 380 381 382
    ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
    ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
                             ReshapeKernelOpName::reshape>);

REGISTER_OP_KERNEL(
383 384
    reshape_grad,
    MKLDNN,
385
    phi::CPUPlace,
386 387 388 389 390
    ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
    ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
                                 ReshapeKernelOpName::reshape>);

REGISTER_OP_KERNEL(
391 392
    reshape2_grad,
    MKLDNN,
393
    phi::CPUPlace,
394 395 396 397 398
    ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
    ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
                                 ReshapeKernelOpName::reshape2>);

REGISTER_OP_KERNEL(
399 400
    flatten,
    MKLDNN,
401
    phi::CPUPlace,
402 403 404 405 406
    ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
    ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
                             ReshapeKernelOpName::flatten>);

REGISTER_OP_KERNEL(
407 408
    flatten_grad,
    MKLDNN,
409
    phi::CPUPlace,
410 411 412 413 414
    ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
    ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
                                 ReshapeKernelOpName::flatten>);

REGISTER_OP_KERNEL(
415 416
    flatten2,
    MKLDNN,
417
    phi::CPUPlace,
418 419 420 421 422
    ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
    ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
                             ReshapeKernelOpName::flatten2>);

REGISTER_OP_KERNEL(
423 424
    flatten2_grad,
    MKLDNN,
425
    phi::CPUPlace,
426 427 428
    ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
    ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
                                 ReshapeKernelOpName::flatten2>);