diff --git a/Dockerfile b/Dockerfile index ed5910d93b41dba8d50b2ba01c59c635797edd29..8cfb16928c95dcbfac08383d32562ff67933d873 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ apt-get install -y \ git python-pip python-dev openssh-server bison \ - wget unzip tar xz-utils bzip2 gzip coreutils ntp \ + wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ python-numpy python-matplotlib gcc g++ \ automake locales clang-format-3.8 swig doxygen cmake \ diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 54fa254863156455f66fa87de9077042a45f9735..9eaf8c04ae01fe7eebc92c51803bfcf977995ee3 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -11,6 +11,7 @@ import ( "github.com/namsral/flag" log "github.com/sirupsen/logrus" + "github.com/topicai/candy" "github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" @@ -20,11 +21,18 @@ func main() { port := flag.Int("port", 8080, "port of the master server.") ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") - taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") - taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") - chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") + taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.") + taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.") + chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.") + logLevel := flag.String("log-level", "info", + "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() + level, e := log.ParseLevel(*logLevel) + candy.Must(e) + + log.SetLevel(level) + if *endpoints == "" { log.Warningln("-endpoints not set, fault tolerance not be enabled.") } diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index b331b8126cadc2c5df516fb241913415b2e3e73d..652d7ba315d72ff19931b82a4b0d1c30b2ff8f37 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -40,7 +40,7 @@ func main() { idx = *index } else { e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) - idx, err = e.Register() + idx, err = e.Register(*port) candy.Must(err) cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) diff --git a/go/master/client.go b/go/master/client.go index a2ca3f3ef8ce300e3df09a302d74b56ee23c6d10..de883bf4b9a3de8d6d6e35e8e808dcf7ba54cb46 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -2,6 +2,7 @@ package master import ( "os" + "time" "github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/recordio" @@ -36,9 +37,9 @@ func (c *Client) getRecords() { for { t, err := c.getTask() if err != nil { - // TODO(helin): wait before move on with next // getTask call. - log.Errorln(err) + log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) + time.Sleep(3 * time.Second) continue } diff --git a/go/master/service.go b/go/master/service.go index a6050ab99437244dade83f2943f6649453d47fff..9cef2270ce6a51425e40b9281f93f2f9c9981329 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -215,6 +215,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { } count := index.NumChunks() + log.Infof("readChunks: file %s has %d chunks", path, count) for i := 0; i < count; i++ { chunk := Chunk{ Path: path, diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py index d6922672f4c1253e62cfe54965f6c2f3b5e6c7bf..e9264592b4f18fddf68b198d73bf907206e77a3f 100644 --- a/go/pserver/client/c/test/test_train.py +++ b/go/pserver/client/c/test/test_train.py @@ -1,5 +1,23 @@ import paddle.v2 as paddle import paddle.v2.dataset.uci_housing as uci_housing +import paddle.v2.master as master +import os +import cPickle as pickle + +etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") +etcd_endpoint = "http://" + etcd_ip + ":2379" + + +def cloud_reader(): + print "connecting to master, etcd endpoints: ", etcd_endpoint + master_client = master.client(etcd_endpoint, 5, 64) + master_client.set_dataset( + ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"]) + while 1: + r, e = master_client.next_record() + if not r: + break + yield pickle.loads(r) def main(): @@ -22,13 +40,13 @@ def main(): # create optimizer of new remote updater to pserver optimizer = paddle.optimizer.Momentum(momentum=0) - #TODO(zhihong) : replace optimizer with new OptimizerConfig - + print "etcd endoint: ", etcd_endpoint trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer, is_local=False, - pserver_spec="localhost:3000") + pserver_spec=etcd_endpoint, + use_etcd=True) # event_handler to print training and testing info def event_handler(event): @@ -47,11 +65,11 @@ def main(): print "Test %d, %.2f" % (event.pass_id, result.cost) # training + # NOTE: use uci_housing.train() as reader for non-paddlecloud training trainer.train( reader=paddle.batch( paddle.reader.shuffle( - uci_housing.train(), buf_size=500), - batch_size=2), + cloud_reader, buf_size=500), batch_size=2), feeding={'x': 0, 'y': 1}, event_handler=event_handler, diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go index 1fd3479aa88ccbbe7c5067da1e9886b65352e847..8eb2a4f4511fc7139a55a2cd47ad73a82137b260 100644 --- a/go/pserver/client/etcd_client.go +++ b/go/pserver/client/etcd_client.go @@ -12,6 +12,7 @@ import ( ) const ( + // DefaultEtcdTimeout is the default etcd timeout DefaultEtcdTimeout time.Duration = 5 * time.Second ) @@ -66,12 +67,12 @@ func (p *EtcdClient) List() []Server { for { for i := 0; i < psDesired; i++ { ctx, cancel := context.WithTimeout(context.Background(), p.timeout) - cancel() psKey := pserver.PsPath + strconv.Itoa(i) log.Debugf("checking %s", psKey) resp, err := p.client.Get(ctx, psKey) + cancel() if err != nil { - log.Infof("Get psKey= %s error, %v", psKey, err) + log.Infof("Get psKey=%s error, %v", psKey, err) time.Sleep(p.timeout) continue } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 4a694b97f47b2ab85d1e109ef7545d104194b5cf..66af4fa0b483f1caea385df61e54d871072a0375 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -49,7 +49,7 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et // Register registers the pserver on etcd // // Register returns the index of the current pserver. -func (e *EtcdClient) Register() (int, error) { +func (e *EtcdClient) Register(port int) (int, error) { var err error e.externalIP, err = networkhelper.GetExternalIP() @@ -116,7 +116,7 @@ func (e *EtcdClient) Register() (int, error) { for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) var err error - pserverIdx, err = e.registerPserverEtcd(ctx) + pserverIdx, err = e.registerPserverEtcd(ctx, port) cancel() if err != nil { log.Warn(err) @@ -140,7 +140,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) ( } // registerPserverEtcd registers pserver node on etcd using transaction. -func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { +func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { var idx int _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { registered := false @@ -156,8 +156,9 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { log.Fatal(err) } // find the first id and write info - c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) - log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) + pserverAddr := e.externalIP + ":" + strconv.Itoa(port) + c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) if kaerr != nil { log.Errorf("keepalive etcd node error: %v", kaerr) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 5fb3d1c73bc56e921f13aafd27c25224e259b3fe..0b9b83d42974151d49250bdf0e7c397f59bf6a62 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -843,7 +843,8 @@ public: bool useSparseUpdater); static ParameterUpdater* createNewRemoteUpdater( OptimizationConfig* config, - const std::string pserverSpec) throw(UnsupportError); + const std::string pserverSpec, + const bool useEtcd) throw(UnsupportError); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 1aaefdfb8107a2eaa0432211fd7df4f5f12d537f..5934cb898b5f6adc74c237b1733a7459d8437a28 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -33,11 +33,12 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ParameterUpdater *ParameterUpdater::createNewRemoteUpdater( OptimizationConfig *config, - const std::string pserverSpec) throw(UnsupportError) { + const std::string pserverSpec, + const bool useEtcd) throw(UnsupportError) { #ifndef PADDLE_WITHOUT_GOLANG auto updater = new ParameterUpdater(); updater->m->updater.reset(new paddle::NewRemoteParameterUpdater( - config->m->getConfig(), pserverSpec)); + config->m->getConfig(), pserverSpec, useEtcd)); return updater; #else throw UnsupportError(); diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8415ce67e90397b31864a90cda54b81c19b3b34e..cc5b05ff0d550ca4536b303426f77d9314b26478 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,8 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) + cc_library(operator SRCS operator.cc DEPS op_desc device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) + cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) @@ -21,4 +23,5 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -cc_library(net SRCS net.cc DEPS net_proto) +cc_library(net SRCS net.cc DEPS operator net_proto op_registry) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index a0e8788846d795429ba35715e60d4e3421e2f5ff..7311cda9a9ad282b21711d8eb0b9ba1cf9542296 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -19,18 +19,41 @@ namespace paddle { namespace framework { -PlainNet::PlainNet(const NetDesc& def) {} - -void PlainNet::InferShape(const ScopePtr& scope) const { +void PlainNet::CompleteAddOp() { + std::unordered_set input_set; + std::unordered_set output_set; + std::unordered_set temp_output; for (auto& op : ops_) { - op.InferShape(); + for (auto& ipt : op->inputs_) { + if (!Contains(output_set, ipt)) { // Not other op's output + input_set.insert(ipt); + } else { + temp_output.insert(ipt); + } + } + + for (auto& opt : op->outputs_) { + output_set.insert(opt); + } } -} - -void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const { - for (auto& op : ops_) { - op.Run(ctx); + inputs_.reserve(input_set.size()); + std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); + + outputs_.reserve(output_set.size()); + std::vector tmp_index; + tmp_index.reserve(temp_output.size()); + int idx = 0; + for (auto& opt : output_set) { + if (Contains(temp_output, opt)) { + tmp_index.push_back(idx); + } + outputs_.push_back(opt); + ++idx; } + + attrs_["temporary_index"] = tmp_index; + add_op_done_ = true; } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 0481d8f47ccbc171c0b752d48493498371ed8d4d..19a1620e29b86fbccfc112a5f85a1784a197dd0b 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -1,99 +1,51 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once +#include +#include #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" #include "paddle/platform/device_context.h" namespace paddle { namespace framework { -using namespace paddle::platform; - -// operator's index stored in a network. -typedef int OpIndex; -/** - * NOTE following codes are some definitions of unimplemented concepts. - * We write some basic implementation to make Net compilable. These APIs will - * keep updating if the concepts related are implemented. - */ - -struct OpDesc; -struct OpAttrs {}; - -class Operator { - public: - Operator(const OpDesc &def) {} - void InferShape() const {} - void Run(const DeviceContext &ctx) const {} -}; - /** - * @brief Network that manage the operators it has. + * @brief Network is also a type of Operator + * + * It will manage the operators it has. * - * Network is the container and controller of a set of operators, user can build - * a real network from a NetDesc which is a protobuf message and use - * Network.Run() * to run all the operators in the network. + * Network is the container and controller of a set of operators. * A network object knows all Operators belonging to this network. Variables, * which are inputs and outputs of these operators, are created and managed by a * hierarchy of Scope objects. * - * This is the base class of network, all the networks should implement the apis + * This is the base class of network, all the networks should implement the APIs * it defines. */ -class Net { +class Net : public OperatorBase { public: - /** - * @brief Infer shapes of all inputs and outputs of operators. - */ - virtual void InferShape(const ScopePtr &scope) const = 0; - /** - * @brief Run the network. - * - * Run all the operators and return success(true) or not, with all the - * variables are located in `scope`. `context` describes the detail execution - * environment for ops. `begin` and `end` specify the scope of `ops_` to run, - * If no positive indexes are provided, all operators in `ops_` will run. - */ - virtual void Run(const ScopePtr &scope, const DeviceContext &ctx) const = 0; - - /** - * @brief Add an Operator according to `def`. - */ - virtual OpIndex AddOp(const OpProto &def) = 0; - - /** - * @brief Add optimizer operators acctording to `attrs`. - */ - virtual void AddOptimizerOps(const OpAttrs &attrs) = 0; - - /** - * @brief Add backward operators. - */ - virtual void AddBackwardOps() = 0; - - /** - * @brief Create a network. - */ - static std::unique_ptr Create(const NetDesc &def = NetDesc()); - - virtual ~Net() {} + virtual void AddOp(const OperatorPtr& op) = 0; + virtual void CompleteAddOp() = 0; }; +using NetPtr = std::shared_ptr; + /** * @brief a basic implementation of Net. * @@ -103,18 +55,14 @@ class Net { class PlainNet : public Net { public: /** - * @brief Initialize a PlainNet. - * - * Initialize from a network describe by `def`. NetDesc is the definition of - * a network. - */ - PlainNet(const NetDesc &def); - - /** - * Infer all the operators' input and output varialbes' shapes, will be called + * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - virtual void InferShape(const ScopePtr &scope) const override; + void InferShape(const ScopePtr& scope) const override { + for (auto& op : ops_) { + op->InferShape(scope); + } + } /** * @brief Run the network. @@ -123,49 +71,32 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(const ScopePtr &scope, - const DeviceContext &ctx) const override; + void Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + for (auto& op : ops_) { + op->Run(scope, dev_ctx); + } + } /** - * @brief Add an operator to this network. + * @brief Add an operator by ptr */ - virtual OpIndex AddOp(const OpProto &def) override; + void AddOp(const OperatorPtr& op) override { + PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + ops_.push_back(op); + } - /** - * @brief Add all optimizer operators related into the network. - */ - virtual void AddOptimizerOps(const OpAttrs &attrs) override; + void CompleteAddOp() override; - /** - * @brief Add all backward operators related into the network. - */ - virtual void AddBackwardOps() override; - - virtual ~PlainNet() override {} - - protected: - /** - * @brief Build the network. - * - * Create operators accordding to `def`, will be called by the constructor. - */ - void BuildNet(const NetDesc &def); - - /** - * @brief Add an operator into this network. - * - * Add a operator which is identified as `type` and has attributes described - * in `attrs`, the `inputs` are the keys of readonly input variables, - * `outputs` are keys of mutable output variables. An `OpIndex` will be - * returned to indicate the offset of the new operator in `ops_`. - */ - OpIndex AddOp(const std::string &type, const std::vector &inputs, - const std::vector &outputs, - const OpAttrs &attrs = OpAttrs()); + std::vector ops_; private: - // the operators owned by `Network`. - std::vector ops_; + bool add_op_done_{false}; + + template + static bool Contains(T container, KeyType key) { + return container.find(key) != container.end(); + } }; } // namespace framework diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5e1c22400a73c3aa09839ef9654f87def99bc77 --- /dev/null +++ b/paddle/framework/net_op_test.cc @@ -0,0 +1,67 @@ +#include +#include +#include +#include + +namespace pd = paddle::framework; + +static int infer_shape_cnt = 0; +static int run_cnt = 0; + +class TestOp : public pd::OperatorBase { + public: + void InferShape(const paddle::framework::ScopePtr& scope) const override { + ++infer_shape_cnt; + } + void Run(const paddle::framework::ScopePtr& scope, + const paddle::platform::DeviceContext& dev_ctx) const override { + ++run_cnt; + } +}; + +template +void AssertSameVectorWithoutOrder(const std::vector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + std::unordered_set expected_set; + for (auto& tmp : expected) { + expected_set.insert(tmp); + } + for (auto& act : actual) { + ASSERT_NE(expected_set.end(), expected_set.find(act)); + } +} + +TEST(OpKernel, all) { + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net->AddOp(op1); + + auto op2 = std::make_shared(); + op2->inputs_ = {"y", "w2", "b2"}; + op2->outputs_ = {"z"}; + net->AddOp(op2); + + net->CompleteAddOp(); + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); + AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); + auto tmp_idx_iter = net->attrs_.find("temporary_index"); + ASSERT_NE(net->attrs_.end(), tmp_idx_iter); + auto& tmp_idx = boost::get>(tmp_idx_iter->second); + ASSERT_EQ(1UL, tmp_idx.size()); + ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); + + auto scope = std::make_shared(); + paddle::platform::CPUDeviceContext dev_ctx; + + net->InferShape(scope); + net->Run(scope, dev_ctx); + ASSERT_EQ(2, infer_shape_cnt); + ASSERT_EQ(2, run_cnt); + + ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); +} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 3d67541db2022935c152bf355176e1631a7e9741..b627b4a60a728cde6467b63784a5464f7c0f1bdc 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -203,10 +203,9 @@ class OpRegistry { //! Create a OpPtr by type. std::string op_type = op_desc.type(); OperatorPtr op(creators().at(op_type)()); - //! Fill op's data member. Not use constructor because it will be noising //! for Op developer. - op->desc_ = op_desc; + op->type_ = op_desc.type(); op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); @@ -239,7 +238,7 @@ class OpRegistry { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { if (outname == OperatorBase::TMP_VAR_NAME()) { - outname += op->Type(); + outname += op->type_; outname += "@"; outname += std::to_string(gUniqId.fetch_add(1)); } @@ -265,12 +264,18 @@ class OpRegisterHelper { } }; +/** + * check if MACRO is used in GLOBAL NAMESPACE. + */ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg) +/** + * Macro to Register Operator. + */ #define REGISTER_OP(__op_type, __op_class, __op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \ "REGISTER_OP must be in global namespace"); \ @@ -278,9 +283,12 @@ class OpRegisterHelper { __op_register_##__op_type##__(#__op_type); \ int __op_register_##__op_type##_handle__() { return 0; } -#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \ +/** + * Macro to Register OperatorKernel. + */ +#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##type##_##GPU_OR_CPU##__, \ + __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ struct __op_kernel_register__##type##__ { \ __op_kernel_register__##type##__() { \ @@ -291,7 +299,7 @@ class OpRegisterHelper { } \ }; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ - int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; } + int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } #define REGISTER_OP_GPU_KERNEL(type, KernelType) \ REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) @@ -299,6 +307,10 @@ class OpRegisterHelper { #define REGISTER_OP_CPU_KERNEL(type, KernelType) \ REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) +/** + * Macro to mark what Operator and Kernel we will use and tell the compiler to + * link them into target. + */ #define USE_OP_WITHOUT_KERNEL(op_type) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_without_kernel_##op_type, \ @@ -316,15 +328,16 @@ class OpRegisterHelper { __attribute__((unused)) = \ __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() -#ifdef PADDLE_ONLY_CPU -#define USE_OP(op_type) \ +// use Operator with only cpu kernel. +#define USE_OP_CPU(op_type) \ USE_OP_WITHOUT_KERNEL(op_type); \ - USE_OP_KERNEL(op_type, CPU); + USE_OP_KERNEL(op_type, CPU) +#ifdef PADDLE_ONLY_CPU +#define USE_OP(op_type) USE_OP_CPU(op_type) #else -#define USE_OP(op_type) \ - USE_OP_WITHOUT_KERNEL(op_type); \ - USE_OP_KERNEL(op_type, CPU); \ +#define USE_OP(op_type) \ + USE_OP_CPU(op_type); \ USE_OP_KERNEL(op_type, GPU) #endif diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index a467d328e1d57cfb896356ef7d883ade77968dbf..43c9886a8d3c1a51e26bef184ff42154795cff83 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -19,7 +19,7 @@ namespace framework { std::string OperatorBase::DebugString() const { std::stringstream ss; - ss << "Op(" << Type() << "), inputs:("; + ss << "Op(" << type_ << "), inputs:("; for (size_t i = 0; i < inputs_.size(); ++i) { ss << inputs_[i]; if (i != inputs_.size() - 1) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index cc166048b7bf83be3115e6469e5fde36a88b6d6d..6340c9dc960d4fb5eabea55b777261408a072dae 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -69,10 +69,8 @@ class OperatorBase { virtual void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const = 0; - std::string Type() const { return desc_.type(); } - public: - OpDesc desc_; + std::string type_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; @@ -148,7 +146,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const final { - auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index d0c3153faef6537cb4260c9d7fe093acb8f839f0..19ac4ecafa21d0a6fde57ef5e867670d7823fde0 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,14 +19,18 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorBase { +static int op_run_num = 0; + +class OpWithoutKernelTest : public OperatorBase { public: void Init() override { x = 1; } void InferShape(const ScopePtr& scope) const override {} void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override { - float scale = GetAttr("scale"); - ASSERT_NEAR(scale, 3.14, 1e-5); + op_run_num++; + ASSERT_EQ((int)inputs_.size(), 1); + ASSERT_EQ((int)outputs_.size(), 1); + ASSERT_NEAR(GetAttr("scale"), 3.14, 1e-5); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(x, 1); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); @@ -36,15 +40,14 @@ class OperatorTest : public OperatorBase { float x = 0; }; -class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: - OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto, + OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("input", "input of test op"); AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); + AddAttr("scale", "scale of cosine op"); AddComment("This is test op"); } }; @@ -52,8 +55,8 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle -REGISTER_OP(test_operator, paddle::framework::OperatorTest, - paddle::framework::OperatorTestProtoAndCheckerMaker); +REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, + paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; @@ -63,18 +66,17 @@ TEST(OperatorBase, all) { auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); - float scale = 3.14; - attr->set_f(scale); + attr->set_f(3.14); paddle::platform::CPUDeviceContext device_context; auto scope = std::make_shared(); paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->GetAttr("scale"), scale); scope->CreateVariable("OUT1"); + ASSERT_EQ(paddle::framework::op_run_num, 0); op->Run(scope, device_context); - std::cout << op->DebugString() << std::endl; + ASSERT_EQ(paddle::framework::op_run_num, 1); } namespace paddle { @@ -86,13 +88,13 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("input", "input of test op"); AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); + AddAttr("scale", "scale of cosine op"); AddComment("This is test op"); } }; +static int cpu_kernel_run_num = 0; + class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(const std::vector& inputs, @@ -102,10 +104,10 @@ class OpWithKernelTest : public OperatorWithKernel { class CPUKernelTest : public OpKernel { public: void Compute(const KernelContext& context) const { - float scale = context.op_.GetAttr("scale"); - ASSERT_NEAR(scale, 3.14, 1e-5); - std::cout << "this is cpu kernel" << std::endl; - std::cout << context.op_.DebugString() << std::endl; + cpu_kernel_run_num++; + ASSERT_EQ((int)context.op_.inputs_.size(), 1); + ASSERT_EQ((int)context.op_.outputs_.size(), 1); + ASSERT_NEAR(context.op_.GetAttr("scale"), 3.14, 1e-5); } }; @@ -131,5 +133,7 @@ TEST(OpKernel, all) { paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_device_context); + ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); } diff --git a/paddle/function/RowConvOpGpu.cu b/paddle/function/RowConvOpGpu.cu index c0b947e224313abaf4fadfb8293dc78ca085ff84..d9dcc7d59d1e3c222f5a7ce448daa8d7edb6c978 100644 --- a/paddle/function/RowConvOpGpu.cu +++ b/paddle/function/RowConvOpGpu.cu @@ -32,7 +32,7 @@ __global__ void KeRowConv(real* y, const real* x, const real* w, for (int i = tidy; i < context; i += blky) { sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; } - + __syncthreads(); for (int i = 0; i < numSeq; ++i) { @@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, int yoff = start + j; // transpose - sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; - sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? + x[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? + dy[yoff * width + xoff] : 0.0; __syncthreads(); if (tidy < (context - 1)) { yoff = yoff - context + 1; - sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? + dy[yoff * width + xoff] : 0.0; } __syncthreads(); @@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy, int yoff = start + j; // transpose - sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? + x[yoff * width + xoff] : 0.0; __syncthreads(); for (int t = 0; t < context; t++) { - sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; + sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && + yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; __syncthreads(); real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; @@ -239,7 +244,7 @@ __global__ void KeRowConvBwData(real* dx, const real* w, const real* dy, for (int i = tidy; i < context; i += blky) { sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; } - + __syncthreads(); for (int i = 0; i < numSeq; ++i) { @@ -312,7 +317,7 @@ void RowConvGrad(const GpuMatrix& outG, dim3 dimBlock(32, 32); dim3 dimGrid(DIVUP(width, dimBlock.x), 1); real* dw = filterG.getData(); - if (contextLength <= 32) { + if (contextLength <= 32) { KeRowConvBwWeight<32, 32, 32> <<>> (dw, x, dy, starts, height, width, numSeq, contextLength); diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index ab60f1a38dd4cd1d9799c0019dccae5f1c7d4310..3860facb099950a5287d3f6b89c3de38f588f568 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -155,7 +155,8 @@ RUN apt-get update &&\ paddle version ${DOCKERFILE_CUDNN_DSO} ${DOCKERFILE_GPU_ENV} - +ADD go/cmd/pserver/pserver /usr/bin/ +ADD go/cmd/master/master /usr/bin/ # default command shows the paddle version and exit CMD ["paddle", "version"] EOF diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index b359d9da2167bf459504e15c3140b3d956f417f3..a830ceba5772846cd9255a3eeb26e8d6a17dcfbc 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -28,6 +28,17 @@ NewRemoteParameterUpdater::NewRemoteParameterUpdater( newGradients_(nullptr), pserverSpec_(pserverSpec) {} +NewRemoteParameterUpdater::NewRemoteParameterUpdater( + const OptimizationConfig &config, + const std::string pserverSpec, + const bool useEtcd) + : trainerConfig_(config), + parameterClient_(-1), + newParameters_(nullptr), + newGradients_(nullptr), + pserverSpec_(pserverSpec), + useEtcd_(useEtcd) {} + void NewRemoteParameterUpdater::init( const std::vector ¶meters) { ParameterUpdater::init(parameters); @@ -38,8 +49,13 @@ void NewRemoteParameterUpdater::init( } // create parameter server client. - parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), - FLAGS_trainer_id == 0); + if (useEtcd_) { + parameterClient_ = paddle_new_etcd_pserver_client( + (char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0); + } else { + parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), + FLAGS_trainer_id == 0); + } // init new parameter and gradient. newParameters_ = initNewParameter(PARAMETER_VALUE); diff --git a/paddle/trainer/NewRemoteParameterUpdater.h b/paddle/trainer/NewRemoteParameterUpdater.h index dfed00bc216b1d41bb7520619b76702f9fe650f2..6223ba427c9b94494c2bee8f0847442f1b0574c9 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.h +++ b/paddle/trainer/NewRemoteParameterUpdater.h @@ -32,6 +32,9 @@ class NewRemoteParameterUpdater : public ParameterUpdater { public: NewRemoteParameterUpdater(const OptimizationConfig& config, const std::string pserverSpec); + NewRemoteParameterUpdater(const OptimizationConfig& config, + const std::string pserverSpec, + const bool useEtcd); ~NewRemoteParameterUpdater() { releaseNewParameter(newParameters_); releaseNewParameter(newGradients_); @@ -111,6 +114,8 @@ protected: paddle_parameter** newGradients_; /// the specification of parameter server "host1:port,host1:port" std::string pserverSpec_; + /// true if pserverSpec_ is etcd endpoint, else pserverSpec_ is pserver addr + bool useEtcd_; }; } // namespace paddle diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 4a2eb59c340f5d0d3818170e56d730330e0bab29..645f3cc0dce70752c20569523e4bab440861f6a1 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -22,6 +22,8 @@ import importlib import paddle.v2.dataset import cPickle import glob +import cPickle as pickle +import random __all__ = [ 'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', @@ -170,8 +172,6 @@ def convert(output_path, name_prefix, max_lines_to_shuffle=1000): import recordio - import cPickle as pickle - import random """ Convert data from reader to recordio format files. @@ -201,8 +201,10 @@ def convert(output_path, def write_data(w, lines): random.shuffle(lines) for i, d in enumerate(lines): - d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL) - w[i % num_shards].write(d) + # FIXME(Yancey1989): + # dumps with protocol: pickle.HIGHEST_PROTOCOL + o = pickle.dumps(d) + w[i % num_shards].write(o) w = open_writers() lines = [] diff --git a/python/paddle/v2/dataset/mq2007.py b/python/paddle/v2/dataset/mq2007.py index fd71b341662ca6f540ce44a86348e782561a97d7..cffb319ad8f56ccddba3fef63e1b6ec68e5bac1e 100644 --- a/python/paddle/v2/dataset/mq2007.py +++ b/python/paddle/v2/dataset/mq2007.py @@ -212,19 +212,19 @@ def gen_pair(querylist, partial_order="full"): for j in range(i + 1, len(querylist)): query_right = querylist[j] if query_left.relevance_score > query_right.relevance_score: - labels.append(1) + labels.append([1]) docpairs.append([ np.array(query_left.feature_vector), np.array(query_right.feature_vector) ]) elif query_left.relevance_score < query_right.relevance_score: - labels.append(1) + labels.append([1]) docpairs.append([ np.array(query_right.feature_vector), np.array(query_left.feature_vector) ]) for label, pair in zip(labels, docpairs): - yield label, pair[0], pair[1] + yield np.array(label), pair[0], pair[1] def gen_list(querylist): diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index 70f9e43c9683033233d48a750668771a4c7ba045..4c041fb509903008a7a5648a112b2472ed856aea 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -10,8 +10,9 @@ class client(object): client is a client to the master server. """ - def __init__(self, addr, buf_size): - self.c = lib.paddle_new_master_client(addr, buf_size) + def __init__(self, etcd_endpoints, timeout, buf_size): + self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout, + buf_size) def close(self): lib.paddle_release_master_client(self.c) diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 173a30a41181cd0f600f5626d9c5d1858b455fad..260a5094695d4f2318be5ca02c9b3c8f947372a7 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -1,3 +1,4 @@ +import py_paddle.swig_paddle as swig_api import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.trainer_config_helpers.optimizers as v1_optimizers """ @@ -16,7 +17,6 @@ __all__ = [ class Optimizer(object): def __init__(self, **kwargs): - import py_paddle.swig_paddle as swig_api if 'batch_size' in kwargs: del kwargs['batch_size'] # not important for python library. @@ -48,12 +48,12 @@ class Optimizer(object): return swig_api.ParameterUpdater.createRemoteUpdater( self.__opt_conf__, pass_num, use_sparse_updater) - def __create_new_remote_updater__(self, pserver_spec): + def __create_new_remote_updater__(self, pserver_spec, use_etcd): return swig_api.ParameterUpdater.createNewRemoteUpdater( - self.__opt_conf__, pserver_spec) + self.__opt_conf__, pserver_spec, use_etcd) def create_updater(self, is_local, num_passes, use_sparse_updater, - pserver_spec): + pserver_spec, use_etcd): """ create proper parameter_updater by configuration. :param is_local: create local or remote parameter updater @@ -79,7 +79,7 @@ class Optimizer(object): num_passes, use_sparse_updater) else: parameter_updater = self.__create_new_remote_updater__( - pserver_spec) + pserver_spec, use_etcd) return parameter_updater diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 92fdf98e9030993cc9f250b2f9e6317073cb49de..76bae0bb12b6c33f88530386f9cc19ae9b59f457 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -45,7 +45,8 @@ class SGD(object): update_equation, extra_layers=None, is_local=True, - pserver_spec=None): + pserver_spec=None, + use_etcd=True): if not isinstance(parameters, v2_parameters.Parameters): raise TypeError('parameters should be parameters') @@ -61,6 +62,7 @@ class SGD(object): self.__topology_in_proto__ = topology.proto() self.__is_local__ = is_local self.__pserver_spec__ = pserver_spec + self.__use_etcd__ = use_etcd self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() # # In local mode, disable sparse_remote_update. @@ -127,7 +129,7 @@ class SGD(object): self.__parameter_updater__ = self.__optimizer__.create_updater( self.__is_local__, num_passes, self.__use_sparse_updater__, - self.__pserver_spec__) + self.__pserver_spec__, self.__use_etcd__) self.__parameter_updater__.init(self.__gradient_machine__) self.__gradient_machine__.start()