diff --git a/doc/design/if_else_op.md b/doc/design/if_else_op.md new file mode 100644 index 0000000000000000000000000000000000000000..79bb543de1ceb58673c29c7e5d12556d1ddff160 --- /dev/null +++ b/doc/design/if_else_op.md @@ -0,0 +1,54 @@ +In an if_op, only inputs with condition satisfied will be run. The op could have multiple inputs and multiple outputs. +We should have the following design: + +```python +# A 1-d bool vector +cond = Var() +# create an op +if = pd.if_op() + +with if.true_block() as block: + x1 = if.input(x1) + x2 = if.input(x2) + y = pd.add(x1, x2) + y2 = pd.fc(x1) # contains (w,b) + if.output(y) + if.output(y2) + +o1, o2 = if(cond) +``` + +In an if_op, only inputs with condition satisfied will be run. +We should have the following design: +```python +# A 1-d bool vector +cond = Var() +# create an op +if = pd.if_op() + +with if.true_block() as block: + x1 = if.input(x1) + x2 = if.input(x2) + y = pd.add(x1, x2) + y2 = pd.fc(x1) # contains (w,b) + if.output(y, name="y") + if.output(y2, name="y2") + +with if.false_block() as block: + x1 = if.input(x1) + x2 = if.input(x2) + y = pd.fc(x2) + y2 = pd.softmax(x1) + if.output(y, name="y") + if.output(y2, name="y2") + +o1, o2 = if(cond) +``` + +Some questions: + 1. how to know which inputs will be selected by condition? + e.g. True_block(): + y = pd.fc(x) + # we will have x, w, b all as inputs + # but only x will be selected by cond, how can the block know? + diff --git a/paddle/operators/switch_op.cc b/paddle/operators/switch_op.cc deleted file mode 100644 index 09574a89a3572e55f7fa55bf3211c87bf039eeae..0000000000000000000000000000000000000000 --- a/paddle/operators/switch_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include "paddle/operators/switch_op.h" - -namespace paddle { -namespace operators { - -// namespace if_else{ - - -void CondOp::Init() override { -} - -void InferShape(const std::shared_ptr& scope) const override { - subnet_t = GetAttr("subnet_t"); - subnet_f = GetAttr("subnet_f"); - - // Create two Nets - // I use the same style as Recurrent_op, but does it create the net? - // can be called like - Variable* net_t = scope.FindVar(subnet_t); - Variable* net_f = scope.FindVar(subnet_f); - - net_op_t = scope.FindVar(net_t)->GetMutable(); - net_op_f = scope.FindVar(net_f)->GetMutable(); - - // Create two scopes - scope_t = scope.NewScope(); - scope_f = scope.NewScope(); - - // check cond of size (batch_size), type bool - net_op_t->InferShape(scope_t); - net_op_f->InferShape(scope_f); - - // check net_op_t and net_op_f of exactly same shape? -} - -void IfElseOp::Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const { - /* step 1: create two subnets and scopes, supposed done in Infershape() */ - - /* step 2: get true and false index */ - cond = Input(name.cond); - // get condition tensor - auto cond_tensor = scope.get(cond); - // tensor to cpu, whatever device it used to be in - cond_cpu.CopyFrom(cond_tensor, platform::CPUPlace()); - - size_t batch_size = cond_cpu.dims()[0]; - - // keep index of true and false to slice, clear them first before each batch - true_index.clear(); - false_index.clear(); - - // get a DDim type variable dims, check dimension - auto dims = input0.dims(); - for(int i=0; idata[i]) - true_index.push_back(i); - else - false_index.push_back(i); - } - - // turn true_index and false_index to tensors - Tensor* true_index_tensor = new Tensor(true_index); - Tensor* false_index_tensor = new Tensor(false_index); - - /* Step 3: Gather */ - { // True Scope - // Create new stuff - for (auto& input : net_op_t->inputs_) { - scope_t.NewVar(input); - if (input.type() != PARAMETER) { // gather and slice required - // Get Tensor and gather - Tensor* input_gather_ = scope_t.FindVar(input)->GetMutable(); - Tensor* input_full_ = scope.FindVar(input)->GetMutable(); - input_gather_ = Gather(input_full_, true_index_tensor); - } - } - - for (auto& output : net_op->outputs_) { - scope_t.NewVar(output); - } - - net_op_t.Run(); - } - - { // False Scope - // Create new stuff - for (auto& input : net_op_f->inputs_) { - scope_f.NewVar(input); - if (input.type() != PARAMETER) { // gather and slice required - // Get Tensor and gather - Tensor* input_gather_ = scope_f.FindVar(input)->GetMutable(); - Tensor* input_full_ = scope.FindVar(input)->GetMutable(); - input_gather_ = Gather(input_full_, false_index_tensor); - } - } - - for (auto& output : net_op->outputs_) { - scope_t.NewVar(output); - } - - net_op_f.Run(); - } - - /* Merge Output Together by scatter update */ - for (auto& ouput : outputs_) { - Tensor* output_t = scope_t->FindVar(output)->GetMutable(); - Tensor* output_f = scope_f->FindVar(output)->GetMutable(); - Tensor* output_tensor = scope->FindVar(output)->GetMutable(); - Scatter(output_t, output_tensor, true_index_tensor); - Scatter(output_f, output_tensor, false_index_tensor); - } -} - -} // namespace operators -} // namespace paddle - -REGISTER_OP(ifelse_op, - paddle::operators::IfElseOp, - paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); diff --git a/paddle/operators/switch_op.h b/paddle/operators/switch_op.h deleted file mode 100644 index f72726bce13c913ce2cabdc67e8b9bc1c23eeeca..0000000000000000000000000000000000000000 --- a/paddle/operators/switch_op.h +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "glog/logging.h" -#include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/ddim.h" -#include "paddle/operators/gather.h" - -namespace paddle { -namespace operators { - -using namespace paddle::framework; - -template -class CondOp final : public OperatorBase { -public: - void Init() override; - - /** - * InferShape must be called before Run. - */ - virtual void InferShape(const std::shared_ptr& scope) const override { - scope_t = scope.NewScope(); - scope_f = scope.NewScope(); - net_op_t->InferShape(scope_t); - net_op_f->InferShape(scope_f); - tensor_t = new Tensor(); - tensor_f = new Tensor(); - { // True branch - for (auto& input : net_op_t->Inputs()) { - auto var_name = input.second; - if (!scope_t.FindVar(var_name) { - scope_t.NewVar(var_name)->GetMutable(); - } - } - } - { // False branch - for (auto& input : net_op_f->Inputs()) { - auto var_name = input.second; - if (!scope_f.FindVar(var_name) { - scope_f.NewVar(var_name)->GetMutable(); - } - } - } - } - - virtual void Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const override { - auto* cond = context.Input("Cond"); - // Step 1: get the index - true_index.clear(); - false_index.clear(); - for(int i = 0; i < cond->dims()[0]; ++i) { - if (cond->data()[i]) - true_index.push_back(i); - else: - false_index.push_back(i); - } - framework::DDim dim_ = paddle::framework::make_ddim({0}); - dim_[0] = true_index.size(); - tensor_t->Resize(dim_); - // set value - for (int i = 0; i < dim_[0]; ++i) - tensor_t->mutable_data()[i] = true_index[i]; - dim_[0] = false_index.size(); - tensor_f->Resize(dim_); - // set value - for (int i = 0; i < dim_[0]; ++i) - tensor_f->mutable_data()[i] = false_index[i]; - - // Step 2: collect data by calling gather - { // True branch - for (auto& input : net_op_t->Inputs()) { - auto var_name = input.second; - // find Tensor - Tensor* Tensor_parent = scope.FindVar(var_name)->GetMutable(); - Tensor* Tensor_child = scope_t.FindVar(var_name)->GetMutable(); - Gather(dev_ctx.GetPlace(), tensor_parent, tensor_t, tensor_child); - } - - } - } - -private: - Scope* scope_t; - Scope* scope_f; - - // subnet_t - std::unique_ptr net_op_t; - // NetOp* net_op_t; - // subnet_f - std::unique_ptr net_op_f; - // NetOp* net_op_f; - - // T_index - vector true_index; - Tensor* tensor_t; - // F_index - vector false_index; - Tensor* tensor_f; -}; - -class CondOpMaker : public OpProtoAndCheckerMaker { -public: - IfElseOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Cond", "The condition, which is a bool vector"); - AddAttr("subnet_t", "The subnet network to be called when Cond[i] == true"); - AddAttr("subnet_f", "The subnet network to be called when Cond[i] == false"); - AddOutput("Out", "The output of if-else op"); - AddComment(R"DOC( -Sample dependent Cond Operator: -The equation is: Out[i] = subnet_t[i], if Cond[i] == true -Out[i] = subnet_t[i], if Cond[i] == false -)DOC"); - } -}; - -class CondGradientOp final : public OperatorBase { -public: - void Init() override; - - virtual void InferShape(const std::shared_ptr& scope) const override; - - virtual void Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const override; -}; - -} // namespace operators -} // namespace paddle