conditional_block_op.cc 10.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14
#include <algorithm>
15 16 17
#include <memory>
#include <string>
#include <vector>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/framework/var_type.h"
Y
Yu Yang 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34

namespace paddle {
namespace operators {

class ConditionalOp : public framework::OperatorBase {
 public:
  ConditionalOp(const std::string &type,
                const framework::VariableNameMap &inputs,
                const framework::VariableNameMap &outputs,
                const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}

 protected:
  std::vector<const framework::LoDTensor *> InputTensors(
35
      const framework::Scope &scope, const std::string &in_name) const {
Y
Yu Yang 已提交
36
    std::vector<const framework::LoDTensor *> retv;
37
    auto xs = Inputs(in_name);
Y
Yu Yang 已提交
38 39 40 41 42 43 44 45 46 47
    retv.resize(xs.size(), nullptr);
    std::transform(
        xs.begin(), xs.end(), retv.begin(),
        [&scope](const std::string &var_name) -> const framework::LoDTensor * {
          auto *var = scope.FindVar(var_name);
          PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name);
          return &var->Get<framework::LoDTensor>();
        });
    return retv;
  }
48 49 50 51 52 53

  bool ScalarCondition(
      const std::vector<const framework::LoDTensor *> &ips) const {
    if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
      PADDLE_THROW("should have one initialized input as condition");
    }
Y
Yu Yang 已提交
54 55 56 57 58 59

    PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
                       ips[0]->numel() == 1,
                   "condition input's data type should be bool, "
                   "numel should be 1, actual numel is %d",
                   ips[0]->numel());
F
fengjiayi 已提交
60
    bool res = false;
61 62 63 64 65 66 67 68 69 70 71
    if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
      framework::LoDTensor cpu_tensor;
      framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
      platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
      res = cpu_tensor.data<bool>()[0];
#endif
    } else {
      res = ips[0]->data<bool>()[0];
    }
    return res;
72
  }
Y
Yu Yang 已提交
73 74 75 76 77 78 79 80 81
};

class ConditionalBlockOp : public ConditionalOp {
 public:
  ConditionalBlockOp(const std::string &type,
                     const framework::VariableNameMap &inputs,
                     const framework::VariableNameMap &outputs,
                     const framework::AttributeMap &attrs)
      : ConditionalOp(type, inputs, outputs, attrs) {}
82 83 84 85

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
86 87
    bool need_run;
    if (Attr<bool>("is_scalar_condition")) {
88 89 90 91
      // When is_scalar_condition is True, the conditional variable is a scalar,
      // whether need to execute the operators in sub-block depends on the
      // conditional variable (Cond).
      auto xs = InputTensors(scope, "Cond");
92 93
      need_run = ScalarCondition(xs);
    } else {
94 95 96 97
      // When is_scalar_condition is False, the conditional variable maybe a
      // vector or tensor, whether need to execute the operators in sub-block
      // depends on the input variables (Input).
      auto xs = InputTensors(scope, "Input");
98 99 100 101
      need_run = std::all_of(
          xs.begin(), xs.end(),
          [](const framework::LoDTensor *t) { return t->numel() != 0; });
    }
Y
Yu Yang 已提交
102 103 104 105 106 107 108 109 110

    if (need_run) {
      auto *scope_var = scope.FindVar(Output("Scope"));
      PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
      auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
      scopes->resize(1);
      scopes->front() = &scope.NewScope();
      auto &cur_scope = *scopes->front();

D
dzhwinter 已提交
111
      framework::Executor exec(dev_place);
Y
Yu Yang 已提交
112
      auto *block = Attr<framework::BlockDesc *>("sub_block");
Y
Yu Yang 已提交
113 114 115 116 117 118 119
      exec.Run(*block->Program(), &cur_scope, block->ID(), false);
    }
  }
};

class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
120
  void Make() override {
121 122
    AddInput("Cond",
             "The conditional variable of this operator. If Cond is empty, the "
Y
Yu Yang 已提交
123 124
             "whole sub-block will not be executed.")
        .AsDuplicable();
125
    AddInput("Input", "The input variables of the sub-block.").AsDuplicable();
Y
Yu Yang 已提交
126 127 128 129 130
    AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
    AddOutput("Scope",
              "(std::vector<Scope*>) The step scope of conditional block. To "
              "unify the conditional block, rnn and while op, the type of "
              "scope is std::vector<Scope*>");
Y
Yu Yang 已提交
131
    AddAttr<framework::BlockDesc *>(
132
        "sub_block", "The step block of conditional block operator");
133
    AddAttr<bool>("is_scalar_condition",
134 135
                  "The conditional variable (Cond) is used as scalar "
                  "condition.")
136
        .SetDefault(false);
Y
Yu Yang 已提交
137 138
    AddComment(R"DOC(Conditional block operator

139 140 141 142 143 144 145
If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar,
run the operators in sub-block if Cond is True.

If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or
tensor, run the operators in sub-block if all of input variables are not empty.


Y
Yu Yang 已提交
146 147 148 149 150 151 152 153 154 155 156
)DOC");
  }
};

class ConditionalBlockGradOp : public ConditionalOp {
 public:
  ConditionalBlockGradOp(const std::string &type,
                         const framework::VariableNameMap &inputs,
                         const framework::VariableNameMap &outputs,
                         const framework::AttributeMap &attrs)
      : ConditionalOp(type, inputs, outputs, attrs) {}
157 158 159 160

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
161 162
    bool need_run;
    if (Attr<bool>("is_scalar_condition")) {
163
      auto xs = this->InputTensors(scope, "Cond");
164 165
      need_run = ScalarCondition(xs);
    } else {
166
      auto xs = this->InputTensors(scope, "Input");
167 168 169 170
      need_run = std::all_of(
          xs.begin(), xs.end(),
          [](const framework::LoDTensor *t) { return t->numel() != 0; });
    }
Y
Yu Yang 已提交
171 172 173 174 175 176 177

    if (need_run) {
      auto *scope_var = scope.FindVar(Input("Scope"));
      PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
      auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
      framework::Scope &cur_scope = *scopes[0];

D
dzhwinter 已提交
178
      framework::Executor exec(dev_place);
Y
Yu Yang 已提交
179
      auto *block = Attr<framework::BlockDesc *>("sub_block");
Y
Yu Yang 已提交
180

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
      const auto &ins = Inputs("Input");
      const auto &d_ins = Outputs(framework::GradVarName("Input"));
      const auto &conds = Inputs("Cond");
      const auto &d_conds = Outputs(framework::GradVarName("Cond"));

      std::vector<std::string> ins_conds_grads;
      ins_conds_grads.reserve(ins.size() + conds.size());
      for (auto &in : ins) {
        ins_conds_grads.emplace_back(framework::GradVarName(in));
      }
      for (auto &cond : conds) {
        ins_conds_grads.emplace_back(framework::GradVarName(cond));
      }

      exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
               ins_conds_grads);

      AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(),
                                  ins.size(), d_ins);
Y
Yu Yang 已提交
200

201 202 203
      AssignLocalGradientToGlobal(dev_place, cur_scope,
                                  ins_conds_grads.data() + ins.size(),
                                  conds.size(), d_conds);
Y
Yu Yang 已提交
204 205 206 207 208
    }
  }

 private:
  void AssignLocalGradientToGlobal(
D
dzhwinter 已提交
209
      const platform::Place &place, const framework::Scope &cur_scope,
210
      const std::string *p_grad_names, size_t p_grad_names_num,
Y
Yu Yang 已提交
211
      const std::vector<std::string> &pg_names) const {
212
    for (size_t i = 0; i < p_grad_names_num; ++i) {
Y
Yu Yang 已提交
213
      auto out_grad_name = pg_names[i];
214
      const auto &in_grad_name = p_grad_names[i];
Y
Yu Yang 已提交
215 216 217 218 219
      auto *in_var = cur_scope.FindVar(in_grad_name);
      if (in_var == nullptr) {
        continue;
      }
      auto new_in_grad_name = cur_scope.Rename(in_grad_name);
Y
Yiqun Liu 已提交
220 221 222
      auto assign = framework::OpRegistry::CreateOp(
          "assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}},
          framework::AttributeMap{});
D
dzhwinter 已提交
223
      assign->Run(cur_scope, place);
Y
Yu Yang 已提交
224 225 226 227 228 229 230 231
      cur_scope.Rename(new_in_grad_name, in_grad_name);
    }
  }
};

class ConditionalBlockGradInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *context) const override {
232 233 234 235 236
    PADDLE_ENFORCE(context->HasInputs("Cond"));
    if (context->HasInputs("Input")) {
      PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input")));
      context->SetOutputsDim(framework::GradVarName("Input"),
                             context->GetInputsDim("Input"));
Y
Yu Yang 已提交
237
    }
238 239 240
    if (context->HasOutputs(framework::GradVarName("Cond"))) {
      context->SetOutputsDim(framework::GradVarName("Cond"),
                             context->GetInputsDim("Cond"));
241
    }
Y
Yu Yang 已提交
242 243 244 245 246 247 248 249
  }
};

class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
250 251
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
252
    grad_op->SetType("conditional_block_grad");
253 254
    grad_op->SetInput("Cond", Input("Cond"));
    grad_op->SetInput("Input", Input("Input"));
Y
Yu Yang 已提交
255 256 257
    grad_op->SetInput("Out", Output("Out"));
    grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    grad_op->SetInput("Scope", Output("Scope"));
258 259 260 261
    grad_op->SetOutput(framework::GradVarName("Cond"),
                       InputGrad("Cond", false));
    grad_op->SetOutput(framework::GradVarName("Input"),
                       InputGrad("Input", false));
A
Abhinav Arora 已提交
262
    grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
263
    grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
Y
Yu Yang 已提交
264
    return std::unique_ptr<framework::OpDesc>(grad_op);
Y
Yu Yang 已提交
265 266 267 268 269 270 271 272 273 274 275 276
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
                  ops::ConditionalBlockOpProtoMaker,
                  ops::ConditionalBlockGradMaker);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
                  ops::ConditionalBlockGradInferShape);