stack_op.cc 4.7 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include <memory>
#include <vector>
C
csy0225 已提交
17 18 19 20
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
X
Xin Pan 已提交
21 22 23

namespace plat = paddle::platform;
namespace ops = paddle::operators;
24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

class StackOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

32 33 34 35 36 37 38 39 40 41 42 43 44 45
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto input_data_type =
        framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
46 47 48 49 50 51 52 53 54 55
};

class StackOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "The input of stack op.").AsDuplicable();
    AddOutput("Y", "The output of stack op.");
    AddAttr<int>("axis",
                 "The axis along which all of the Inputs(X) should be stacked.")
        .SetDefault(0);
56 57 58 59 60
    AddAttr<bool>(
        "use_mkldnn",
        "(bool, default false) Indicates if MKL-DNN kernel will be used")
        .SetDefault(false)
        .AsExtra();
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    AddComment(R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC");
  }
};

class StackOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE_EQ(
        ctx->HasInput(framework::GradVarName("Y")), true,
        platform::errors::InvalidArgument("Input(Y@Grad) not exist."));

    int axis = ctx->Attrs().Get<int>("axis");
    auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
    int rank = dy_dim.size();
    PADDLE_ENFORCE_GE(
        axis, -rank,
        platform::errors::InvalidArgument(
            "Attr(axis) must be inside [-rank, rank), where rank = %d, "
            "but received axis is:%d.",
            rank, axis));
    PADDLE_ENFORCE_LT(
        axis, rank,
        platform::errors::InvalidArgument(
            "Attr(axis) must be inside [-rank, rank), where rank = %d, "
            "but received axis is:%d.",
            rank, axis));

    if (axis < 0) axis += rank;
    PADDLE_ENFORCE_EQ(
        ctx->Outputs(framework::GradVarName("X")).size(),
        static_cast<size_t>(dy_dim[axis]),
        platform::errors::InvalidArgument(
            "Number of Outputs(X@Grad) is equal to dy dim at axis, but"
            " received outputs size is:%d, dy dims is:%d.",
            ctx->Outputs(framework::GradVarName("X")).size(),
            static_cast<size_t>(dy_dim[axis])));

103
    auto vec = phi::vectorize<int>(dy_dim);
104 105 106
    vec.erase(vec.begin() + axis);
    ctx->SetOutputsDim(
        framework::GradVarName("X"),
107
        std::vector<framework::DDim>(dy_dim[axis], phi::make_ddim(vec)));
108 109 110 111 112 113 114 115 116
  }
};

template <typename T>
class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
117
  void Apply(GradOpPtr<T> op) const override {
118 119 120 121 122 123 124 125 126 127
    op->SetType("stack_grad");
    op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
    op->SetAttrMap(this->Attrs());
  }
};

}  // namespace operators
}  // namespace paddle

C
csy0225 已提交
128 129
DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor,
                            PD_INFER_META(phi::StackInferMeta));
X
Xin Pan 已提交
130
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
H
hong 已提交
131
                  ops::StackGradOpMaker<paddle::framework::OpDesc>,
C
csy0225 已提交
132 133
                  ops::StackGradOpMaker<paddle::imperative::OpBase>,
                  StackInferMetaFunctor);
X
Xin Pan 已提交
134
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);