From b8e75c1f1a0b56993b3b1a528784e9e86d5a7277 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Tue, 12 Sep 2017 15:10:31 -0700 Subject: [PATCH] cond op --- paddle/operators/CMakeLists.txt | 2 + paddle/operators/cond_op.cc | 45 ++++ paddle/operators/cond_op.h | 232 ++++++++++++++++++ paddle/pybind/pybind.cc | 23 ++ python/paddle/v2/framework/op.py | 22 ++ .../paddle/v2/framework/tests/test_cond_op.py | 114 +++++++++ 6 files changed, 438 insertions(+) create mode 100644 paddle/operators/cond_op.cc create mode 100644 paddle/operators/cond_op.h create mode 100644 python/paddle/v2/framework/tests/test_cond_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f9ea25ab04..639ccd4052 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,12 +55,14 @@ set(DEPS_OPS minus_op mul_op recurrent_op + cond_op scale_op) op_library(identity_op DEPS scale_op) op_library(minus_op DEPS scale_op) op_library(mul_op DEPS math_function) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) +op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(scale_op DEPS net_op) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc new file mode 100644 index 0000000000..cb7fed7ebd --- /dev/null +++ b/paddle/operators/cond_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/cond_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class CondOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + CondOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Cond", "The condition, which is a bool vector"); + AddInput("Xs", "Inputs of Subnets").AsDuplicable(); + AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable(); + + AddOutput("SubScopes", "sub scopes for true and false branches"); + AddOutput("IndexTensors", "Index Tensors contains indices for true/false"); + + AddComment(R"DOC( +Sample dependent Cond Operator: +The equation is: Out[i] = subnet_t[i], if Cond[i] == true +Out[i] = subnet_t[i], if Cond[i] == false +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(cond_op, paddle::operators::CondOp, + paddle::operators::CondOpProtoAndCheckerMaker); diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h new file mode 100644 index 0000000000..b776f8ccd9 --- /dev/null +++ b/paddle/operators/cond_op.h @@ -0,0 +1,232 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "glog/logging.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/gather.h" +#include "paddle/operators/scatter.h" + +namespace paddle { +namespace operators { + +using namespace paddle::framework; + +class CondOp : public OperatorBase { + public: + CondOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { + index_.resize(2); + sub_net_op_.resize(2); + LOG(INFO) << "Initialization Done."; + } + + CondOp(const CondOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } + + void CreateScope(const Scope& scope) const { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); + auto sub_scopes = sub_scopes_var->GetMutable>(); + auto& sub_scope = scope.NewScope(); + sub_scopes->push_back(&sub_scope); + } + + void CreateIndexTensor(const Scope& scope) const { + auto index_tensors_var = scope.FindVar("IndexTensors"); + PADDLE_ENFORCE(index_tensors_var != nullptr, ""); + auto& index_tensors = + *index_tensors_var->GetMutable>(); + Tensor index_tensor; + index_tensors.push_back(&index_tensor); + } + + /** + * InferShape must be called before Run. + */ + void InferShape(const framework::Scope& scope) const override { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); + auto& sub_scopes = *sub_scopes_var->GetMutable>(); + // auto& index_tensors = + // *scope.FindVar("IndexTensors")->GetMutable>(); + + for (int i = 0; i < 2; ++i) { + // Create two sub scopes for true and false branches + // sub_scopes[0] for the true branch and sub_scopes[1] for the false + // branch + CreateScope(scope); + + // Create two tensors for true and false indices + // index_tensors[0] for the true branch and index_tensors[1] for the false + // branch + CreateIndexTensor(scope); + + for (auto& input : Inputs("Xs")) { + // Create a new tensor in sub-scope for input-type tensor + Variable* v = sub_scopes[i]->NewVar(input); + Tensor* sub_input = v->GetMutable(); + sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); + } + + // Inputs that do not require tailoring + /*for (auto& input : (*sub_net_op_[i]).Inputs()) { + // weights are located in the parent scope rather than sub scope + for (auto& var_name : input.second) { + if (!sub_scopes[i]->FindVar(var_name)) { + sub_scopes[i]->NewVar(var_name)->GetMutable(); + } + } + }*/ + + // Outputs + for (auto& output : (*sub_net_op_[i]).Outputs()) { + for (auto& var_name : output.second) { + sub_scopes[i]->NewVar(var_name); + } + } + + // each net calls InferShape + LOG(INFO) << "OK 3"; + sub_net_op_[i]->InferShape(*sub_scopes[i]); + LOG(INFO) << "OK 4"; + } + + for (auto& output : Outputs("Outs")) { + Tensor* tensor_t_out = + sub_scopes[0]->FindVar(output)->GetMutable(); + Tensor* tensor_f_out = + sub_scopes[1]->FindVar(output)->GetMutable(); + Tensor* tensor_out = scope.FindVar(output)->GetMutable(); + // check output size should be same + PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), + "Outputs not of the same shape"); + tensor_out->Resize(tensor_t_out->dims()); + } + LOG(INFO) << "OK 5"; + } + + // Set True Block + void set_truenet(std::unique_ptr net) { + sub_net_op_[0] = std::move(net); + } + + // Set False Block + void set_falsenet(std::unique_ptr net) { + sub_net_op_[1] = std::move(net); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + auto sub_scopes = scope.FindVar("SubScopes")->Get>(); + auto index_tensors = + scope.FindVar("IndexTensors")->Get>(); + + std::string cond_name = Input("Cond"); + Variable* cond_var = scope.FindVar(cond_name); + PADDLE_ENFORCE_NOT_NULL(cond_var) + const Tensor* cond = cond_var->GetMutable(); + + // Step 1: get the true/false index at runtime + // index_[0]: vector, contains all index for cond[i] == true + // index_[1]: vector, contains all index for cond[i] == false + for (int i = 0; i < 2; ++i) index_[i].clear(); + + const bool* cond_data = cond->data(); + for (int i = 0; i < cond->dims()[0]; ++i) { + if (cond_data[i]) + index_[0].push_back(i); + else + index_[1].push_back(i); + } + // put index_[0] and index_[1] into two tensors: + // index_tensor_[0] and index_tensor_[1] + framework::DDim dim = paddle::framework::make_ddim({0}); + for (int i = 0; i < 2; ++i) { + dim[0] = index_[i].size(); + int* tmp_ptr = + index_tensors[i]->mutable_data(dim, platform::CPUPlace()); + index_tensors[i]->Resize(dim); + memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); + } + + // Step 2: collect data by calling gather + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& input : Inputs("Xs")) { + // find Tensor + // Tensor* tensor_parent = scope.FindVar(input)->GetMutable(); + Variable* v = scope.FindVar(input); + Tensor* tensor_parent = v->GetMutable(); + // Tensor* tensor_child = + // sub_scope_[i].FindVar(input)->GetMutable(); + v = sub_scopes[i]->FindVar(input); + Tensor* tensor_child = v->GetMutable(); + Gather(dev_ctx.GetPlace(), tensor_parent, index_tensors[i], + tensor_child); + } + } + + // Step 3: run + for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); + + // Step 4: merge output results + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + // for (auto& output : GetAttr>("sub_outputs")) { + for (auto& output : Outputs("Outs")) { + // find Tensor + Variable* v = scope.FindVar(output); + Tensor* tensor_parent = v->GetMutable(); + v = sub_scopes[i]->FindVar(output); + Tensor* tensor_child = v->GetMutable(); + ScatterUpdate(dev_ctx.GetPlace(), tensor_child, index_tensors[i], + tensor_parent); + } + } + } + + private: + // sub_net_op_[0]: subnet_t + // sub_net_op_[1]: subnet_f + std::vector> sub_net_op_; + + // index_[0]: True_index; + // index_[1]: False_index; + mutable std::vector> index_; +}; + +/* +class CondGradientOp final : public OperatorBase { +public: + void Init() override; + + virtual void InferShape(const std::shared_ptr& scope) const +override; + + virtual void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override; +};*/ + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae..3eeae856fb 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -41,6 +41,7 @@ USE_OP(softmax); USE_OP(rowwise_add); USE_OP(fill_zeros_like); USE_NO_KERNEL_OP(recurrent); +USE_NO_KERNEL_OP(cond); USE_OP(gaussian_random); USE_OP(uniform_random); USE_OP(lookup_table); @@ -324,6 +325,28 @@ All parameter, weight, gradient are variables in Paddle. [](operators::RecurrentOp &self, const operators::NetOp &net) -> void { self.set_stepnet(net.Clone()); }); + // cond_op + py::class_(m, "CondOp") + .def_static("create", + [](py::bytes protobin) -> operators::CondOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto cond_op = OpRegistry::CreateOp(desc); + return static_cast(cond_op.release()); + }) + .def("set_truenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_truenet(net.Clone()); + }) + .def("set_falsenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_falsenet(net.Clone()); + }); + m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 9e665adad2..bddd4d8908 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -215,5 +215,27 @@ class __RecurrentOp__(object): return core.RecurrentOp.create(proto.SerializeToString()) +class __CondOp__(object): + __proto__ = None + type = 'cond_op' + + def __init__(self): + # cache recurrent_op's proto + if self.__proto__ is None: + for op_proto in get_all_op_protos(): + if op_proto.type == self.type: + self.__proto__ = op_proto + + def __call__(self, *args, **kwargs): + if self.type not in args and 'type' not in kwargs: + kwargs['type'] = self.type + # create proto + create_method = OpDescCreationMethod(self.__proto__) + proto = create_method(*args, **kwargs) + # create condop + return core.CondOp.create(proto.SerializeToString()) + + Operator = OperatorFactory() # The default global factory RecurrentOp = __RecurrentOp__() +CondOp = __CondOp__() diff --git a/python/paddle/v2/framework/tests/test_cond_op.py b/python/paddle/v2/framework/tests/test_cond_op.py new file mode 100644 index 0000000000..1fe5889b7f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_cond_op.py @@ -0,0 +1,114 @@ +import logging +import paddle.v2.framework.core as core +import unittest +import numpy as np +from paddle.v2.framework.op import Operator, CondOp + + +class PySimpleCond(object): + ''' + A simple implementation of dynamic if-else based on numpy + ''' + + def __init__(self): + array = [True] * 10 + for i in range(1, 10, 2): + array[i] = False + self.cond = np.array(array) + self.x = np.ones(shape=(10, 1)) + + def forward(self): + self.index_t = np.where(self.cond) + self.index_f = np.where(self.cond == False) + y_t = self.x[self.index_t] + y_f = self.x[self.index_f] + y_t = y_t * 2. + y_f = y_f * (-2.) + output = np.zeros(shape=(10, 1)) + output[self.index_t] = y_t + output[self.index_f] = y_f + return output + + +class PySimpleCondTest(unittest.TestCase): + def setUp(self): + self.condnn = PySimpleCond() + + def test_forward(self): + output = self.condnn.forward() + print 'output', output + + +def create_tensor(scope, name, shape, np_data): + tensor = scope.new_var(name).get_tensor() + tensor.set_dims(shape) + tensor.set(np_data, core.CPUPlace()) + return tensor + + +class TestCondOp(unittest.TestCase): + ''' + Test CondOp + + equation: + cond = [True, False, True, False, ...] + y[index_t] = x[index_t] * 2. + y[index_f] = x[index_f] * -2. + outputs: + y + ''' + + def setUp(self): + self.py_cond = PySimpleCond() + + def forward(self): + self.scope = core.Scope() + self.create_global_variables() + self.create_cond_op() + self.create_sub_net() + ctx = core.DeviceContext.create(core.CPUPlace()) + print 'running infer shape' + print self.scope.find_var("SubScopes") + self.condop.infer_shape(self.scope) + print 'ok 2' + self.condop.run(self.scope, ctx) + print 'ok 3' + return np.array(self.scope.find_var("Outs").get_tensor()) + + def create_global_variables(self): + x_np_data = self.py_cond.x + create_tensor(self.scope, "x", [10, 1], x_np_data) + cond_np_data = self.py_cond.cond + create_tensor(self.scope, "cond", [10, 1], x_np_data) + self.scope.new_var("SubScopes") + self.scope.new_var("IndexTensors") + self.scope.new_var("Outs") + + def create_cond_op(self): + self.condop = CondOp( + Cond="cond", + Xs=["x"], + Outs=['Out_final'], + SubScopes="SubScopes", + IndexTensors="IndexTensors") + + def create_sub_net(self): + truenet = core.Net.create() + scale_op_t = Operator("scale", X='X', Y='Out', scale=2.) + truenet.append_op(scale_op_t) + truenet.complete_add_op(True) + self.condop.set_truenet(truenet) + + falsenet = core.Net.create() + scale_op_t = Operator("scale", X='X', Y='Out', scale=-2.) + falsenet.append_op(scale_op_t) + falsenet.complete_add_op(True) + self.condop.set_falsenet(falsenet) + + def test_forward(self): + print 'test cond op forward' + py_output = self.forward() + + +if __name__ == "__main__": + unittest.main() -- GitLab