fetch_op.cc 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

15
#include "paddle/fluid/framework/data_layout_transform.h"
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
19
#include "paddle/fluid/platform/profiler/event_tracing.h"
Q
qijun 已提交
20 21 22 23

namespace paddle {
namespace operators {

24 25 26 27 28 29 30 31 32 33 34 35
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
static void DataCopy(const framework::LoDTensor &src_item,
                     const std::string &fetch_var_name,
                     framework::LoDTensor *dst_item) {
  if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
    // Conversion from MKL-DNN to Paddle
    if (src_item.layout() == framework::DataLayout::kMKLDNN) {
      framework::Tensor out;
      // Convert to desired Paddle layout, apart from grads of filter
      // as params are not a subject to paddle's data_format
L
Leo Chen 已提交
36
      VLOG(4) << "innerTransDataLayoutFromMKLDNN";
37
      framework::innerTransDataLayoutFromMKLDNN(
38 39 40 41 42
          src_item.layout(),
          fetch_var_name == framework::GradVarName("Filter")
              ? framework::DataLayout::kNCHW
              : paddle::platform::MKLDNNDeviceContext::tls()
                    .get_cur_paddle_data_layout(),
43 44 45
          src_item,
          &out,
          platform::CPUPlace());
46
      paddle::framework::TensorCopySync(out, platform::CPUPlace(), dst_item);
47
    } else {
48 49
      paddle::framework::TensorCopySync(
          src_item, platform::CPUPlace(), dst_item);
50 51
    }
#else
52
    paddle::framework::TensorCopySync(src_item, platform::CPUPlace(), dst_item);
53 54 55 56 57 58 59 60 61
#endif
  } else {
    // Not copy, if the src tensor is empty.
    dst_item->clear();
    dst_item->Resize({0});
  }
  dst_item->set_lod(src_item.lod());
}

Y
Yu Yang 已提交
62
class FetchOp : public framework::OperatorBase {
Q
qijun 已提交
63
 public:
64 65
  FetchOp(const std::string &type,
          const framework::VariableNameMap &inputs,
Y
Yu Yang 已提交
66 67 68
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
Q
qijun 已提交
69

70 71 72
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
73 74 75
    OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Fetch");
    OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Fetch");

Y
Yu Yang 已提交
76
    auto fetch_var_name = Input("X");
Y
Yu Yang 已提交
77
    auto *fetch_var = scope.FindVar(fetch_var_name);
78 79 80
    PADDLE_ENFORCE_NOT_NULL(
        fetch_var,
        platform::errors::NotFound(
81 82 83 84 85 86 87
            "Input variable(%s) cannot be found in scope for operator 'Fetch'."
            "Confirm that you have used the fetch `Variable` format "
            "instead of the string literal('%s') in `fetch_list` "
            "parameter when using `executor.run` method. In other "
            "words, the format of "
            "`executor.run(fetch_list=[fetch_var])`(fetch_var is a "
            "Variable) is recommended.",
88 89
            fetch_var_name,
            fetch_var_name));
Q
qijun 已提交
90

91
    auto out_name = Output("Out");
Y
Yu Yang 已提交
92
    auto *out_var = scope.FindVar(out_name);
93 94 95 96 97
    PADDLE_ENFORCE_NOT_NULL(
        out_var,
        platform::errors::NotFound("Output variable(%s) cannot be found "
                                   "in scope for operator 'Fetch'.",
                                   out_name));
98 99 100

    int col = Attr<int>("col");
    PADDLE_ENFORCE_GE(
101 102
        col,
        0,
103 104 105 106 107
        platform::errors::InvalidArgument(
            "Expected the column index (the attribute 'col' of "
            "operator 'Fetch') of current fetching variable to be "
            "no less than 0. But received column index = %d.",
            col));
108 109 110

    VLOG(3) << "Fetch variable " << fetch_var_name << " to variable "
            << out_name << "'s " << col << " column.";
Y
Yu Yang 已提交
111

112
    auto *fetch_list = out_var->GetMutable<framework::FetchList>();
Y
Yu Yang 已提交
113

114
    if (static_cast<size_t>(col) >= fetch_list->size()) {
Y
Yu Yang 已提交
115 116 117
      fetch_list->resize(col + 1);
    }

118 119
    if (fetch_var->IsType<framework::LoDTensor>()) {
      auto &src_item = fetch_var->Get<framework::LoDTensor>();
R
Ruibiao Chen 已提交
120
      auto *dst_item = &(PADDLE_GET(framework::LoDTensor, fetch_list->at(col)));
121
      DataCopy(src_item, fetch_var_name, dst_item);
S
Steffy-zxf 已提交
122 123
    } else if (fetch_var->IsType<framework::Vocab>()) {
      auto &src_item = fetch_var->Get<framework::Vocab>();
R
Ruibiao Chen 已提交
124
      auto *dst_item = &(PADDLE_GET(framework::Vocab, fetch_list->at(col)));
S
Steffy-zxf 已提交
125
      *dst_item = src_item;
126 127 128
    } else if (fetch_var->IsType<phi::SparseCooTensor>()) {
      auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
      fetch_list->at(col) = src_item;
Q
qingqing01 已提交
129
    } else {
130 131 132 133
      auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
      framework::LoDTensorArray tmp(src_item.size());
      fetch_list->at(col) = tmp;
      auto &dst_item =
R
Ruibiao Chen 已提交
134
          PADDLE_GET(framework::LoDTensorArray, fetch_list->at(col));
135 136 137
      for (size_t i = 0; i < src_item.size(); ++i) {
        DataCopy(src_item[i], fetch_var_name, &dst_item[i]);
      }
Q
qingqing01 已提交
138
    }
Q
qijun 已提交
139 140 141
  }
};

Y
Yu Yang 已提交
142 143
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
144
  void Make() override {
145 146 147
    AddInput("X",
             "(LoDTensor) The resulted LoDTensor which is expected to return "
             "to users.");
S
Steffy-zxf 已提交
148 149 150 151 152
    AddOutput(
        "Out",
        "(vector<LoDTensor>|unordered_map<string, int32_t>) A fetching list"
        " of LoDTensor|unordered_map<string, int32_t> which may have "
        "different dimension, shape and data type.");
153
    AddAttr<int>("col", "(int) The column index of fetching object.");
K
kexinzhao 已提交
154 155 156 157 158 159
    AddComment(R"DOC(
Fetch Operator.

It should not be configured by users directly.

)DOC");
Y
Yu Yang 已提交
160 161
  }
};
Q
qijun 已提交
162 163 164
}  // namespace operators
}  // namespace paddle

H
hong 已提交
165
REGISTER_OPERATOR(
166 167
    fetch,
    paddle::operators::FetchOp,
H
hong 已提交
168 169 170
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    paddle::operators::FetchOpInfoMaker);