diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index daee55b7f9adfffdf709ed2b5b0d957c7ca1aea4..ec7f1446cfb74842af7d0c7152bebf58619f3861 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -198,6 +198,10 @@ identity_projection .. autoclass:: paddle.v2.layer.identity_projection :noindex: +slice_projection +------------------- +.. autoclass:: paddle.v2.layer.slice_projection + :noindex: table_projection ---------------- diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py index 85cb399590f7a5e7e73285ca87c49ea5f24afb32..572a61e4ccaa9ef3d03a60d916e80eab907c6d88 100644 --- a/go/pserver/client/c/test/test_train.py +++ b/go/pserver/client/c/test/test_train.py @@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing import paddle.v2.master as master import os import cPickle as pickle +from paddle.v2.reader.creator import cloud_reader etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") -etcd_endpoint = "http://" + etcd_ip + ":2379" -print "connecting to master, etcd endpoints: ", etcd_endpoint -master_client = master.client(etcd_endpoint, 5, 64) - - -def cloud_reader(): - global master_client - master_client.set_dataset( - ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30) - while 1: - r, e = master_client.next_record() - if not r: - if e != -2: # other errors - print "get record error:", e - break - yield pickle.loads(r) +etcd_endpoints = "http://" + etcd_ip + ":2379" +print "etcd endpoints: ", etcd_endpoints def main(): @@ -49,7 +36,7 @@ def main(): parameters=parameters, update_equation=optimizer, is_local=False, - pserver_spec=etcd_endpoint, + pserver_spec=etcd_endpoints, use_etcd=True) # event_handler to print training and testing info @@ -75,7 +62,11 @@ def main(): trainer.train( reader=paddle.batch( paddle.reader.shuffle( - cloud_reader, buf_size=500), batch_size=2), + cloud_reader( + ["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"], + etcd_endpoints), + buf_size=500), + batch_size=2), feeding={'x': 0, 'y': 1}, event_handler=event_handler, diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9d172640493d3c884bb5c142b37829eafe9ee928..12a3a00bba35d476fca9c9fb47ac20b87e6f53f2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -32,4 +32,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) cc_library(net SRCS net.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) + +cc_library(backward SRCS backward.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS backward) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc new file mode 100644 index 0000000000000000000000000000000000000000..0da11b91a7fe4a98e0832f70095c3200956ff001 --- /dev/null +++ b/paddle/framework/backward.cc @@ -0,0 +1,178 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/backward.h" +#include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +static bool AllInSet(const std::vector& names, + const std::string& suffix, + const std::unordered_set& set) { + for (auto& name : names) { + if (set.find(name + suffix) == set.end()) { + return false; + } + } + return true; +} + +static std::shared_ptr NOP() { + auto net_op = std::make_shared(); + net_op->type_ = "@NOP@"; + net_op->CompleteAddOp(); + return net_op; +} + +// Get backward operator from a forward operator, recursively implementation. +// +// no_grad_names the gradient variable names without gradient calculating. +// +// uniq_id is a unique index used inside recursively calling BackwardRecursive. +// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through +// recursive calling. +// +// returns The backward operator. For simple situation, it is a simple +// operator. For complex situation, it is a NetOp. +// +// See Backward.h for details +static std::shared_ptr BackwardRecursive( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id); +std::shared_ptr BackwardRecursive( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id) { + // If all input gradients of forwarding operator do not need to calculate, + // just return an NOP. Not return null ptr because NOP does not take + // too much time for calculation, but it is useful for simplifying logic. + if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + return NOP(); + } + + // All output gradients of forwarding operator do not need to calculate. Then + // all input gradients cannot be computed at all, and we put them into + // `no_grad_names` set. Return an NOP. + if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + for (auto& name : forwardOp.inputs_) { + // Mark all input is not need + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + return NOP(); + } + + // Returned gradient network + auto net = std::make_shared(); + + if (forwardOp.IsNetOp()) { + // Because forwardOp is a net op, it can static_cast. + auto& forwardNet = static_cast(forwardOp); + + // Map from output gradient variable name to operator's indices in backward + // net. That operator generates that variable. + std::unordered_map> dup_output_ops; + + size_t local_op_id = 0; + // reversely travel forwardNet + for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); + ++it, ++local_op_id) { + auto fwd = *it; + auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); + net->AddOp(bwd); + for (auto& out : bwd->outputs_) { + dup_output_ops[out].emplace_back(local_op_id); + } + } + // Get unique ID for this method. + auto uid = uniq_id++; + // TODO(dzh): more comment + using Pos = std::pair>; + std::list insert_position; + for (auto& dup_output_op : dup_output_ops) { + const std::string& name = dup_output_op.first; + auto& dup_op = dup_output_op.second; + if (dup_op.size() == 1) continue; + std::vector dup_outputs; + + for (size_t i = 0; i < dup_op.size(); ++i) { + auto op_offset = dup_op[i]; + dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + + std::to_string(i)); + net->ops_[op_offset]->Rename(name, dup_outputs.back()); + } + insert_position.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "add", {dup_outputs}, {name}, + {{"input_format", + std::vector{0, static_cast(dup_outputs.size())}}})}); + } + + insert_position.sort( + [](const Pos& l, const Pos& r) { return l.first > r.first; }); + + for (auto& pos : insert_position) { + net->InsertOp(pos.first + 1, pos.second); + } + + } else { + std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + for (std::string& grad_input : grad_op->inputs_) { + if (no_grad_names.count(grad_input)) { + std::string prefix = grad_input.substr( + 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); + grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + + // If part of input gradient of that operator is not calculated, fill + // zero variables to that input gradient. + net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, + {grad_input}, {})); + } + } + + for (std::string& grad_output : grad_op->outputs_) { + if (no_grad_names.count(grad_output)) { + grad_output = OperatorBase::EMPTY_VAR_NAME(); + } + } + + if (net->ops_.empty()) { // Current no aux op is added to network + return grad_op; + } + net->AddOp(grad_op); + } + net->type_ = "@GENERATED_BACKWARD@"; + net->CompleteAddOp(); + return net; +} + +// See header for comments +std::shared_ptr Backward( + const OperatorBase& forwardOp, + const std::unordered_set& no_grad_vars) { + std::unordered_set no_grad_names; + no_grad_names.reserve(no_grad_vars.size()); + + for (auto& name : no_grad_vars) { + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + size_t uid = 0; + return BackwardRecursive(forwardOp, no_grad_names, uid); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h new file mode 100644 index 0000000000000000000000000000000000000000..c181919dc165cf0b49362f85e22ceb4131bbd387 --- /dev/null +++ b/paddle/framework/backward.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "operator.h" +namespace paddle { +namespace framework { + +// Create the backward operator from a forward operator. +// TODO(yuyang18): Add more API reference comment. +extern std::shared_ptr Backward( + const OperatorBase& forwardOp, + const std::unordered_set& no_grad_vars); +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md new file mode 100644 index 0000000000000000000000000000000000000000..74c001b06a9e7b2279abf998604f2acf1b1168e4 --- /dev/null +++ b/paddle/framework/backward.md @@ -0,0 +1,38 @@ +## Operator/expression 's Backward + +### Motivation + +In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass. + +### Implement : gradient operator registry + +| | forward operator | backward operator | +| ---------------------- | ---------------- | -------------------------------- | +| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | +| **Operator::outputs_** | Outputs | InputGradients | + +Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute. + +We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself. + +grad_op_builder(fengjiayi) + +### Implement : Backward network + +given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`. + +1. bla bla bla (yuyang) + +2. NetOp + + when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name. + + We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable. + + ![./images/duplicate_op]() + + Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead. + +![./images/duplicate_op2]() + +​ Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b095c2c3d5dbf21b5ea70e17475a4aaad9b1db44 --- /dev/null +++ b/paddle/framework/backward_test.cc @@ -0,0 +1,389 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/backward.h" + +#include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +class EmptyOp : public OperatorBase { + public: + void InferShape(const Scope &scope) const override {} + void Run(const Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { + public: + RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input X of Add").IgnoreGradient(); + AddInput("b", "Bias of Add").IgnoreGradient(); + AddOutput("Out", "Out of Add").IgnoreGradient(); + AddComment("Add Op"); + } +}; + +class MulOpMaker : public OpProtoAndCheckerMaker { + public: + MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("A", "A"); + AddInput("B", "B"); + AddOutput("Out", "Out"); + AddComment("Mul"); + } +}; + +class SigmoidOpMaker : public OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X"); + AddOutput("Y", "Y"); + AddComment("Sigmoid"); + } +}; + +class NoGradOpMaker : public OpProtoAndCheckerMaker { + public: + NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X input"); + AddOutput("Y", "Y output"); + AddComment("NoGradOp, same input output. no Grad"); + } +}; + +class FcOp : public NetOp { + public: + void Init() override { + AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, + {Output("mul_result")}, {})); + auto b_name = Input("b"); + std::string before_act = "mul_result"; + if (b_name != EMPTY_VAR_NAME()) { + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, + {Output("add_result")}, {})); + before_act = "add_result"; + } else { + auto out_varname = Output("add_result"); + if (out_varname != EMPTY_VAR_NAME()) { + this->Rename(out_varname, EMPTY_VAR_NAME()); + } + } + + AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")}, + {})); + CompleteAddOp(false); + } +}; + +class FcOpMaker : public OpProtoAndCheckerMaker { + public: + FcOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x"); + AddInput("W", "w"); + AddInput("b", "b"); + AddOutput("mul_result", "").SetTemporary(); + AddOutput("add_result", "").SetTemporary(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class ManyOutputOpMaker : public OpProtoAndCheckerMaker { + public: + ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("y", "y"); + AddOutput("z", "z"); + AddComment(""); + } +}; + +class FillZeroOpMaker : public OpProtoAndCheckerMaker { + public: + FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("out", "out"); + AddComment(""); + } +}; + +class AddOpMaker : public OpProtoAndCheckerMaker { + public: + AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x").SetMultiple(); + AddOutput("Y", "y"); + AddComment(""); + } +}; +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +using EnforceNotMet = paddle::platform::EnforceNotMet; +REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); +REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); +REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); +REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); +REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); +REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); +REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); +REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); +REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); +REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); + +TEST(Backward, simple_op_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ(1UL, gop->inputs_.size()); + ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); + ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); + + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), + gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, simple_op_not_need_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); + auto gop = f::Backward(*fwd, {"X"}); + ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), + "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + gop->outputs_.end()); + + auto no_input_gop = f::Backward(*fwd, {"X", "b"}); + ASSERT_NE(no_input_gop, nullptr); + ASSERT_TRUE(no_input_gop->IsNetOp()); + ASSERT_EQ(0UL, std::static_pointer_cast(no_input_gop)->ops_.size()); +} + +TEST(Backward, net_fc_backward_normal) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(3UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_add = *net->ops_[1]; + ASSERT_EQ("rowwise_add_grad", d_add.type_); + + f::OperatorBase &d_mul = *net->ops_[2]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_fc_backward_not_have_b) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, + {"mul_result", "add_result", "tmp"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(2UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_mul = *net->ops_[1]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_input_of_network_not_need_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, + {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, + {"mul_tmp_1", "add_tmp_1", "hidden1"}, {})); + net.CompleteAddOp(); + auto bwd = Backward(net, {"X"}); // X@GRAD is not need. + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + + std::unordered_set all_output = std::unordered_set( + bwd_net->outputs_.begin(), bwd_net->outputs_.end()); + all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + + for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { + ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + } + + // Not Generated X + ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + + ASSERT_EQ(2UL, bwd_net->ops_.size()); + ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + ASSERT_EQ(3UL, first_fc_grad->ops_.size()); + ASSERT_EQ( + f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, net_shared_weight) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); + net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); + net.CompleteAddOp(); + + auto bwd = f::Backward(net, {}); + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + ASSERT_EQ(3UL, bwd_net->ops_.size()); + ASSERT_EQ("add", bwd_net->ops_[2]->type_); +} + +TEST(Backward, op_register_grad_not_for_network) { + auto fwd = f::OpRegistry::CreateOp( + "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, + {{"temporary_index", std::vector{0, 1}}}); + + ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); +} + +TEST(Backward, op_all_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"X", "b"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, op_all_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"Out"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, op_part_of_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); + auto backward = f::Backward(*fwd, {"Z"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 2UL); + + auto &fill_zero = *net->ops_[0]; + ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ(1UL, fill_zero.inputs_.size()); + ASSERT_EQ("Z", fill_zero.inputs_[0]); + ASSERT_EQ(1UL, fill_zero.outputs_.size()); + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]); + + auto &d_many_out = *net->ops_[1]; + ASSERT_EQ("many_output_op_grad", d_many_out.type_); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), + d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, op_part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); + auto &grad_mul = *backward; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("A"), "a"); + ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("Out"), "out"); +} + +TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, + {"mul_out1", "add_out1", "out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, + {"mul_out2", "tmp_out2", "out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, + {"mul_out3", "tmp_out3", "out3"}, {})); + net.CompleteAddOp(); + auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); + ASSERT_TRUE(backward->IsNetOp()); + auto bwd_net = static_cast(backward.get()); + ASSERT_EQ(bwd_net->ops_.size(), 3UL); + auto &grad_fc = *bwd_net->ops_[0]; + EXPECT_EQ(grad_fc.inputs_.size(), + 3UL /* external input number */ + + 1UL /* external output number*/ + + 1UL /* number of gradient of external output*/ + - 1UL /*ignoreGradient varable number*/ + + 2U /* internal variable number*/); + EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ + + 2UL /* input number of rowwise_add */ + + 1UL /* input number of sigmod */); + EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); + + /* + EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + + EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ(grad_fc.Input("X"), "out2"); + EXPECT_EQ(grad_fc.Input("W"), "w3"); + EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3"); + EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3"); + EXPECT_EQ(grad_fc.Input("Out"), "out3"); + */ +} diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 6235be75f27dadb65de663ff1b3caf26a649f6cb..dd686cc78246f06cdc3ec7d013086863d7e8fac0 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,7 +20,7 @@ namespace framework { OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); + std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); @@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, } void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_->type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); + const OpProto& op_proto = OpRegistry::protos().at(op_.type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); const std::vector& in_format = - op_->attrs_.count("input_format") - ? op_->GetAttr>("input_format") + op_.attrs_.count("input_format") + ? op_.GetAttr>("input_format") : std::vector(); const std::vector& out_format = - op_->attrs_.count("output_format") - ? op_->GetAttr>("output_format") + op_.attrs_.count("output_format") + ? op_.GetAttr>("output_format") : std::vector(); for (const auto& var : op_proto.inputs()) { arg_list_.emplace_back( @@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } (*varmap)[var_name] = idx++; size_t pre_sz = in_out.size(); - auto base_it = - arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::back_inserter(in_out)); if (is_grad) { @@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_->attrs_; + grad_op->attrs_ = op_.attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); VarIndexMap* grad_varmap = new VarIndexMap(); diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index 2ecf39479b4f4a51f89cd500caf851897df0e599..cc7a76f3726e00a08fbe06bca4c9b9f5bad466b4 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -29,7 +29,7 @@ class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpBuilder(const OperatorBase* op) : op_(op) {} + GradOpBuilder(const OperatorBase& op) : op_(op) {} OperatorBase* Build(); private: @@ -40,7 +40,7 @@ class GradOpBuilder { std::vector& format, VarIndexMap* varmap, int& idx, bool is_grad) const; void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase* op_; + const OperatorBase& op_; std::vector> arg_list_; }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 288a7841cd7c9212d8fa230e38d49dfc26e76256..e9cf3b9798db2cbfb8d26259ae9a6741fbae8278 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -11,7 +11,7 @@ namespace framework { TEST(GradOpBuilder, AddTwo) { std::shared_ptr add_op( OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); diff --git a/paddle/framework/images/duplicate_op.graffle b/paddle/framework/images/duplicate_op.graffle new file mode 100644 index 0000000000000000000000000000000000000000..5979f792e252f028a615729215529c2be42d9165 Binary files /dev/null and b/paddle/framework/images/duplicate_op.graffle differ diff --git a/paddle/framework/images/duplicate_op.png b/paddle/framework/images/duplicate_op.png new file mode 100644 index 0000000000000000000000000000000000000000..f299c5d37f260a1bb0daec886f0a4ee1c1f31c92 Binary files /dev/null and b/paddle/framework/images/duplicate_op.png differ diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle new file mode 100644 index 0000000000000000000000000000000000000000..2b658085d6a55d368c320051ba7f94ec2900f13c Binary files /dev/null and b/paddle/framework/images/duplicate_op2.graffle differ diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png new file mode 100644 index 0000000000000000000000000000000000000000..c5588015d1450fd8c1bda3580680d884494868bb Binary files /dev/null and b/paddle/framework/images/duplicate_op2.png differ diff --git a/paddle/framework/net.h b/paddle/framework/net.h index fc98080b178bbe2af794e7a9387802f8ed764712..acf1a69da9fd8adce1bd89367c882eade052e725 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -68,9 +68,18 @@ class NetOp : public OperatorBase { */ void AddOp(const std::shared_ptr& op) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); ops_.push_back(op); } + void InsertOp(size_t pos, const std::shared_ptr& op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot InsertOp when this network is sealed"); + PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); + PADDLE_ENFORCE(pos <= ops_.size(), "Out of range"); + ops_.insert(ops_.begin() + pos, op); + } + void CompleteAddOp(bool calculate = true); std::string DebugString() const override; diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 3392b033c50da51f1912437e0c45bef62c0d11d4..f32e456e5d142bf8203f9ec03e8059772c4f5c99 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,11 +3,6 @@ #include #include -USE_OP(add_two); -USE_OP(mul); -USE_OP(sigmoid); -USE_OP(softmax); - namespace paddle { namespace framework { @@ -25,6 +20,13 @@ class TestOp : public OperatorBase { } }; +class EmptyOp : public OperatorBase { + public: + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + template void AssertSameVectorWithoutOrder(const std::vector& expected, const std::vector& actual) { @@ -71,20 +73,17 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -//! TODO(yuyang18): Refine Backward Op. -// TEST(AddBackwardOp, TestGradOp) { -// auto net = std::make_shared(); -// ASSERT_NE(net, nullptr); -// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); -// net->AddOp( -// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); -// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, -// {})); -// auto grad_ops = AddBackwardOp(net); -// for (auto& op : grad_ops->ops_) { -// op->DebugString(); -// } -//} +TEST(Net, insert_op) { + NetOp net; + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net.AddOp(op1); + net.InsertOp(0, op1); + ASSERT_EQ(2UL, net.ops_.size()); + net.InsertOp(2, op1); + ASSERT_EQ(3UL, net.ops_.size()); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 384f0f631dd9b9a4dd7c0c628340afe668bc248f..f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { } protected: - void AddInput(const std::string& name, const std::string& comment, - bool multiple = false, bool ignore_gradient = false) { + struct VariableBuilder { + VarProto* var_; + std::function on_multiple_; + std::function on_temporary_; + + VariableBuilder& SetMultiple() { + var_->set_multiple(true); + on_multiple_(); + return *this; + } + + VariableBuilder& SetTemporary() { + PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); + var_->set_temporary(true); + on_temporary_(); + return *this; + } + + VariableBuilder& IgnoreGradient() { + var_->set_ignore_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, + const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; - input->set_ignore_gradient(ignore_gradient); - input->set_multiple(multiple); - if (multiple) { - SetHasMultipleInput(); - } - } - - void AddInputs(const std::string& name, const std::string& comment, - bool ignore_gradient = false) { - AddInput(name, comment, true, ignore_gradient); + return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, + nullptr}; } - void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false, - bool ignore_gradient = false) { + VariableBuilder AddOutput(const std::string& name, + const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; - output->set_ignore_gradient(ignore_gradient); - output->set_multiple(multiple); - if (multiple) { - SetHasMultipleOutput(); - } - output->set_temporary(temporary); - if (temporary) { - SetHasTemporaryOutput(); - } - } - - void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false, bool ignore_gradient = false) { - AddOutput(name, comment, temporary, true, ignore_gradient); + return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, + [=] { this->SetHasTemporaryOutput(); }}; } template @@ -300,9 +303,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp( - std::shared_ptr op) { - GradOpBuilder builder(op.get()); + static std::shared_ptr CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), + "Use framework::Backward to get backward ops"); + GradOpBuilder builder(op); std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op; diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index d8ae3d07227f1dfab05a98adfc24072aca41ec4d..9894928a7aa19bc6c7ad8b230562fb9a681cfebd 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("input", "input of cosine op"); - AddOutput("output", "output of cosine op", - /*temporary*/ true); + AddInput("input", "input of cosine op").SetMultiple(); + AddOutput("output", "output of cosine op").SetTemporary(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9bf60b7b11636df97031d111031ef782a173006b..cfe9cba308556475ef64b45e7178dfc418761598 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -52,7 +52,8 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), "Input Out Of Range"); return std::vector{ @@ -78,7 +79,8 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), @@ -105,5 +107,11 @@ std::string OperatorBase::DebugString() const { return ss.str(); } +void OperatorBase::Rename(const std::string& old_name, + const std::string& new_name) { + std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); + std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index b3a66de7530b942208fb5e8dd5f2cf46fb7f3530..0832a663dd01fe2921366d70599bc867e73af47c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -54,6 +55,9 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + /// Variables with this suffix are supposed to be filled up with zeros. + static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } + virtual ~OperatorBase() {} template @@ -79,8 +83,12 @@ class OperatorBase { virtual bool IsNetOp() const { return false; } + /// rename inputs outputs name + void Rename(const std::string& old_name, const std::string& new_name); + //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; + //! Get a input which has multiple variables. //! TODO add a vector_view to prevent memory copy. std::vector Inputs(const std::string& name) const; @@ -92,7 +100,13 @@ class OperatorBase { public: std::string type_; + // NOTE: in case of OpGrad, inputs_ contains: + // I (Inputs) + // O (Outputs) + // OG (Output Gradients) std::vector inputs_; + // NOTE: in case of OpGrad, outputs_ contains + // IG (Inputs Gradients) std::vector outputs_; AttributeMap attrs_; // store the arguments' offset described in op_desc. @@ -147,22 +161,30 @@ class OperatorContext { template const T* Input(const size_t index) const { - return &(InputVar(index)->Get()); + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); } template T* Output(const size_t index) const { - return OutputVar(index)->GetMutable(); + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); } template const T* Input(const std::string& name) const { - return &(InputVar(name)->Get()); + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); } template T* Output(const std::string& name) const { - return OutputVar(name)->GetMutable(); + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); } template @@ -171,8 +193,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return &scope_.FindVar(name)->Get(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); }); return res; } @@ -183,8 +209,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return scope_.FindVar(name)->GetMutable(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); }); return res; } diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 401bce0fbbc953afedf5f352f9d3ae8bedab7cbe..6a6a802b7da05c37a317540030836baa28a89cd7 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("xs", "inputs of test op"); + AddInput("xs", "inputs of test op").SetMultiple(); AddInput("k", "input of test op"); - AddOutputs("ys", "outputs of test op"); + AddOutput("ys", "outputs of test op").SetMultiple(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); diff --git a/paddle/gserver/layers/SliceProjection.cpp b/paddle/gserver/layers/SliceProjection.cpp new file mode 100644 index 0000000000000000000000000000000000000000..267dd6154b1b21cc9b936384d438a2c3bdf0c246 --- /dev/null +++ b/paddle/gserver/layers/SliceProjection.cpp @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Projection.h" + +namespace paddle { + +/** + * SliceProjection can slice the input value into multiple parts, + * and then select some of them to merge into a new output. + * + * First, calculate the slices that need to be merged into the output. + * slices = input.slices().for_output() + * + * Second, merge each slice into the output. + * for(auto slice: slices) { + * out.addAtOffset(slice, offset); + * } + * + * Input slices as output: s0, s1, ...: + * ----------------------- + * |///| |//////| | + * |/s0| |//s1//| | + * |///| |//////| | + * ----------------------- + * Output, merge s0, s1, ... into one output: + * ---------------- + * |///|//////| | + * |/s0|//s1//|...| + * |///|//////| | + * ---------------- + * + * The config file api is slice_projection. + */ +class SliceProjection : public Projection { +public: + SliceProjection(const ProjectionConfig& config, + const ParameterPtr& parameter, + bool useGpu); + virtual void forward(); + virtual void backward(const UpdateCallback& callback); + +protected: + std::vector> slices_; +}; + +REGISTER_PROJECTION(slice, SliceProjection); + +/** + * Constructed function. + * @note SliceProjection should not have any parameter. + */ +SliceProjection::SliceProjection(const ProjectionConfig& config, + const ParameterPtr& parameter, + bool useGpu) + : Projection(config, parameter, useGpu) { + CHECK(!parameter) << "'slice' projection should not have any parameter"; + + slices_.reserve(config.slices_size()); + for (const auto& slice : config.slices()) { + slices_.push_back(std::make_pair(slice.start(), slice.end())); + } +} + +void SliceProjection::forward() { + size_t offset = 0; + for (auto& slice : slices_) { + auto slice_out = in_->value->subColMatrix(slice.first, slice.second); + out_->value->addAtOffset(*slice_out, offset); + offset += slice_out->getWidth(); + } +} + +void SliceProjection::backward(const UpdateCallback& callback) { + if (in_->grad) { + size_t offset = 0; + for (auto& slice : slices_) { + auto slice_out = in_->grad->subColMatrix(slice.first, slice.second); + slice_out->addAtOffset(*out_->grad, offset); + offset += slice_out->getWidth(); + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/concat_slice_a.conf b/paddle/gserver/tests/concat_slice_a.conf new file mode 100644 index 0000000000000000000000000000000000000000..dccf911089e16f4f97b1470ee39d192d4557d4bd --- /dev/null +++ b/paddle/gserver/tests/concat_slice_a.conf @@ -0,0 +1,41 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) + +data = data_layer(name ="input", size=8*16*16) + +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +proj1 = slice_projection(input=conv1, slices=[(0, 4), (4, 12)]) + +proj2 = slice_projection(input=conv2, slices=[(1, 5), (5, 15)]) + +concat = concat_layer(input=[proj1, proj2]) + +outputs(concat) + diff --git a/paddle/gserver/tests/concat_slice_b.conf b/paddle/gserver/tests/concat_slice_b.conf new file mode 100644 index 0000000000000000000000000000000000000000..29686ef2810370af3f84b60b2450d5c7d2e7663d --- /dev/null +++ b/paddle/gserver/tests/concat_slice_b.conf @@ -0,0 +1,41 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) + +data = data_layer(name ="input", size=8*16*16) + +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +proj1 = slice_projection(input=conv1, slices=[(0, 12)]) + +proj2 = slice_projection(input=conv2, slices=[(1, 15)]) + +concat = concat_layer(input=[proj1, proj2]) + +outputs(concat) + diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0975c3bc9573c6ccb8f0ac98c41586d322d2465e..8ce8600c6743779899b2685c1c12053922265411 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -152,6 +152,26 @@ TEST(Projection, identity) { } } +TEST(Projection, slice) { + ProjectionConfig conf; + conf.set_type("slice"); + conf.set_input_size(100); + SliceConfig& slice1 = *conf.add_slices(); + slice1.set_start(10); + slice1.set_end(20); + SliceConfig& slice2 = *conf.add_slices(); + slice2.set_start(50); + slice2.set_end(70); + conf.set_output_size(30); + for (auto useGpu : {false, true}) { + testProjectionGrad(conf, + INPUT_DATA, + /* parameterSize */ 0, + /* batchSize */ 10, + useGpu); + } +} + TEST(Projection, scaling) { ProjectionConfig conf; conf.set_type("scaling"); diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index 40e662b22bac0a2d22aea31fe99b11695bac3f57..f930c72fde3f5e0a6a45cb6bfd3507a4f48028fc 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -237,6 +237,12 @@ TEST(Compare, concat_table) { compareNetwork(config_file_a, config_file_b); } +TEST(Compare, concat_slice) { + std::string config_file_a = "./gserver/tests/concat_slice_a.conf"; + std::string config_file_b = "./gserver/tests/concat_slice_b.conf"; + compareNetwork(config_file_a, config_file_b); +} + #ifndef PADDLE_ONLY_CPU TEST(Compare, img_pool) { std::string config_file_a = "./gserver/tests/img_pool_a.conf"; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 9d28404f687dca0e36ee24d4712bbae82d60d620..701d356ee8a6febca3c19a9003fec6b49e853099 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) +op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index c4a9f5937f4fa8c60989bea1726cedbb73330156..71ceda958770796693265c08cb1fcae27e79bcd9 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -50,8 +50,8 @@ public: AddInput("b", "the bias of fc operator"); AddOutput("Y", "the output of fc operator"); - AddOutput( - "before_act", "the before activation output of fc operator", true); + AddOutput("before_act", "the before activation output of fc operator") + .SetTemporary(); AddAttr("activation", "The activation key for fc layer") .SetDefault("sigmoid") .InEnum({"sigmoid", "softmax"}); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..79a0e3d7e911b728a7a96ceff573976ba2b2e37f --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class FillZerosLikeOp : public framework::OperatorWithKernel { +protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1UL, + "Input size of FillZerosLikeOp must be one."); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output size of AddOp must be one."); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, + "Input of FillZerosLikeOp must be set."); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, + "Output of FillZerosLikeOp must be set."); + ctx.Output(0)->Resize( + ctx.Input(0)->dims()); + } +}; + +class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FillZerosLikeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Src", "The input of fill-zeros-like op."); + AddOutput("Dst", "The varibale will be filled up with zeros."); + AddComment(R"DOC( +Fill up a vriable with zeros. + +The output will have the same size with input. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(fill_zeros_like, + paddle::operators::FillZerosLikeOp, + paddle::operators::FillZerosLikeOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..55ad58f4f17cd4a3e737c01b001675d2690d273e --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cu @@ -0,0 +1,6 @@ +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_zeros_like_op.h" + +REGISTER_OP_GPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..05272964abd43bdc2bd5c3cae8b128099e1c888c --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class FillZerosLikeKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output = context.Output(0); + output->mutable_data(context.GetPlace()); + framework::EigenVector::Flatten(*output).setZero(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index df0d94cbdd2d3b14e68845031389ba1b719168dc..e5b76e3724b5b0287071c90d26235b8e1a1d80cf 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -291,13 +291,15 @@ public: : OpProtoAndCheckerMaker(proto, op_checker) { const auto& name = RecurrentOp::kArgName; // inputs and outputs stored in proto - AddInputs(name.inlinks, - "the inputs that need to be segmented for each step."); - AddInputs(name.boot_memories, "variables to initialize memories."); + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") + .SetMultiple(); + AddInput(name.boot_memories, "variables to initialize memories.") + .SetMultiple(); AddInput(name.step_net, "network shared by all steps."); - AddOutputs(name.outlinks, - "the outputs that need to concated for all steps."); + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") + .SetMultiple(); AddOutput(name.step_scopes, "step scopes"); // Attributes stored in AttributeMap diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index ff69a0f8ab944495ef4232fbde253b734abdf9fd..26c8eb78e614a68ec9728aad727d8fe3e08547ae 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -14,8 +14,9 @@ limitations under the License. */ #pragma once -#include +#include #include +#include #include #include #include @@ -40,12 +41,22 @@ namespace platform { struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; - EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { + static constexpr int TRACE_STACK_LIMIT = 100; try { std::rethrow_exception(exp_); } catch (const std::exception& exp) { - err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l); + std::ostringstream sout; + sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; + sout << "Call Stacks: " << std::endl; + void* call_stack[TRACE_STACK_LIMIT]; + int sz = backtrace(call_stack, TRACE_STACK_LIMIT); + auto line = backtrace_symbols(call_stack, sz); + for (int i = 0; i < sz; ++i) { + sout << line[i] << std::endl; + } + free(line); + err_str_ = sout.str(); } } diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 83f72c137bdf5e55f28be908321bd2ccd6c906fe..3bee5b572ae42750332b69e28af980ae325532da 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -198,6 +198,11 @@ message RowConvConfig { required uint32 context_length = 1; } +message SliceConfig { + required uint32 start = 1; + required uint32 end = 2; +} + message ProjectionConfig { required string type = 1; required string name = 2; @@ -218,6 +223,10 @@ message ProjectionConfig { // For pool optional PoolConfig pool_conf = 12; + + // For slice + // Each slice output is the input[start, end) + repeated SliceConfig slices = 13; } message OperatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5477158ecb8646992ebdded0b15cce50720ebf36..f71fefffb59d4a53dda092ff83a61d9eec4b601f 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -565,6 +565,35 @@ class IdentityOffsetProjection(Projection): return [] +@config_class +class SliceProjection(Projection): + type = 'slice' + + def __init__(self, input_layer_name, slices, **xargs): + super(SliceProjection, self).__init__(input_layer_name, **xargs) + input = g_layer_map[input_layer_name] + if input.type in ["exconv", "cudnn_conv"]: + # the slice operator is for the channel dimension + assert input.num_filters is not None + channels = input.num_filters + image_size = input.size / channels + assert slices[len(slices) - 1][1] <= channels + for i in xrange(len(slices)): + slice = self.proj_conf.slices.add() + slice.start = slices[i][0] * image_size + slice.end = slices[i][1] * image_size + self.size += slice.end - slice.start + else: + config_assert(False, + 'Currently the input should be convolution layer') + + def calc_parameter_size(self, input_size, output_size): + return 0 + + def calc_parameter_dims(self, input_size, output_size): + return [] + + # DotMulProjection performs element-wise multiplication with weight @config_class class DotMulProjection(Projection): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 14f072fc55109d770edf469ad7c574b8dda8a434..965874ddf632a83d00065c2d40037930a6e604a8 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -128,6 +128,7 @@ __all__ = [ 'prelu_layer', 'gated_unit_layer', 'crop_layer', + 'slice_projection', ] @@ -536,6 +537,45 @@ def identity_projection(input, offset=None, size=None): return proj +def slice_projection(input, slices): + """ + slice_projection can slice the input value into multiple parts, + and then select some of them to merge into a new output. + + .. math:: + output = [input.slices()] + + The example usage is: + + .. code-block:: python + + proj = slice_projection(input=layer, slices=[(0, 10), (20, 30)]) + + Note that slice_projection should not have any parameter. + + :param input: Input Layer. + :type input: LayerOutput + :param slices: An array of slice parameters. + Each slice contains the start and end offsets based + on the input. + :type slices: pair of int + :return: A SliceProjection object + :rtype: SliceProjection + """ + assert len(slices) >= 1 + start = 0 + for i in xrange(len(slices)): + assert len(slices[i]) == 2 + # The start position of the next slice needs to be greater than + # or equal to the end position of the previous slice. + assert slices[i][0] >= start + assert slices[i][1] >= slices[i][0] + start = slices[i][1] + proj = SliceProjection(input_layer_name=input.name, slices=slices) + proj.origin = input + return proj + + @wrap_param_attr_default() def scaling_projection(input, param_attr=None): """ diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index b658a81630733fea3976b812afe819d76de4cb25..fc718f031e2267e737adbc340226e145bf614bf2 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -76,3 +76,6 @@ class client(object): # Memory created from C should be freed. get_c_lib().mem_free(ret.contents) return record, 0 + + def paddle_start_get_records(self, pass_id): + get_c_lib().paddle_start_get_records(self.c, pass_id) diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 55a0fcdf56af7a8c9bee3255ea6f1d1ae1b34893..d0f18e4b6611fa56654e7f2a0144758339cb9e19 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could be used in user program. """ -__all__ = ['np_array', 'text_file', "recordio"] +__all__ = ['np_array', 'text_file', "cloud_reader"] def np_array(x): @@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100): return dec.buffered(reader, buf_size) -def recordio(paths, buf_size=100): +pass_num = 0 + + +def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): """ - Creates a data reader that outputs record one one by one - from given local or cloud recordio path. + Create a data reader that yield a record one by one from + the paths: :path: path of recordio files. + :etcd_endpoints: the endpoints for etcd cluster :returns: data reader of recordio files. + + .. code-block:: python + from paddle.v2.reader.creator import cloud_reader + etcd_endpoints = "http://127.0.0.1:2379" + trainer.train.( + reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints), + ) """ import os - import paddle.v2.master.client as cloud - - if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): - return recordio_local(paths) - - host_name = "MASTER_SERVICE_HOST" - if host_name not in os.environ.keys(): - raise Exception('not find ' + host_name + ' in environment variable.') - - addr = os.environ(host) + import cPickle as pickle + import paddle.v2.master as master + c = master.client(etcd_endpoints, timeout_sec, buf_size) + c.set_dataset(paths) def reader(): - c = cloud(addr, buf_size) - c.set_dataset(paths) + global pass_num + c.paddle_start_get_records(pass_num) + pass_num += 1 while True: - r, err = client.next_record() - if err < 0: + r, e = c.next_record() + if not r: + if e != -2: + print "get record error: ", e break - yield r - - c.release() + yield pickle.loads(r) return reader diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index b42d273ecfe6c4bc5706ec52617960b83496d70d..359f3eeefbe8efeb343cc875c707c9251a7087fb 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase): self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) -class TestRecordIO(unittest.TestCase): - def test_recordio(self): - path = os.path.join( - os.path.dirname(__file__), "test_recordio_creator.dat") - reader = paddle.v2.reader.creator.recordio([path]) - for idx, r in enumerate(reader()): - self.assertSequenceEqual(r, str(idx)) - - if __name__ == '__main__': unittest.main()