diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index cb7fed7ebd128037b6e3d76a043e705113b838e5..a3e4a2506f72e524a9d2a91fe8a849f863b5fb0d 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -13,15 +13,175 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/cond_op.h" + +#include +#include + #include "paddle/framework/op_registry.h" +#include "paddle/operators/gather.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/scatter.h" namespace paddle { namespace operators { -class CondOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +using Scope = framework::Scope; +using Variable = framework::Variable; +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +void CondOp::CreateScope(const Scope& scope) const { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); + auto sub_scopes = sub_scopes_var->GetMutable>(); + auto& sub_scope = scope.NewScope(); + sub_scopes->push_back(&sub_scope); +} + +void CondOp::CreateIndexTensor(const Scope& scope) const { + auto index_tensors_var = scope.FindVar("IndexTensors"); + PADDLE_ENFORCE(index_tensors_var != nullptr, ""); + auto& index_tensors = *index_tensors_var->GetMutable>(); + index_tensors.push_back(Tensor()); +} + +void CondOp::InferShape(const Scope& scope) const { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); + auto& sub_scopes = *sub_scopes_var->GetMutable>(); + + for (int i = 0; i < 2; ++i) { + // Create two sub scopes for true and false branches + // sub_scopes[0] for the true branch and sub_scopes[1] for the false + // branch + CreateScope(scope); + + // Create two tensors for true and false indices + // index_tensors[0] for the true branch and index_tensors[1] for the false + // branch + CreateIndexTensor(scope); + + PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty"); + for (auto& input : Inputs("Xs")) { + // Create a new tensor in sub-scope for input-type tensor + Variable* v = sub_scopes[i]->NewVar(input); + Tensor* sub_input = v->GetMutable(); + sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); + } + + for (auto& output : (*sub_net_op_[i]).Outputs()) { + for (auto& var_name : output.second) { + sub_scopes[i]->NewVar(var_name); + } + } + + // each net calls InferShape + sub_net_op_[i]->InferShape(*sub_scopes[i]); + } + + for (auto& output : Outputs("Outs")) { + Tensor* tensor_t_out = sub_scopes[0]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); + Tensor* tensor_f_out = sub_scopes[1]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "True output should be NULL"); + + auto* tensor_out_var = scope.FindVar(output); + PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); + Tensor* tensor_out = tensor_out_var->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should be NULL"); + // check output size should be same + PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), + "Outputs not of the same shape"); + tensor_out->Resize(tensor_t_out->dims()); + tensor_out->mutable_data(tensor_out->dims(), platform::CPUPlace()); + } +} + +void CondOp::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + auto sub_scopes = scope.FindVar("SubScopes")->Get>(); + auto index_tensors = + scope.FindVar("IndexTensors")->Get>(); + + std::string cond_name = Input("Cond"); + Variable* cond_var = scope.FindVar(cond_name); + PADDLE_ENFORCE_NOT_NULL(cond_var); + const Tensor* cond = cond_var->GetMutable(); + + // Step 1: get the true/false index at runtime + // index_[0]: vector, contains all index for cond[i] == true + // index_[1]: vector, contains all index for cond[i] == false + for (int i = 0; i < 2; ++i) index_[i].clear(); + + const int* cond_data = cond->data(); + for (int i = 0; i < cond->dims()[0]; ++i) { + if (cond_data[i]) + index_[0].push_back(i); + else + index_[1].push_back(i); + } + + // put index_[0] and index_[1] into two tensors: + // index_tensor_[0] and index_tensor_[1] + DDim dim = paddle::framework::make_ddim({0}); + for (int i = 0; i < 2; ++i) { + dim[0] = index_[i].size(); + int* tmp_ptr = + index_tensors[i].mutable_data(dim, platform::CPUPlace()); + index_tensors[i].Resize(dim); + memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); + } + + // Step 2: collect data by calling gather + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& input : Inputs("Xs")) { + // find Tensor + Variable* v = scope.FindVar(input); + PADDLE_ENFORCE_NOT_NULL(v); + Tensor* tensor_parent = v->GetMutable(); + + v = sub_scopes[i]->FindVar(input); + PADDLE_ENFORCE_NOT_NULL(v); + Tensor* tensor_child = v->GetMutable(); + + // Resize child + DDim dim = tensor_child->dims(); + dim[0] = index_[i].size(); + tensor_child->Resize(dim); + tensor_child->mutable_data(dim, platform::CPUPlace()); + + Gather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], + tensor_child); + } + } + + // Step 3: run + for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); + + // Step 4: merge output results + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& output : Outputs("Outs")) { + // find Tensor + Variable* v = scope.FindVar(output); + PADDLE_ENFORCE_NOT_NULL(v); + Tensor* tensor_parent = v->GetMutable(); + + v = sub_scopes[i]->FindVar(output); + PADDLE_ENFORCE_NOT_NULL(v); + Tensor* tensor_child = v->GetMutable(); + + ScatterUpdate(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + tensor_parent); + } + } +} + +class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { public: - CondOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + CondOpProtoAndCheckerMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Cond", "The condition, which is a bool vector"); AddInput("Xs", "Inputs of Subnets").AsDuplicable(); @@ -41,5 +201,5 @@ Out[i] = subnet_t[i], if Cond[i] == false } // namespace operators } // namespace paddle -REGISTER_OP_WITHOUT_GRADIENT(cond_op, paddle::operators::CondOp, +REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp, paddle::operators::CondOpProtoAndCheckerMaker); diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h index b776f8ccd92a8be643199221d3b687e3ed36b6a2..27a6e9e3c30a25f16388a801bd8138534311bd11 100644 --- a/paddle/operators/cond_op.h +++ b/paddle/operators/cond_op.h @@ -19,22 +19,19 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" -#include "paddle/operators/gather.h" -#include "paddle/operators/scatter.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { -using namespace paddle::framework; - -class CondOp : public OperatorBase { +class CondOp : public framework::OperatorBase { public: - CondOp(const std::string& type, const VariableNameMap& inputs, - const VariableNameMap& outputs, const AttributeMap& attrs) + CondOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { index_.resize(2); sub_net_op_.resize(2); - LOG(INFO) << "Initialization Done."; } CondOp(const CondOp& o) @@ -44,87 +41,14 @@ class CondOp : public OperatorBase { PADDLE_THROW("Not implemented"); } - void CreateScope(const Scope& scope) const { - auto sub_scopes_var = scope.FindVar("SubScopes"); - PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); - auto sub_scopes = sub_scopes_var->GetMutable>(); - auto& sub_scope = scope.NewScope(); - sub_scopes->push_back(&sub_scope); - } + void CreateScope(const framework::Scope& scope) const; - void CreateIndexTensor(const Scope& scope) const { - auto index_tensors_var = scope.FindVar("IndexTensors"); - PADDLE_ENFORCE(index_tensors_var != nullptr, ""); - auto& index_tensors = - *index_tensors_var->GetMutable>(); - Tensor index_tensor; - index_tensors.push_back(&index_tensor); - } + void CreateIndexTensor(const framework::Scope& scope) const; /** * InferShape must be called before Run. */ - void InferShape(const framework::Scope& scope) const override { - auto sub_scopes_var = scope.FindVar("SubScopes"); - PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); - auto& sub_scopes = *sub_scopes_var->GetMutable>(); - // auto& index_tensors = - // *scope.FindVar("IndexTensors")->GetMutable>(); - - for (int i = 0; i < 2; ++i) { - // Create two sub scopes for true and false branches - // sub_scopes[0] for the true branch and sub_scopes[1] for the false - // branch - CreateScope(scope); - - // Create two tensors for true and false indices - // index_tensors[0] for the true branch and index_tensors[1] for the false - // branch - CreateIndexTensor(scope); - - for (auto& input : Inputs("Xs")) { - // Create a new tensor in sub-scope for input-type tensor - Variable* v = sub_scopes[i]->NewVar(input); - Tensor* sub_input = v->GetMutable(); - sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); - } - - // Inputs that do not require tailoring - /*for (auto& input : (*sub_net_op_[i]).Inputs()) { - // weights are located in the parent scope rather than sub scope - for (auto& var_name : input.second) { - if (!sub_scopes[i]->FindVar(var_name)) { - sub_scopes[i]->NewVar(var_name)->GetMutable(); - } - } - }*/ - - // Outputs - for (auto& output : (*sub_net_op_[i]).Outputs()) { - for (auto& var_name : output.second) { - sub_scopes[i]->NewVar(var_name); - } - } - - // each net calls InferShape - LOG(INFO) << "OK 3"; - sub_net_op_[i]->InferShape(*sub_scopes[i]); - LOG(INFO) << "OK 4"; - } - - for (auto& output : Outputs("Outs")) { - Tensor* tensor_t_out = - sub_scopes[0]->FindVar(output)->GetMutable(); - Tensor* tensor_f_out = - sub_scopes[1]->FindVar(output)->GetMutable(); - Tensor* tensor_out = scope.FindVar(output)->GetMutable(); - // check output size should be same - PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), - "Outputs not of the same shape"); - tensor_out->Resize(tensor_t_out->dims()); - } - LOG(INFO) << "OK 5"; - } + void InferShape(const framework::Scope& scope) const override; // Set True Block void set_truenet(std::unique_ptr net) { @@ -137,74 +61,7 @@ class CondOp : public OperatorBase { } void Run(const framework::Scope& scope, - const platform::DeviceContext& dev_ctx) const override { - auto sub_scopes = scope.FindVar("SubScopes")->Get>(); - auto index_tensors = - scope.FindVar("IndexTensors")->Get>(); - - std::string cond_name = Input("Cond"); - Variable* cond_var = scope.FindVar(cond_name); - PADDLE_ENFORCE_NOT_NULL(cond_var) - const Tensor* cond = cond_var->GetMutable(); - - // Step 1: get the true/false index at runtime - // index_[0]: vector, contains all index for cond[i] == true - // index_[1]: vector, contains all index for cond[i] == false - for (int i = 0; i < 2; ++i) index_[i].clear(); - - const bool* cond_data = cond->data(); - for (int i = 0; i < cond->dims()[0]; ++i) { - if (cond_data[i]) - index_[0].push_back(i); - else - index_[1].push_back(i); - } - // put index_[0] and index_[1] into two tensors: - // index_tensor_[0] and index_tensor_[1] - framework::DDim dim = paddle::framework::make_ddim({0}); - for (int i = 0; i < 2; ++i) { - dim[0] = index_[i].size(); - int* tmp_ptr = - index_tensors[i]->mutable_data(dim, platform::CPUPlace()); - index_tensors[i]->Resize(dim); - memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); - } - - // Step 2: collect data by calling gather - for (int i = 0; i < 2; ++i) { - // i= 0/i for True and False branches respectively - for (auto& input : Inputs("Xs")) { - // find Tensor - // Tensor* tensor_parent = scope.FindVar(input)->GetMutable(); - Variable* v = scope.FindVar(input); - Tensor* tensor_parent = v->GetMutable(); - // Tensor* tensor_child = - // sub_scope_[i].FindVar(input)->GetMutable(); - v = sub_scopes[i]->FindVar(input); - Tensor* tensor_child = v->GetMutable(); - Gather(dev_ctx.GetPlace(), tensor_parent, index_tensors[i], - tensor_child); - } - } - - // Step 3: run - for (int i = 0; i < 2; ++i) sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); - - // Step 4: merge output results - for (int i = 0; i < 2; ++i) { - // i= 0/i for True and False branches respectively - // for (auto& output : GetAttr>("sub_outputs")) { - for (auto& output : Outputs("Outs")) { - // find Tensor - Variable* v = scope.FindVar(output); - Tensor* tensor_parent = v->GetMutable(); - v = sub_scopes[i]->FindVar(output); - Tensor* tensor_child = v->GetMutable(); - ScatterUpdate(dev_ctx.GetPlace(), tensor_child, index_tensors[i], - tensor_parent); - } - } - } + const platform::DeviceContext& dev_ctx) const override; private: // sub_net_op_[0]: subnet_t @@ -216,17 +73,5 @@ class CondOp : public OperatorBase { mutable std::vector> index_; }; -/* -class CondGradientOp final : public OperatorBase { -public: - void Init() override; - - virtual void InferShape(const std::shared_ptr& scope) const -override; - - virtual void Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const override; -};*/ - } // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3eeae856fb9b368f234f278cdee60b2a9df4eb8d..34214ad2b3eb8c39724dd479f0428c5ef9039300 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/cond_op.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index bddd4d8908ae914a79aa9b7a46c63d92fcd9b97b..1469d207d474092decbf88c76c6f09cdf25f81fd 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -217,7 +217,7 @@ class __RecurrentOp__(object): class __CondOp__(object): __proto__ = None - type = 'cond_op' + type = "cond" def __init__(self): # cache recurrent_op's proto @@ -227,8 +227,8 @@ class __CondOp__(object): self.__proto__ = op_proto def __call__(self, *args, **kwargs): - if self.type not in args and 'type' not in kwargs: - kwargs['type'] = self.type + if self.type not in args and "type" not in kwargs: + kwargs["type"] = self.type # create proto create_method = OpDescCreationMethod(self.__proto__) proto = create_method(*args, **kwargs) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 6b22c0008210b492d00dee42e967ca14d0948b20..a2e3e978c72c3b0ed6a67342764ceb17abd73822 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -27,6 +27,7 @@ py_test(test_operator SRCS test_operator.py) py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) +py_test(test_cond_op SRCS test_cond_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_lookup_table SRCS test_lookup_table.py) diff --git a/python/paddle/v2/framework/tests/test_cond_op.py b/python/paddle/v2/framework/tests/test_cond_op.py index 1fe5889b7f23c7e9a8388d5d05e5fe2a94d8558f..37177ae0b2482517c4183969c8ef0670f2b3de89 100644 --- a/python/paddle/v2/framework/tests/test_cond_op.py +++ b/python/paddle/v2/framework/tests/test_cond_op.py @@ -11,15 +11,15 @@ class PySimpleCond(object): ''' def __init__(self): - array = [True] * 10 + array = [1] * 10 for i in range(1, 10, 2): - array[i] = False + array[i] = 0 self.cond = np.array(array) self.x = np.ones(shape=(10, 1)) def forward(self): - self.index_t = np.where(self.cond) - self.index_f = np.where(self.cond == False) + self.index_t = np.where(self.cond == 1) + self.index_f = np.where(self.cond == 0) y_t = self.x[self.index_t] y_f = self.x[self.index_f] y_t = y_t * 2. @@ -36,7 +36,6 @@ class PySimpleCondTest(unittest.TestCase): def test_forward(self): output = self.condnn.forward() - print 'output', output def create_tensor(scope, name, shape, np_data): @@ -67,47 +66,50 @@ class TestCondOp(unittest.TestCase): self.create_cond_op() self.create_sub_net() ctx = core.DeviceContext.create(core.CPUPlace()) - print 'running infer shape' - print self.scope.find_var("SubScopes") self.condop.infer_shape(self.scope) - print 'ok 2' self.condop.run(self.scope, ctx) - print 'ok 3' - return np.array(self.scope.find_var("Outs").get_tensor()) + return np.array(self.scope.find_var("Out").get_tensor()) def create_global_variables(self): x_np_data = self.py_cond.x - create_tensor(self.scope, "x", [10, 1], x_np_data) - cond_np_data = self.py_cond.cond - create_tensor(self.scope, "cond", [10, 1], x_np_data) + create_tensor(self.scope, "X", [10, 1], x_np_data) + cond_np_data = self.py_cond.cond.astype("int32") + create_tensor(self.scope, "cond", [10, 1], cond_np_data) self.scope.new_var("SubScopes") self.scope.new_var("IndexTensors") - self.scope.new_var("Outs") + self.scope.new_var("Out") def create_cond_op(self): self.condop = CondOp( Cond="cond", - Xs=["x"], - Outs=['Out_final'], + Xs=["X"], + Outs=["Out"], SubScopes="SubScopes", IndexTensors="IndexTensors") def create_sub_net(self): truenet = core.Net.create() - scale_op_t = Operator("scale", X='X', Y='Out', scale=2.) + scale_op_t = Operator("scale", X='X', Out='Out', scale=2.) truenet.append_op(scale_op_t) truenet.complete_add_op(True) self.condop.set_truenet(truenet) falsenet = core.Net.create() - scale_op_t = Operator("scale", X='X', Y='Out', scale=-2.) + scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.) falsenet.append_op(scale_op_t) falsenet.complete_add_op(True) self.condop.set_falsenet(falsenet) def test_forward(self): print 'test cond op forward' - py_output = self.forward() + pd_output = self.forward() + py_output = self.py_cond.forward() + print 'pd_output', pd_output + print + print 'py_output', py_output + self.assertEqual(pd_output.shape, py_output.shape) + print 'test passed' + return 0 if __name__ == "__main__":