conditional_block_infer_op.cc 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/controlflow/conditional_block_op.h"

17 18 19 20 21
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

DECLARE_bool(use_mkldnn);
W
wanghuancoder 已提交
22 23 24 25 26 27 28 29 30 31 32 33
namespace paddle {
namespace framework {
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
namespace paddle {
namespace operators {

/* We will implement the op with block separately in the future.
 * The main reason is that some of the training requirements
 * in these OPS can lead to problems(such as memory leaks) during inference.
 */
class ConditionalBlockInferOp : public ConditionalOp {
 public:
  ConditionalBlockInferOp(const std::string &type,
                          const framework::VariableNameMap &inputs,
                          const framework::VariableNameMap &outputs,
                          const framework::AttributeMap &attrs)
      : ConditionalOp(type, inputs, outputs, attrs) {}

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
    bool need_run;
    if (Attr<bool>("is_scalar_condition")) {
      // When is_scalar_condition is True, the conditional variable is a scalar,
      // whether need to execute the operators in sub-block depends on the
      // conditional variable (Cond).
      auto xs = InputTensors(scope, "Cond");
      need_run = ScalarCondition(xs);
    } else {
      // When is_scalar_condition is False, the conditional variable maybe a
      // vector or tensor, whether need to execute the operators in sub-block
      // depends on the input variables (Input).
      auto xs = InputTensors(scope, "Input");
64
      need_run =
65
          std::all_of(xs.begin(), xs.end(), [](const phi::DenseTensor *t) {
66 67
            return t->numel() != 0;
          });
68 69 70 71
    }

    if (need_run) {
      auto *scope_var = scope.FindVar(Output("Scope"));
72
      PADDLE_ENFORCE_NOT_NULL(
73 74 75
          scope_var,
          platform::errors::PreconditionNotMet(
              "Scope must be set in ConditionalBlockInferOp."));
76 77 78 79 80 81
      auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
      scopes->resize(1);
      scopes->front() = &scope.NewScope();
      auto &cur_scope = *scopes->front();

      auto *block = Attr<framework::BlockDesc *>("sub_block");
W
wenbin 已提交
82 83
      VLOG(3) << "Conditional block.idx = " << block->ID()
              << ", scope = " << &cur_scope;
84

85
      if (!exec_ || !platform::is_same_place(exec_->GetPlace(), dev_place)) {
86
        auto &pdesc = *block->Program();
87 88 89 90 91
        exec_.reset(new framework::Executor(dev_place));
#ifdef PADDLE_WITH_MKLDNN
        if (FLAGS_use_mkldnn) exec_->EnableMKLDNN(pdesc);
#endif
        ctx_ = exec_->Prepare(
92 93
            pdesc, block->ID(), std::vector<std::string>(), false);
#ifdef PADDLE_WITH_MKLDNN
94 95 96 97
        if (FLAGS_use_mkldnn) {
          platform::AttachPointerHashToMKLDNNKey(exec_.get(), dev_place);
          platform::RegisterModelLayout(ctx_->ops_, dev_place);
        }
98 99
#endif
      }
100
      exec_->RunPreparedContext(ctx_.get(), &cur_scope, false, true, false);
101 102 103
      scope.DeleteScope(scopes->front());
    }
  }
104 105

 private:
106 107
  mutable std::shared_ptr<framework::Executor> exec_{nullptr};
  mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
108 109 110 111 112 113
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
114
REGISTER_OPERATOR(
115 116
    conditional_block_infer,
    ops::ConditionalBlockInferOp,
H
hong 已提交
117 118 119
    ops::ConditionalBlockOpProtoMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);