conditional_block_op.cc 9.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14
#include <algorithm>
Y
Yi Wang 已提交
15 16
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
S
sneaxiy 已提交
17
#include "paddle/fluid/framework/var_type.h"
Y
Yu Yang 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

class ConditionalOp : public framework::OperatorBase {
 public:
  ConditionalOp(const std::string &type,
                const framework::VariableNameMap &inputs,
                const framework::VariableNameMap &outputs,
                const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}

 protected:
  std::vector<const framework::LoDTensor *> InputTensors(
32
      const framework::Scope &scope, const std::string &in_name) const {
Y
Yu Yang 已提交
33
    std::vector<const framework::LoDTensor *> retv;
34
    auto xs = Inputs(in_name);
Y
Yu Yang 已提交
35 36 37 38 39 40 41 42 43 44
    retv.resize(xs.size(), nullptr);
    std::transform(
        xs.begin(), xs.end(), retv.begin(),
        [&scope](const std::string &var_name) -> const framework::LoDTensor * {
          auto *var = scope.FindVar(var_name);
          PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name);
          return &var->Get<framework::LoDTensor>();
        });
    return retv;
  }
45 46 47 48 49 50

  bool ScalarCondition(
      const std::vector<const framework::LoDTensor *> &ips) const {
    if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
      PADDLE_THROW("should have one initialized input as condition");
    }
S
sneaxiy 已提交
51
    if (!(framework::IsType<bool>(ips[0]->type()) &&  // NOLINT
52 53 54 55 56 57
          ips[0]->numel() == 1)) {
      PADDLE_THROW(
          "condition input's data type should be bool, "
          "numel should be 1, actual numel is %d",
          ips[0]->numel());
    }
F
fengjiayi 已提交
58
    bool res = false;
59 60 61 62 63 64 65 66 67 68 69
    if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
      framework::LoDTensor cpu_tensor;
      framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
      platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
      res = cpu_tensor.data<bool>()[0];
#endif
    } else {
      res = ips[0]->data<bool>()[0];
    }
    return res;
70
  }
Y
Yu Yang 已提交
71 72 73 74 75 76 77 78 79
};

class ConditionalBlockOp : public ConditionalOp {
 public:
  ConditionalBlockOp(const std::string &type,
                     const framework::VariableNameMap &inputs,
                     const framework::VariableNameMap &outputs,
                     const framework::AttributeMap &attrs)
      : ConditionalOp(type, inputs, outputs, attrs) {}
80 81 82 83

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
84 85
    bool need_run;
    if (Attr<bool>("is_scalar_condition")) {
86 87 88 89
      // When is_scalar_condition is True, the conditional variable is a scalar,
      // whether need to execute the operators in sub-block depends on the
      // conditional variable (Cond).
      auto xs = InputTensors(scope, "Cond");
90 91
      need_run = ScalarCondition(xs);
    } else {
92 93 94 95
      // When is_scalar_condition is False, the conditional variable maybe a
      // vector or tensor, whether need to execute the operators in sub-block
      // depends on the input variables (Input).
      auto xs = InputTensors(scope, "Input");
96 97 98 99
      need_run = std::all_of(
          xs.begin(), xs.end(),
          [](const framework::LoDTensor *t) { return t->numel() != 0; });
    }
Y
Yu Yang 已提交
100 101 102 103 104 105 106 107 108

    if (need_run) {
      auto *scope_var = scope.FindVar(Output("Scope"));
      PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
      auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
      scopes->resize(1);
      scopes->front() = &scope.NewScope();
      auto &cur_scope = *scopes->front();

D
dzhwinter 已提交
109
      framework::Executor exec(dev_place);
Y
Yu Yang 已提交
110
      auto *block = Attr<framework::BlockDesc *>("sub_block");
Y
Yu Yang 已提交
111 112 113 114 115 116 117
      exec.Run(*block->Program(), &cur_scope, block->ID(), false);
    }
  }
};

class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
118
  void Make() override {
119 120
    AddInput("Cond",
             "The conditional variable of this operator. If Cond is empty, the "
Y
Yu Yang 已提交
121 122
             "whole sub-block will not be executed.")
        .AsDuplicable();
123
    AddInput("Input", "The input variables of the sub-block.").AsDuplicable();
Y
Yu Yang 已提交
124 125 126 127 128
    AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
    AddOutput("Scope",
              "(std::vector<Scope*>) The step scope of conditional block. To "
              "unify the conditional block, rnn and while op, the type of "
              "scope is std::vector<Scope*>");
Y
Yu Yang 已提交
129
    AddAttr<framework::BlockDesc *>(
130
        "sub_block", "The step block of conditional block operator");
131
    AddAttr<bool>("is_scalar_condition",
132 133
                  "The conditional variable (Cond) is used as scalar "
                  "condition.")
134
        .SetDefault(false);
Y
Yu Yang 已提交
135 136
    AddComment(R"DOC(Conditional block operator

137 138 139 140 141 142 143
If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar,
run the operators in sub-block if Cond is True.

If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or
tensor, run the operators in sub-block if all of input variables are not empty.


Y
Yu Yang 已提交
144 145 146 147 148 149 150 151 152 153 154
)DOC");
  }
};

class ConditionalBlockGradOp : public ConditionalOp {
 public:
  ConditionalBlockGradOp(const std::string &type,
                         const framework::VariableNameMap &inputs,
                         const framework::VariableNameMap &outputs,
                         const framework::AttributeMap &attrs)
      : ConditionalOp(type, inputs, outputs, attrs) {}
155 156 157 158

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
159 160
    bool need_run;
    if (Attr<bool>("is_scalar_condition")) {
161
      auto xs = this->InputTensors(scope, "Cond");
162 163
      need_run = ScalarCondition(xs);
    } else {
164
      auto xs = this->InputTensors(scope, "Input");
165 166 167 168
      need_run = std::all_of(
          xs.begin(), xs.end(),
          [](const framework::LoDTensor *t) { return t->numel() != 0; });
    }
Y
Yu Yang 已提交
169 170 171 172 173 174 175

    if (need_run) {
      auto *scope_var = scope.FindVar(Input("Scope"));
      PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
      auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
      framework::Scope &cur_scope = *scopes[0];

D
dzhwinter 已提交
176
      framework::Executor exec(dev_place);
Y
Yu Yang 已提交
177
      auto *block = Attr<framework::BlockDesc *>("sub_block");
Y
Yu Yang 已提交
178 179
      exec.Run(*block->Program(), &cur_scope, block->ID(), false);

180 181
      AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"),
                                  Outputs(framework::GradVarName("Input")));
Y
Yu Yang 已提交
182

183 184
      AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"),
                                  Outputs(framework::GradVarName("Cond")));
Y
Yu Yang 已提交
185 186 187 188 189
    }
  }

 private:
  void AssignLocalGradientToGlobal(
D
dzhwinter 已提交
190
      const platform::Place &place, const framework::Scope &cur_scope,
Y
Yu Yang 已提交
191 192 193 194 195 196 197 198 199 200
      const std::vector<std::string> &p_names,
      const std::vector<std::string> &pg_names) const {
    for (size_t i = 0; i < p_names.size(); ++i) {
      auto out_grad_name = pg_names[i];
      auto in_grad_name = framework::GradVarName(p_names[i]);
      auto *in_var = cur_scope.FindVar(in_grad_name);
      if (in_var == nullptr) {
        continue;
      }
      auto new_in_grad_name = cur_scope.Rename(in_grad_name);
Y
Yiqun Liu 已提交
201 202 203
      auto assign = framework::OpRegistry::CreateOp(
          "assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}},
          framework::AttributeMap{});
D
dzhwinter 已提交
204
      assign->Run(cur_scope, place);
Y
Yu Yang 已提交
205 206 207 208 209 210 211 212
      cur_scope.Rename(new_in_grad_name, in_grad_name);
    }
  }
};

class ConditionalBlockGradInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *context) const override {
213 214 215 216 217
    PADDLE_ENFORCE(context->HasInputs("Cond"));
    if (context->HasInputs("Input")) {
      PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input")));
      context->SetOutputsDim(framework::GradVarName("Input"),
                             context->GetInputsDim("Input"));
Y
Yu Yang 已提交
218
    }
219 220 221
    if (context->HasOutputs(framework::GradVarName("Cond"))) {
      context->SetOutputsDim(framework::GradVarName("Cond"),
                             context->GetInputsDim("Cond"));
222
    }
Y
Yu Yang 已提交
223 224 225 226 227 228 229 230
  }
};

class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
231 232
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
233
    grad_op->SetType("conditional_block_grad");
234 235
    grad_op->SetInput("Cond", Input("Cond"));
    grad_op->SetInput("Input", Input("Input"));
Y
Yu Yang 已提交
236 237 238
    grad_op->SetInput("Out", Output("Out"));
    grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    grad_op->SetInput("Scope", Output("Scope"));
239 240 241 242
    grad_op->SetOutput(framework::GradVarName("Cond"),
                       InputGrad("Cond", false));
    grad_op->SetOutput(framework::GradVarName("Input"),
                       InputGrad("Input", false));
A
Abhinav Arora 已提交
243
    grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
244
    grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
Y
Yu Yang 已提交
245
    return std::unique_ptr<framework::OpDesc>(grad_op);
Y
Yu Yang 已提交
246 247 248 249 250 251 252 253 254 255 256 257
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
                  ops::ConditionalBlockOpProtoMaker,
                  ops::ConditionalBlockGradMaker);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
                  ops::ConditionalBlockGradInferShape);