提交 317eb0aa 编写于 作者: D dongdaxiang

add incubate for unified API

上级 39449ba0
......@@ -38,9 +38,8 @@ std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
#ifdef PADDLE_WITH_PSLIB
template<class AR>
paddle::ps::Archive<AR>& operator << (
paddle::ps::Archive<AR>& ar,
template <class AR>
paddle::ps::Archive<AR>& operator<<(paddle::ps::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
ar << ins.GetOffset();
......@@ -49,9 +48,8 @@ paddle::ps::Archive<AR>& operator << (
return ar;
}
template<class AR>
paddle::ps::Archive<AR>& operator >> (
paddle::ps::Archive<AR>& ar,
template <class AR>
paddle::ps::Archive<AR>& operator>>(paddle::ps::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
ar >> ins.MutableOffset();
......@@ -205,6 +203,10 @@ void FleetWrapper::PullDenseVarsSync(
#endif
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
......@@ -324,8 +326,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++,
(uint64_t)(cur_time * 1000)};
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
};
......@@ -333,7 +334,7 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
return r.engine;
}
template<typename T>
template <typename T>
void FleetWrapper::Serialize(const T& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
......@@ -344,7 +345,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
#endif
}
template<typename T>
template <typename T>
void FleetWrapper::Deserialize(T* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
......@@ -357,8 +358,8 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) {
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize(
std::vector<MultiSlotType>*, const std::string&);
template void FleetWrapper::Deserialize(std::vector<MultiSlotType>*,
const std::string&);
} // end namespace framework
} // end namespace paddle
......@@ -16,12 +16,12 @@ limitations under the License. */
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#include <archive.h>
#include <pslib.h>
#endif
#include <random>
#include <atomic>
#include <ctime>
#include <random>
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
......@@ -79,6 +79,9 @@ class FleetWrapper {
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* push_sparse_status);
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in Async mode
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad_names,
......@@ -121,9 +124,9 @@ class FleetWrapper {
const std::string& msg);
std::default_random_engine& LocalRandomEngine();
template<typename T>
template <typename T>
void Serialize(const T& t, std::string* str);
template<typename T>
template <typename T>
void Deserialize(T* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
......
......@@ -43,6 +43,7 @@ namespace pybind {
void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet")
.def(py::init())
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("stop_server", &framework::FleetWrapper::StopServer)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# incubate directory is mainly for internal use
# after we have tested incubate APIs in industrial application for a period
# we will move stable functions into fluid
__version__ = '0.1.0'
......@@ -142,4 +142,4 @@ class DistributedOptimizer(paddle.fluid.Optimizer):
no_grad_set)
fleet_instance._set_opt_info(opt_info)
return [a, b]
return [optimize_ops, param_grads]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册