assign_op.cc 4.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

15 16 17
#include "paddle/fluid/operators/assign_op.h"

#include <string>
Y
Yu Yang 已提交
18

19 20 21
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
W
wanghuancoder 已提交
22 23 24 25 26 27 28 29 30 31
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Y
Yu Yang 已提交
32 33 34
namespace paddle {
namespace operators {

35
class AssignOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
36
 public:
37 38
  AssignOp(const std::string &type,
           const framework::VariableNameMap &inputs,
Y
Yu Yang 已提交
39 40
           const framework::VariableNameMap &outputs,
           const framework::AttributeMap &attrs)
41
      : OperatorWithKernel(type, inputs, outputs, attrs) {}
42

43
 protected:
44
  phi::KernelKey GetKernelTypeForVar(
45
      const std::string &var_name,
46
      const phi::DenseTensor &tensor,
47 48 49 50
      const phi::KernelKey &expected_kernel_type) const override {
    return phi::KernelKey(phi::Backend::ALL_BACKEND,
                          tensor.layout(),
                          expected_kernel_type.dtype());
51 52
  }

53
  phi::KernelKey GetExpectedKernelType(
54
      const framework::ExecutionContext &ctx) const override {
L
liym27 已提交
55 56 57 58 59 60
    const framework::Variable *var = ctx.InputVar("X");
    if (var->IsType<framework::LoDTensorArray>()) {
      auto t_arr = var->Get<framework::LoDTensorArray>();
      // NOTE(liym27): Support an empty tensor array as Input.
      // And set the kernel type is float.
      if (t_arr.size() == 0) {
61 62
        return phi::KernelKey(framework::proto::VarType::FP32,
                              ctx.device_context().GetPlace());
L
liym27 已提交
63 64 65
      }
    }

66 67
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.device_context().GetPlace());
68 69 70
  }
};

71 72 73
class AssignInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
74
    ctx->SyncTypeAndDataType("X", "Out");
75 76 77
  }
};

Y
Yu Yang 已提交
78 79
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
80
  void Make() override {
81 82 83 84 85
    AddInput(
        "X",
        "(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The input "
        "variable "
        "could be phi::DenseTensor, SelectedRows or phi::DenseTensorArray.")
Y
Yu Yang 已提交
86 87
        .AsDispensable();
    AddOutput("Out",
88 89
              "(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The "
              "type of output "
Y
Yu Yang 已提交
90 91 92
              "is the same as input X.");
    AddComment(R"DOC(Assign Operator

93
Out = X,  when type in [phi::DenseTensor/SelectedRows/phi::DenseTensorArray]
Y
Yu Yang 已提交
94 95 96 97 98
raise error if the type is not listed above.
)DOC");
  }
};

H
hong 已提交
99 100
template <typename T>
class AssignGradMaker : public framework::SingleGradOpMaker<T> {
Y
Yu Yang 已提交
101
 public:
H
hong 已提交
102
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yu Yang 已提交
103 104

 protected:
105
  void Apply(GradOpPtr<T> op) const override {
Y
Yu Yang 已提交
106
    op->SetType("assign");
H
hong 已提交
107 108
    op->SetInput("X", this->OutputGrad("Out"));
    op->SetOutput("Out", this->InputGrad("X"));
Y
Yu Yang 已提交
109 110 111
  }
};

112 113
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});

Y
Yu Yang 已提交
114 115 116 117
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
118
namespace plat = paddle::platform;
119

120 121
DECLARE_INFER_SHAPE_FUNCTOR(assign,
                            AssignInferShapeFunctor,
122
                            PD_INFER_META(phi::UnchangedInferMeta));
123 124
REGISTER_OPERATOR(assign,
                  ops::AssignOp,
H
hong 已提交
125 126
                  ops::AssignGradMaker<paddle::framework::OpDesc>,
                  ops::AssignGradMaker<paddle::imperative::OpBase>,
127 128 129 130
                  ops::AssignOpProtoMaker,
                  ops::AssignOpInplaceInferer,
                  ops::AssignInferVarType,
                  AssignInferShapeFunctor);