From 7683e35816f448351e4a4037b5b4c6f55e34835d Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Sun, 3 Sep 2017 23:17:43 +0000 Subject: [PATCH] cond op --- paddle/operators/cond_op.cc | 56 +++++++++++++++ paddle/operators/cond_op.h | 131 ++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 paddle/operators/cond_op.cc create mode 100644 paddle/operators/cond_op.h diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc new file mode 100644 index 00000000000..be5e0e6a5b1 --- /dev/null +++ b/paddle/operators/cond_op.cc @@ -0,0 +1,56 @@ +#include "paddle/operators/switch_op.h" + +namespace paddle { +namespace operators { + +void CondOp::InferShape(const std::shared_ptr& scope) const { + // Create two Nets + // Create two scopes + for (int i = 0; i < 2; ++i) + sub_scope.push_back(scope.NewScope()); + + for (int i = 0; i < 2; ++i) + sub_net_op_[i].InferShape(sub_scope[i]); + + for (int i = 0; i < 2; ++i) + tensor_index = new Tensor(); + + for (int i = 0; i < 2; ++i) + _index.push_back(vector()); + + for (int i = 0; i < 2; ++i) + { + // for (auto& input : net_op_[i]->Inputs()) { + for (auto& input : GetAttr>("True_inputs")) { + auto var_name = input.second; + // Create a new tensor in sub-scope for input-type tensor + sub_scope[i]->NewVar(var_name)->GetMutable(); + } + } +} + +class CondOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +public: + CondOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Cond", "The condition, which is a bool vector"); + AddInput("Xs", "Inputs of Subnets"); + AddAttr>("sub_inputs", "Inputs of the Whole Op, net op and so forth"); + AddAttr>("sub_outputs", "True Outputs needs merge"); + AddOutput("Outs", "The output of cond op"); + + AddComment(R"DOC( +Sample dependent Cond Operator: +The equation is: Out[i] = subnet_t[i], if Cond[i] == true +Out[i] = subnet_t[i], if Cond[i] == false +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(cond_op, + paddle::operators::CondOp, + paddle::operators::CondOpProtoAndCheckerMaker); + diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h new file mode 100644 index 00000000000..e9ae41b1919 --- /dev/null +++ b/paddle/operators/cond_op.h @@ -0,0 +1,131 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/ddim.h" +#include "paddle/operators/gather.h" +#include + +namespace paddle { +namespace operators { + +using namespace paddle::framework; + +template +class CondOp final : public OperatorBase { +public: + /** + * InferShape must be called before Run. + */ + void InferShape(const std::shared_ptr& scope) const override; + + // Set True Block + void set_truenet(std::unique_ptr net) { + sub_net_op_[0] = std::move(net); + } + + // Set False Block + void set_falsenet(std::unique_ptr net) { + sub_net_op_[1] = std::move(net); + } + + virtual void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + auto* cond = context.Input("Cond"); + // Step 1: get the true/false index at runtime + // _index[0]: vector, contains all index for cond[i] == true + // _index[1]: vector, contains all index for cond[i] == false + for(int i = 0; i < 2; ++i) + _index[i].clear(); + for(int i = 0; i < cond->dims()[0]; ++i) { + if (cond->data()[i]) + _index[0].push_back(i); + else + _index[1].push_back(i); + } + // put _index[0] and _index[1] into two tensors + // tensor_index[0] and tensor_index[1] + framework::DDim dim_ = paddle::framework::make_ddim({0}); + for(int i = 0; i < 2; ++i) { + dim_[0] = _index[i].size(); + int* tmp_ = _index[i]->mutable_data(dim_, CPUPlace()); + tensor_index[i]->Resize(dim_); + memcpy(tmp_, index_[i], dim_[0] * sizeof(int)); + } + + + // Step 2: collect data by calling gather + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& input : GetAttr>("sub_inputs")) { + auto var_name = input.second; + // find Tensor + Tensor* Tensor_parent = scope.FindVar(var_name)->GetMutable(); + Tensor* Tensor_child = sub_scope_[i].FindVar(var_name)->GetMutable(); + Gather(dev_ctx.GetPlace(), tensor_parent, tensor_index[i], tensor_child); + } + } + + // Step 3: run + for (int i = 0; i < 2; ++i) + sub_net_op_[i]->Run(sub_scope_[i], dev_ctx); + + // Step 4: merge output results + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& output : GetAttr>("sub_outputs")) { + auto var_name = output.second; + // find Tensor + Tensor* Tensor_parent = scope.FindVar(var_name)->GetMutable(); + Tensor* Tensor_child = sub_scope_[i].FindVar(var_name)->GetMutable(); + ScatterUpdate(dev_ctx.GetPlace(), tensor_child, tensor_index[i], tensor_parent); + } + } + } + +private: + // sub_scope_[0]: true scope + // sub_scope_[1]: false scope + std::vector sub_scope_; + + // sub_net_op_[0]: subnet_t + // sub_net_op_[1]: subnet_f + std::vector> sub_net_op_; + + // tensor_index[0]: True_index tensor + // tensor_index[1]: False_index; + std::vector tensor_index; + + // _index[0]: True_index; + // _index[1]: False_index; + vector > _index; +}; + +/* +class CondGradientOp final : public OperatorBase { +public: + void Init() override; + + virtual void InferShape(const std::shared_ptr& scope) const override; + + virtual void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override; +};*/ + +} // namespace operators +} // namespace paddle + -- GitLab