提交 317eb0aa 编写于 作者: D dongdaxiang

add incubate for unified API

上级 39449ba0
...@@ -38,10 +38,9 @@ std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL; ...@@ -38,10 +38,9 @@ std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false; bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
template<class AR> template <class AR>
paddle::ps::Archive<AR>& operator << ( paddle::ps::Archive<AR>& operator<<(paddle::ps::Archive<AR>& ar,
paddle::ps::Archive<AR>& ar, const MultiSlotType& ins) {
const MultiSlotType& ins) {
ar << ins.GetType(); ar << ins.GetType();
ar << ins.GetOffset(); ar << ins.GetOffset();
ar << ins.GetFloatData(); ar << ins.GetFloatData();
...@@ -49,10 +48,9 @@ paddle::ps::Archive<AR>& operator << ( ...@@ -49,10 +48,9 @@ paddle::ps::Archive<AR>& operator << (
return ar; return ar;
} }
template<class AR> template <class AR>
paddle::ps::Archive<AR>& operator >> ( paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
paddle::ps::Archive<AR>& ar, MultiSlotType& ins) {
MultiSlotType& ins) {
ar >> ins.MutableType(); ar >> ins.MutableType();
ar >> ins.MutableOffset(); ar >> ins.MutableOffset();
ar >> ins.MutableFloatData(); ar >> ins.MutableFloatData();
...@@ -205,6 +203,10 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -205,6 +203,10 @@ void FleetWrapper::PullDenseVarsSync(
#endif #endif
} }
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
void FleetWrapper::PushDenseVarsAsync( void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id, const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
...@@ -324,8 +326,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() { ...@@ -324,8 +326,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
clock_gettime(CLOCK_REALTIME, &tp); clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0); static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
(uint64_t)(cur_time * 1000)};
engine.seed(sseq); engine.seed(sseq);
} }
}; };
...@@ -333,7 +334,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() { ...@@ -333,7 +334,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
return r.engine; return r.engine;
} }
template<typename T> template <typename T>
void FleetWrapper::Serialize(const T& t, std::string* str) { void FleetWrapper::Serialize(const T& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
...@@ -344,7 +345,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { ...@@ -344,7 +345,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
#endif #endif
} }
template<typename T> template <typename T>
void FleetWrapper::Deserialize(T* t, const std::string& str) { void FleetWrapper::Deserialize(T* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
...@@ -357,8 +358,8 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) { ...@@ -357,8 +358,8 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) {
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string*); const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize( template void FleetWrapper::Deserialize(std::vector<MultiSlotType>*,
std::vector<MultiSlotType>*, const std::string&); const std::string&);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -16,12 +16,12 @@ limitations under the License. */ ...@@ -16,12 +16,12 @@ limitations under the License. */
#include <memory> #include <memory>
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#include <archive.h> #include <archive.h>
#include <pslib.h>
#endif #endif
#include <random>
#include <atomic> #include <atomic>
#include <ctime> #include <ctime>
#include <random>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -79,6 +79,9 @@ class FleetWrapper { ...@@ -79,6 +79,9 @@ class FleetWrapper {
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status); std::vector<::std::future<int32_t>>* push_sparse_status);
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in Async mode // Push sparse variables with labels to server in Async mode
// This is specially designed for click/show stats in server // This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names, // Param<in>: scope, table_id, var_grad_names,
...@@ -121,9 +124,9 @@ class FleetWrapper { ...@@ -121,9 +124,9 @@ class FleetWrapper {
const std::string& msg); const std::string& msg);
std::default_random_engine& LocalRandomEngine(); std::default_random_engine& LocalRandomEngine();
template<typename T> template <typename T>
void Serialize(const T& t, std::string* str); void Serialize(const T& t, std::string* str);
template<typename T> template <typename T>
void Deserialize(T* t, const std::string& str); void Deserialize(T* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
......
...@@ -43,6 +43,7 @@ namespace pybind { ...@@ -43,6 +43,7 @@ namespace pybind {
void BindFleetWrapper(py::module* m) { void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet") py::class_<framework::FleetWrapper>(*m, "Fleet")
.def(py::init()) .def(py::init())
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer) .def("init_server", &framework::FleetWrapper::InitServer)
.def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_worker", &framework::FleetWrapper::InitWorker)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# incubate directory is mainly for internal use
# after we have tested incubate APIs in industrial application for a period
# we will move stable functions into fluid
__version__ = '0.1.0'
...@@ -142,4 +142,4 @@ class DistributedOptimizer(paddle.fluid.Optimizer): ...@@ -142,4 +142,4 @@ class DistributedOptimizer(paddle.fluid.Optimizer):
no_grad_set) no_grad_set)
fleet_instance._set_opt_info(opt_info) fleet_instance._set_opt_info(opt_info)
return [a, b] return [optimize_ops, param_grads]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册