switch_op.cc 3.5 KB
Newer Older
Z
Zhuoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
#include "paddle/operators/switch_op.h"

namespace paddle {
namespace operators {

// namespace if_else{


void CondOp::Init() override {
}

void InferShape(const std::shared_ptr<Scope>& scope) const override {
  subnet_t = GetAttr<std::string>("subnet_t");
  subnet_f = GetAttr<std::string>("subnet_f");

  // Create two Nets
  // I use the same style as Recurrent_op, but does it create the net?
  // can be called like
  Variable* net_t = scope.FindVar(subnet_t);
  Variable* net_f = scope.FindVar(subnet_f);

  net_op_t = scope.FindVar(net_t)->GetMutable<NetOp>();
  net_op_f = scope.FindVar(net_f)->GetMutable<NetOp>();

  // Create two scopes
  scope_t = scope.NewScope();
  scope_f = scope.NewScope();

  // check cond of size (batch_size), type bool
  net_op_t->InferShape(scope_t);
  net_op_f->InferShape(scope_f);

  // check net_op_t and net_op_f of exactly same shape?
}

void IfElseOp::Run(const std::shared_ptr<Scope>& scope,
                   const platform::DeviceContext& dev_ctx) const {
  /* step 1: create two subnets and scopes, supposed done in Infershape() */

  /* step 2: get true and false index */
  cond = Input(name.cond);
  // get condition tensor
  auto cond_tensor = scope.get<Tensor>(cond);
  // tensor to cpu, whatever device it used to be in
  cond_cpu.CopyFrom(cond_tensor, platform::CPUPlace());

  size_t batch_size = cond_cpu.dims()[0];

  // keep index of true and false to slice, clear them first before each batch
  true_index.clear();
  false_index.clear();
  	
  // get a DDim type variable dims, check dimension
  auto dims = input0.dims();
  for(int i=0; i<dims; i++) {
    if (cond_cpu->data[i])
      true_index.push_back(i);
    else
      false_index.push_back(i);
  }

  // turn true_index and false_index to tensors
  Tensor* true_index_tensor = new Tensor(true_index);
  Tensor* false_index_tensor = new Tensor(false_index);

  /* Step 3: Gather */
  { // True Scope
    // Create new stuff
    for (auto& input : net_op_t->inputs_) {
      scope_t.NewVar(input);
      if (input.type() != PARAMETER) { // gather and slice required
        // Get Tensor and gather
        Tensor* input_gather_ = scope_t.FindVar(input)->GetMutable<Tensor>();
        Tensor* input_full_ = scope.FindVar(input)->GetMutable<Tensor>();
        input_gather_ = Gather(input_full_, true_index_tensor);
      }
    }

    for (auto& output : net_op->outputs_) {
      scope_t.NewVar(output);
    }
      
    net_op_t.Run();
  }

  { // False Scope
    // Create new stuff
    for (auto& input : net_op_f->inputs_) {
      scope_f.NewVar(input);
      if (input.type() != PARAMETER) { // gather and slice required
        // Get Tensor and gather
        Tensor* input_gather_ = scope_f.FindVar(input)->GetMutable<Tensor>();
        Tensor* input_full_ = scope.FindVar(input)->GetMutable<Tensor>();
        input_gather_ = Gather(input_full_, false_index_tensor);
      }
    }

    for (auto& output : net_op->outputs_) {
      scope_t.NewVar(output);
    }

    net_op_f.Run();
  }

  /* Merge Output Together by scatter update */
  for (auto& ouput : outputs_) {
    Tensor* output_t = scope_t->FindVar(output)->GetMutable<Tensor>();
    Tensor* output_f = scope_f->FindVar(output)->GetMutable<Tensor>();
    Tensor* output_tensor = scope->FindVar(output)->GetMutable<Tensor>();
    Scatter(output_t, output_tensor, true_index_tensor);
    Scatter(output_f, output_tensor, false_index_tensor);
  }
}

}  // namespace operators
}  // namespace paddle

REGISTER_OP(ifelse_op,
            paddle::operators::IfElseOp,
            paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);