提交 99b835cc 编写于 作者: D dangqingqing

resolve conficts.

...@@ -198,6 +198,10 @@ identity_projection ...@@ -198,6 +198,10 @@ identity_projection
.. autoclass:: paddle.v2.layer.identity_projection .. autoclass:: paddle.v2.layer.identity_projection
:noindex: :noindex:
slice_projection
-------------------
.. autoclass:: paddle.v2.layer.slice_projection
:noindex:
table_projection table_projection
---------------- ----------------
......
...@@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing ...@@ -3,24 +3,11 @@ import paddle.v2.dataset.uci_housing as uci_housing
import paddle.v2.master as master import paddle.v2.master as master
import os import os
import cPickle as pickle import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379" etcd_endpoints = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint print "etcd endpoints: ", etcd_endpoints
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader():
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
def main(): def main():
...@@ -49,7 +36,7 @@ def main(): ...@@ -49,7 +36,7 @@ def main():
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
is_local=False, is_local=False,
pserver_spec=etcd_endpoint, pserver_spec=etcd_endpoints,
use_etcd=True) use_etcd=True)
# event_handler to print training and testing info # event_handler to print training and testing info
...@@ -75,7 +62,11 @@ def main(): ...@@ -75,7 +62,11 @@ def main():
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
cloud_reader, buf_size=500), batch_size=2), cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0, feeding={'x': 0,
'y': 1}, 'y': 1},
event_handler=event_handler, event_handler=event_handler,
......
...@@ -32,4 +32,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -32,4 +32,7 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(net SRCS net.cc DEPS op_registry) cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
static bool AllInSet(const std::vector<std::string>& names,
const std::string& suffix,
const std::unordered_set<std::string>& set) {
for (auto& name : names) {
if (set.find(name + suffix) == set.end()) {
return false;
}
}
return true;
}
static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<NetOp>();
net_op->type_ = "@NOP@";
net_op->CompleteAddOp();
return net_op;
}
// Get backward operator from a forward operator, recursively implementation.
//
// no_grad_names the gradient variable names without gradient calculating.
//
// uniq_id is a unique index used inside recursively calling BackwardRecursive.
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
// recursive calling.
//
// returns The backward operator. For simple situation, it is a simple
// operator. For complex situation, it is a NetOp.
//
// See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
return NOP();
}
// All output gradients of forwarding operator do not need to calculate. Then
// all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
}
return NOP();
}
// Returned gradient network
auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in backward
// net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0;
// reversely travel forwardNet
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id);
}
}
// Get unique ID for this method.
auto uid = uniq_id++;
// TODO(dzh): more comment
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second;
if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) {
auto op_offset = dup_op[i];
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back());
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {dup_outputs}, {name},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
}
insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second);
}
} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {}));
}
}
for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME();
}
}
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op);
}
net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp();
return net;
}
// See header for comments
std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size());
for (auto& name : no_grad_vars) {
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
}
size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_set>
#include "operator.h"
namespace paddle {
namespace framework {
// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
} // namespace paddle
## Operator/expression 's Backward
### Motivation
In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass.
### Implement : gradient operator registry
| | forward operator | backward operator |
| ---------------------- | ---------------- | -------------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute.
We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself.
grad_op_builder(fengjiayi)
### Implement : Backward network
given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
1. bla bla bla (yuyang)
2. NetOp
when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name.
We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable.
![./images/duplicate_op]()
Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead.
![./images/duplicate_op2]()
​ Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/backward.h"
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class EmptyOp : public OperatorBase {
public:
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient();
AddOutput("Out", "Out of Add").IgnoreGradient();
AddComment("Add Op");
}
};
class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("A", "A");
AddInput("B", "B");
AddOutput("Out", "Out");
AddComment("Mul");
}
};
class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X");
AddOutput("Y", "Y");
AddComment("Sigmoid");
}
};
class NoGradOpMaker : public OpProtoAndCheckerMaker {
public:
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X input");
AddOutput("Y", "Y output");
AddComment("NoGradOp, same input output. no Grad");
}
};
class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("mul_result")}, {}));
auto b_name = Input("b");
std::string before_act = "mul_result";
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("add_result")}, {}));
before_act = "add_result";
} else {
auto out_varname = Output("add_result");
if (out_varname != EMPTY_VAR_NAME()) {
this->Rename(out_varname, EMPTY_VAR_NAME());
}
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")},
{}));
CompleteAddOp(false);
}
};
class FcOpMaker : public OpProtoAndCheckerMaker {
public:
FcOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_result", "").SetTemporary();
AddOutput("add_result", "").SetTemporary();
AddOutput("Out", "");
AddComment("");
}
};
class ManyOutputOpMaker : public OpProtoAndCheckerMaker {
public:
ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("y", "y");
AddOutput("z", "z");
AddComment("");
}
};
class FillZeroOpMaker : public OpProtoAndCheckerMaker {
public:
FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("out", "out");
AddComment("");
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple();
AddOutput("Y", "y");
AddComment("");
}
};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL, std::static_pointer_cast<f::NetOp>(no_input_gop)->ops_.size());
}
TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
ASSERT_EQ(3UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
f::OperatorBase &d_add = *net->ops_[1];
ASSERT_EQ("rowwise_add_grad", d_add.type_);
f::OperatorBase &d_mul = *net->ops_[2];
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
{"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
ASSERT_EQ(2UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
f::OperatorBase &d_mul = *net->ops_[1];
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
{"mul_tmp_1", "add_tmp_1", "hidden1"}, {}));
net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
}
// Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, net_shared_weight) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_);
}
TEST(Backward, op_register_grad_not_for_network) {
auto fwd = f::OpRegistry::CreateOp(
"fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"},
{{"temporary_index", std::vector<int>{0, 1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_);
ASSERT_EQ(1UL, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1UL, fill_zero.outputs_.size());
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(),
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out");
}
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
{"mul_out1", "add_out1", "out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
{"mul_out2", "tmp_out2", "out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
{"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(),
3UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */
+ 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
/*
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("X"), "out2");
EXPECT_EQ(grad_fc.Input("W"), "w3");
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
EXPECT_EQ(grad_fc.Input("Out"), "out3");
*/
}
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
OperatorBase* GradOpBuilder::Build() { OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList(); BuildOpInOutArgList();
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type; grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op); CompleteGradOp(grad_op);
...@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, ...@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
} }
void GradOpBuilder::BuildOpInOutArgList() { void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const OpProto& op_proto = OpRegistry::protos().at(op_.type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_));
const std::vector<int>& in_format = const std::vector<int>& in_format =
op_->attrs_.count("input_format") op_.attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format") ? op_.GetAttr<std::vector<int>>("input_format")
: std::vector<int>(); : std::vector<int>();
const std::vector<int>& out_format = const std::vector<int>& out_format =
op_->attrs_.count("output_format") op_.attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format") ? op_.GetAttr<std::vector<int>>("output_format")
: std::vector<int>(); : std::vector<int>();
for (const auto& var : op_proto.inputs()) { for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back( arg_list_.emplace_back(
...@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
} }
(*varmap)[var_name] = idx++; (*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size(); size_t pre_sz = in_out.size();
auto base_it = auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin();
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out)); std::back_inserter(in_out));
if (is_grad) { if (is_grad) {
...@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
} }
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_; grad_op->attrs_ = op_.attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap(); VarIndexMap* grad_varmap = new VarIndexMap();
......
...@@ -29,7 +29,7 @@ class GradOpBuilder { ...@@ -29,7 +29,7 @@ class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
public: public:
GradOpBuilder(const OperatorBase* op) : op_(op) {} GradOpBuilder(const OperatorBase& op) : op_(op) {}
OperatorBase* Build(); OperatorBase* Build();
private: private:
...@@ -40,7 +40,7 @@ class GradOpBuilder { ...@@ -40,7 +40,7 @@ class GradOpBuilder {
std::vector<int>& format, VarIndexMap* varmap, int& idx, std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const; bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const; void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_; const OperatorBase& op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_; std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
}; };
......
...@@ -11,7 +11,7 @@ namespace framework { ...@@ -11,7 +11,7 @@ namespace framework {
TEST(GradOpBuilder, AddTwo) { TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op( std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op); std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2); EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
......
...@@ -68,9 +68,18 @@ class NetOp : public OperatorBase { ...@@ -68,9 +68,18 @@ class NetOp : public OperatorBase {
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
ops_.push_back(op); ops_.push_back(op);
} }
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE(pos <= ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op);
}
void CompleteAddOp(bool calculate = true); void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,6 +20,13 @@ class TestOp : public OperatorBase { ...@@ -25,6 +20,13 @@ class TestOp : public OperatorBase {
} }
}; };
class EmptyOp : public OperatorBase {
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
template <typename T> template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected, void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
const std::vector<T>& actual) { const std::vector<T>& actual) {
...@@ -71,20 +73,17 @@ TEST(OpKernel, all) { ...@@ -71,20 +73,17 @@ TEST(OpKernel, all) {
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
//! TODO(yuyang18): Refine Backward Op. TEST(Net, insert_op) {
// TEST(AddBackwardOp, TestGradOp) { NetOp net;
// auto net = std::make_shared<NetOp>(); auto op1 = std::make_shared<EmptyOp>();
// ASSERT_NE(net, nullptr); op1->inputs_ = {"x", "w1", "b1"};
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); op1->outputs_ = {"y"};
// net->AddOp( net.AddOp(op1);
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); net.InsertOp(0, op1);
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, ASSERT_EQ(2UL, net.ops_.size());
// {})); net.InsertOp(2, op1);
// auto grad_ops = AddBackwardOp(net); ASSERT_EQ(3UL, net.ops_.size());
// for (auto& op : grad_ops->ops_) { }
// op->DebugString();
// }
//}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { ...@@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker {
} }
protected: protected:
void AddInput(const std::string& name, const std::string& comment, struct VariableBuilder {
bool multiple = false, bool ignore_gradient = false) { VarProto* var_;
std::function<void()> on_multiple_;
std::function<void()> on_temporary_;
VariableBuilder& SetMultiple() {
var_->set_multiple(true);
on_multiple_();
return *this;
}
VariableBuilder& SetTemporary() {
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
return *this;
}
VariableBuilder& IgnoreGradient() {
var_->set_ignore_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name,
const std::string& comment) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name; *input->mutable_name() = name;
*input->mutable_comment() = comment; *input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient); return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
input->set_multiple(multiple); nullptr};
if (multiple) {
SetHasMultipleInput();
}
}
void AddInputs(const std::string& name, const std::string& comment,
bool ignore_gradient = false) {
AddInput(name, comment, true, ignore_gradient);
} }
void AddOutput(const std::string& name, const std::string& comment, VariableBuilder AddOutput(const std::string& name,
bool temporary = false, bool multiple = false, const std::string& comment) {
bool ignore_gradient = false) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient); return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
output->set_multiple(multiple); [=] { this->SetHasTemporaryOutput(); }};
if (multiple) {
SetHasMultipleOutput();
}
output->set_temporary(temporary);
if (temporary) {
SetHasTemporaryOutput();
}
}
void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true, ignore_gradient);
} }
template <typename T> template <typename T>
...@@ -300,9 +303,10 @@ class OpRegistry { ...@@ -300,9 +303,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp( static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
std::shared_ptr<OperatorBase> op) { PADDLE_ENFORCE(!op.IsNetOp(),
GradOpBuilder builder(op.get()); "Use framework::Backward to get backward ops");
GradOpBuilder builder(op);
std::shared_ptr<OperatorBase> grad_op(builder.Build()); std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
......
...@@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("input", "input of cosine op"); AddInput("input", "input of cosine op").SetMultiple();
AddOutput("output", "output of cosine op", AddOutput("output", "output of cosine op").SetTemporary();
/*temporary*/ true);
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
}; };
......
...@@ -52,7 +52,8 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -52,7 +52,8 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(inputs_.size()),
"Input Out Of Range"); "Input Out Of Range");
return std::vector<std::string>{ return std::vector<std::string>{
...@@ -78,7 +79,8 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { ...@@ -78,7 +79,8 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(outputs_.size()),
"Output Out of Range"); "Output Out of Range");
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
...@@ -105,5 +107,11 @@ std::string OperatorBase::DebugString() const { ...@@ -105,5 +107,11 @@ std::string OperatorBase::DebugString() const {
return ss.str(); return ss.str();
} }
void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
std::replace(inputs_.begin(), inputs_.end(), old_name, new_name);
std::replace(outputs_.begin(), outputs_.end(), old_name, new_name);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -54,6 +55,9 @@ class OperatorBase { ...@@ -54,6 +55,9 @@ class OperatorBase {
/// e.g. Variable "x@GRAD" is the gradient of varibale "x". /// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
...@@ -79,8 +83,12 @@ class OperatorBase { ...@@ -79,8 +83,12 @@ class OperatorBase {
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const; std::vector<std::string> Inputs(const std::string& name) const;
...@@ -92,7 +100,13 @@ class OperatorBase { ...@@ -92,7 +100,13 @@ class OperatorBase {
public: public:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc. // store the arguments' offset described in op_desc.
...@@ -147,22 +161,30 @@ class OperatorContext { ...@@ -147,22 +161,30 @@ class OperatorContext {
template <typename T> template <typename T>
const T* Input(const size_t index) const { const T* Input(const size_t index) const {
return &(InputVar(index)->Get<T>()); auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const size_t index) const { T* Output(const size_t index) const {
return OutputVar(index)->GetMutable<T>(); auto var = OutputVar(index);
PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index);
return var->GetMutable<T>();
} }
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>()); auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const std::string& name) const { T* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<T>(); auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
} }
template <typename T> template <typename T>
...@@ -171,8 +193,12 @@ class OperatorContext { ...@@ -171,8 +193,12 @@ class OperatorContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [&](const std::string& sub_name) {
return &scope_.FindVar(name)->Get<T>(); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiInput(%s:%s) should not be nullptr",
name, sub_name);
return &var->Get<T>();
}); });
return res; return res;
} }
...@@ -183,8 +209,12 @@ class OperatorContext { ...@@ -183,8 +209,12 @@ class OperatorContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [&](const std::string& sub_name) {
return scope_.FindVar(name)->GetMutable<T>(); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiOutput(%s:%s) should not be nullptr",
name, sub_name);
return var->GetMutable<T>();
}); });
return res; return res;
} }
......
...@@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "inputs of test op"); AddInput("xs", "inputs of test op").SetMultiple();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutputs("ys", "outputs of test op"); AddOutput("ys", "outputs of test op").SetMultiple();
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Projection.h"
namespace paddle {
/**
* SliceProjection can slice the input value into multiple parts,
* and then select some of them to merge into a new output.
*
* First, calculate the slices that need to be merged into the output.
* slices = input.slices().for_output()
*
* Second, merge each slice into the output.
* for(auto slice: slices) {
* out.addAtOffset(slice, offset);
* }
*
* Input slices as output: s0, s1, ...:
* -----------------------
* |///| |//////| |
* |/s0| |//s1//| |
* |///| |//////| |
* -----------------------
* Output, merge s0, s1, ... into one output:
* ----------------
* |///|//////| |
* |/s0|//s1//|...|
* |///|//////| |
* ----------------
*
* The config file api is slice_projection.
*/
class SliceProjection : public Projection {
public:
SliceProjection(const ProjectionConfig& config,
const ParameterPtr& parameter,
bool useGpu);
virtual void forward();
virtual void backward(const UpdateCallback& callback);
protected:
std::vector<std::pair<size_t, size_t>> slices_;
};
REGISTER_PROJECTION(slice, SliceProjection);
/**
* Constructed function.
* @note SliceProjection should not have any parameter.
*/
SliceProjection::SliceProjection(const ProjectionConfig& config,
const ParameterPtr& parameter,
bool useGpu)
: Projection(config, parameter, useGpu) {
CHECK(!parameter) << "'slice' projection should not have any parameter";
slices_.reserve(config.slices_size());
for (const auto& slice : config.slices()) {
slices_.push_back(std::make_pair(slice.start(), slice.end()));
}
}
void SliceProjection::forward() {
size_t offset = 0;
for (auto& slice : slices_) {
auto slice_out = in_->value->subColMatrix(slice.first, slice.second);
out_->value->addAtOffset(*slice_out, offset);
offset += slice_out->getWidth();
}
}
void SliceProjection::backward(const UpdateCallback& callback) {
if (in_->grad) {
size_t offset = 0;
for (auto& slice : slices_) {
auto slice_out = in_->grad->subColMatrix(slice.first, slice.second);
slice_out->addAtOffset(*out_->grad, offset);
offset += slice_out->getWidth();
}
}
}
} // namespace paddle
#edit-mode: -*- python -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
settings(batch_size=10)
data = data_layer(name ="input", size=8*16*16)
conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
proj1 = slice_projection(input=conv1, slices=[(0, 4), (4, 12)])
proj2 = slice_projection(input=conv2, slices=[(1, 5), (5, 15)])
concat = concat_layer(input=[proj1, proj2])
outputs(concat)
#edit-mode: -*- python -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
settings(batch_size=10)
data = data_layer(name ="input", size=8*16*16)
conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
proj1 = slice_projection(input=conv1, slices=[(0, 12)])
proj2 = slice_projection(input=conv2, slices=[(1, 15)])
concat = concat_layer(input=[proj1, proj2])
outputs(concat)
...@@ -152,6 +152,26 @@ TEST(Projection, identity) { ...@@ -152,6 +152,26 @@ TEST(Projection, identity) {
} }
} }
TEST(Projection, slice) {
ProjectionConfig conf;
conf.set_type("slice");
conf.set_input_size(100);
SliceConfig& slice1 = *conf.add_slices();
slice1.set_start(10);
slice1.set_end(20);
SliceConfig& slice2 = *conf.add_slices();
slice2.set_start(50);
slice2.set_end(70);
conf.set_output_size(30);
for (auto useGpu : {false, true}) {
testProjectionGrad(conf,
INPUT_DATA,
/* parameterSize */ 0,
/* batchSize */ 10,
useGpu);
}
}
TEST(Projection, scaling) { TEST(Projection, scaling) {
ProjectionConfig conf; ProjectionConfig conf;
conf.set_type("scaling"); conf.set_type("scaling");
......
...@@ -237,6 +237,12 @@ TEST(Compare, concat_table) { ...@@ -237,6 +237,12 @@ TEST(Compare, concat_table) {
compareNetwork(config_file_a, config_file_b); compareNetwork(config_file_a, config_file_b);
} }
TEST(Compare, concat_slice) {
std::string config_file_a = "./gserver/tests/concat_slice_a.conf";
std::string config_file_b = "./gserver/tests/concat_slice_b.conf";
compareNetwork(config_file_a, config_file_b);
}
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(Compare, img_pool) { TEST(Compare, img_pool) {
std::string config_file_a = "./gserver/tests/img_pool_a.conf"; std::string config_file_a = "./gserver/tests/img_pool_a.conf";
......
...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net) softmax_op net)
......
...@@ -50,8 +50,8 @@ public: ...@@ -50,8 +50,8 @@ public:
AddInput("b", "the bias of fc operator"); AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator"); AddOutput("Y", "the output of fc operator");
AddOutput( AddOutput("before_act", "the before activation output of fc operator")
"before_act", "the before activation output of fc operator", true); .SetTemporary();
AddAttr<std::string>("activation", "The activation key for fc layer") AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid") .SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"}); .InEnum({"sigmoid", "softmax"});
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output size of AddOp must be one.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr,
"Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
}
};
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillZerosLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Src", "The input of fill-zeros-like op.");
AddOutput("Dst", "The varibale will be filled up with zeros.");
AddComment(R"DOC(
Fill up a vriable with zeros.
The output will have the same size with input.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(fill_zeros_like,
paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
REGISTER_OP_GPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero();
}
};
} // namespace operators
} // namespace paddle
...@@ -291,13 +291,15 @@ public: ...@@ -291,13 +291,15 @@ public:
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInputs(name.inlinks, AddInput(name.inlinks,
"the inputs that need to be segmented for each step."); "the inputs that need to be segmented for each step.")
AddInputs(name.boot_memories, "variables to initialize memories."); .SetMultiple();
AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple();
AddInput(name.step_net, "network shared by all steps."); AddInput(name.step_net, "network shared by all steps.");
AddOutputs(name.outlinks, AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
"the outputs that need to concated for all steps."); .SetMultiple();
AddOutput(name.step_scopes, "step scopes"); AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
......
...@@ -14,8 +14,9 @@ limitations under the License. */ ...@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include <execinfo.h>
#include <paddle/string/printf.h> #include <paddle/string/printf.h>
#include <iomanip>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
...@@ -40,12 +41,22 @@ namespace platform { ...@@ -40,12 +41,22 @@ namespace platform {
struct EnforceNotMet : public std::exception { struct EnforceNotMet : public std::exception {
std::exception_ptr exp_; std::exception_ptr exp_;
std::string err_str_; std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
static constexpr int TRACE_STACK_LIMIT = 100;
try { try {
std::rethrow_exception(exp_); std::rethrow_exception(exp_);
} catch (const std::exception& exp) { } catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l); std::ostringstream sout;
sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl;
sout << "Call Stacks: " << std::endl;
void* call_stack[TRACE_STACK_LIMIT];
int sz = backtrace(call_stack, TRACE_STACK_LIMIT);
auto line = backtrace_symbols(call_stack, sz);
for (int i = 0; i < sz; ++i) {
sout << line[i] << std::endl;
}
free(line);
err_str_ = sout.str();
} }
} }
......
...@@ -198,6 +198,11 @@ message RowConvConfig { ...@@ -198,6 +198,11 @@ message RowConvConfig {
required uint32 context_length = 1; required uint32 context_length = 1;
} }
message SliceConfig {
required uint32 start = 1;
required uint32 end = 2;
}
message ProjectionConfig { message ProjectionConfig {
required string type = 1; required string type = 1;
required string name = 2; required string name = 2;
...@@ -218,6 +223,10 @@ message ProjectionConfig { ...@@ -218,6 +223,10 @@ message ProjectionConfig {
// For pool // For pool
optional PoolConfig pool_conf = 12; optional PoolConfig pool_conf = 12;
// For slice
// Each slice output is the input[start, end)
repeated SliceConfig slices = 13;
} }
message OperatorConfig { message OperatorConfig {
......
...@@ -565,6 +565,35 @@ class IdentityOffsetProjection(Projection): ...@@ -565,6 +565,35 @@ class IdentityOffsetProjection(Projection):
return [] return []
@config_class
class SliceProjection(Projection):
type = 'slice'
def __init__(self, input_layer_name, slices, **xargs):
super(SliceProjection, self).__init__(input_layer_name, **xargs)
input = g_layer_map[input_layer_name]
if input.type in ["exconv", "cudnn_conv"]:
# the slice operator is for the channel dimension
assert input.num_filters is not None
channels = input.num_filters
image_size = input.size / channels
assert slices[len(slices) - 1][1] <= channels
for i in xrange(len(slices)):
slice = self.proj_conf.slices.add()
slice.start = slices[i][0] * image_size
slice.end = slices[i][1] * image_size
self.size += slice.end - slice.start
else:
config_assert(False,
'Currently the input should be convolution layer')
def calc_parameter_size(self, input_size, output_size):
return 0
def calc_parameter_dims(self, input_size, output_size):
return []
# DotMulProjection performs element-wise multiplication with weight # DotMulProjection performs element-wise multiplication with weight
@config_class @config_class
class DotMulProjection(Projection): class DotMulProjection(Projection):
......
...@@ -128,6 +128,7 @@ __all__ = [ ...@@ -128,6 +128,7 @@ __all__ = [
'prelu_layer', 'prelu_layer',
'gated_unit_layer', 'gated_unit_layer',
'crop_layer', 'crop_layer',
'slice_projection',
] ]
...@@ -536,6 +537,45 @@ def identity_projection(input, offset=None, size=None): ...@@ -536,6 +537,45 @@ def identity_projection(input, offset=None, size=None):
return proj return proj
def slice_projection(input, slices):
"""
slice_projection can slice the input value into multiple parts,
and then select some of them to merge into a new output.
.. math::
output = [input.slices()]
The example usage is:
.. code-block:: python
proj = slice_projection(input=layer, slices=[(0, 10), (20, 30)])
Note that slice_projection should not have any parameter.
:param input: Input Layer.
:type input: LayerOutput
:param slices: An array of slice parameters.
Each slice contains the start and end offsets based
on the input.
:type slices: pair of int
:return: A SliceProjection object
:rtype: SliceProjection
"""
assert len(slices) >= 1
start = 0
for i in xrange(len(slices)):
assert len(slices[i]) == 2
# The start position of the next slice needs to be greater than
# or equal to the end position of the previous slice.
assert slices[i][0] >= start
assert slices[i][1] >= slices[i][0]
start = slices[i][1]
proj = SliceProjection(input_layer_name=input.name, slices=slices)
proj.origin = input
return proj
@wrap_param_attr_default() @wrap_param_attr_default()
def scaling_projection(input, param_attr=None): def scaling_projection(input, param_attr=None):
""" """
......
...@@ -76,3 +76,6 @@ class client(object): ...@@ -76,3 +76,6 @@ class client(object):
# Memory created from C should be freed. # Memory created from C should be freed.
get_c_lib().mem_free(ret.contents) get_c_lib().mem_free(ret.contents)
return record, 0 return record, 0
def paddle_start_get_records(self, pass_id):
get_c_lib().paddle_start_get_records(self.c, pass_id)
...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could ...@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could
be used in user program. be used in user program.
""" """
__all__ = ['np_array', 'text_file', "recordio"] __all__ = ['np_array', 'text_file', "cloud_reader"]
def np_array(x): def np_array(x):
...@@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100): ...@@ -81,35 +81,41 @@ def recordio_local(paths, buf_size=100):
return dec.buffered(reader, buf_size) return dec.buffered(reader, buf_size)
def recordio(paths, buf_size=100): pass_num = 0
def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
""" """
Creates a data reader that outputs record one one by one Create a data reader that yield a record one by one from
from given local or cloud recordio path. the paths:
:path: path of recordio files. :path: path of recordio files.
:etcd_endpoints: the endpoints for etcd cluster
:returns: data reader of recordio files. :returns: data reader of recordio files.
.. code-block:: python
from paddle.v2.reader.creator import cloud_reader
etcd_endpoints = "http://127.0.0.1:2379"
trainer.train.(
reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints),
)
""" """
import os import os
import paddle.v2.master.client as cloud import cPickle as pickle
import paddle.v2.master as master
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): c = master.client(etcd_endpoints, timeout_sec, buf_size)
return recordio_local(paths) c.set_dataset(paths)
host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environment variable.')
addr = os.environ(host)
def reader(): def reader():
c = cloud(addr, buf_size) global pass_num
c.set_dataset(paths) c.paddle_start_get_records(pass_num)
pass_num += 1
while True: while True:
r, err = client.next_record() r, e = c.next_record()
if err < 0: if not r:
if e != -2:
print "get record error: ", e
break break
yield r yield pickle.loads(r)
c.release()
return reader return reader
...@@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase): ...@@ -34,14 +34,5 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def test_recordio(self):
path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册