未验证 提交 284bae99 编写于 作者: C Chengmo 提交者: GitHub

【Cherry-Pick】Fix device_context & Save Tensor & Gloo (#30336)

* Fix server.h include device_context (#30243)

* fix cmake
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>

* 【Paddle.Fleet】Support local save sparse param (#30175)

* add save tensor support
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>

* add sparse embedding & load vars for 2.0 & gloo bug fix (#30306)

* add sparse embedding & load vars for 2.0

Change-Id: I36b59ed5f015189dc9d9d2e34a9357722d369f1b

* fix hdfs gloo

Change-Id: Ia84d579053720ad804183e54c9a04b4f031c79c6

* fix gloo hdfs

Change-Id: I5ab982fd483cddc10adcdef0b8aa83aca976cb9e

* move loadvar/sparse embedding from incubute to static

Change-Id: I57081d3545ad2efab78c72420d2162c0eacaf3a0
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 df67b317
...@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, ...@@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
} }
} }
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->recv_and_save_table(table_id, path);
if (ret != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
}
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) { void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->print_table_stat(table_id); auto ret = communicator->_worker_ptr->print_table_stat(table_id);
......
...@@ -198,6 +198,10 @@ class FleetWrapper { ...@@ -198,6 +198,10 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path, void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode); const int mode);
// recv table from server and save it in LodTensor
void RecvAndSaveTable(const uint64_t table_id, const std::string& path);
// clear all models, release their memory // clear all models, release their memory
void ClearModel(); void ClearModel();
// clear one table // clear one table
......
set(BRPC_SRCS ps_client.cc server.cc) set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS}) set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
brpc_library(sendrecv_rpc SRCS brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS} ${BRPC_SRCS}
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -21,6 +22,7 @@ ...@@ -21,6 +22,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
const static int max_port = 65535; const static int max_port = 65535;
...@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000, ...@@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( ...@@ -903,5 +915,72 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return fut; return fut;
} }
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).accessor().fea_dim();
var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim();
break;
}
}
PADDLE_ENFORCE_NE(
var_name, "",
platform::errors::InvalidArgument(
"Cannot find table id %d to save variables.", table_id));
std::string var_store = string::Sprintf("%s", path);
MkDirRecursively(var_store.c_str());
// pull sparse from server
std::vector<float> save_huge_vec(var_num * var_shape);
std::vector<uint64_t> save_key(var_num);
std::vector<float *> save_vec;
for (size_t i = 0; i < save_key.size(); ++i) {
save_key[i] = i;
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
save_key.data(), save_key.size());
status.wait();
// create lod tensor
std::shared_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Variable *var = scope->Var(var_name);
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> vec_dim = {var_num, var_shape};
var_tensor->Resize(framework::make_ddim(vec_dim));
// copy and save
float *tensor_data = var_tensor->mutable_data<float>(place);
memcpy(tensor_data, save_huge_vec.data(),
var_num * var_shape * sizeof(float));
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
std::ofstream fout(file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", file_name));
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
fout.close();
return 0;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "brpc/controller.h" #include "brpc/controller.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient { ...@@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient {
virtual std::future<int32_t> send_client2client_msg( virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override; int msg_type, int to_client_id, const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
private: private:
virtual int32_t initialize() override; virtual int32_t initialize() override;
......
...@@ -134,6 +134,11 @@ class PSClient { ...@@ -134,6 +134,11 @@ class PSClient {
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data, int64_t *total_send_data,
void *done) = 0; void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path) = 0;
virtual void finalize_worker() = 0; virtual void finalize_worker() = 0;
// client to client, 消息发送 // client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type, virtual std::future<int32_t> send_client2client_msg(int msg_type,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX "_txt"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG(0) << "sparse table save: " << dirname << " mode: " << mode; VLOG(0) << "sparse table save: " << dirname << " mode: " << mode;
auto varname = _config.common().table_name(); auto varname = _config.common().table_name();
std::string var_store = string::Sprintf("%s/%s", dirname, varname); std::string var_store =
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
MkDirRecursively(var_store.c_str()); MkDirRecursively(var_store.c_str());
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
......
...@@ -229,18 +229,18 @@ void ParallelConnectContext::connectFullMesh( ...@@ -229,18 +229,18 @@ void ParallelConnectContext::connectFullMesh(
store.wait({key}, getTimeout()); store.wait({key}, getTimeout());
std::vector<char> allAddrs; std::vector<char> allAddrs;
auto max_retry_times = 5; auto max_retry_times = 10;
// Connect to other side of this pair // Connect to other side of this pair
while (max_retry_times > 0) { while (max_retry_times > 0) {
allAddrs = store.get(key); allAddrs = store.get(key);
VLOG(3) << "store get all address size: " << allAddrs.size() VLOG(3) << "store get all address size: " << allAddrs.size()
<< " except: " << total_add_size; << " except: " << total_add_size;
if (allAddrs.size() == static_cast<size_t>(total_add_size)) { if (allAddrs.size() == static_cast<size_t>(total_add_size)) {
break; break;
} }
sleep(5);
--max_retry_times; --max_retry_times;
} }
...@@ -272,11 +272,13 @@ void GlooWrapper::Init() { ...@@ -272,11 +272,13 @@ void GlooWrapper::Init() {
attr.iface = iface_; attr.iface = iface_;
std::shared_ptr<gloo::rendezvous::HdfsStore> file_store = nullptr; std::shared_ptr<gloo::rendezvous::HdfsStore> file_store = nullptr;
std::shared_ptr<gloo::rendezvous::HTTPStore> http_store = nullptr; std::shared_ptr<gloo::rendezvous::HTTPStore> http_store = nullptr;
auto context = std::make_shared<gloo::rendezvous::Context>(rank_, size_);
context->setTimeout(run_timeout_);
auto dev = gloo::transport::tcp::CreateDevice(attr); auto dev = gloo::transport::tcp::CreateDevice(attr);
switch (store_type_) { switch (store_type_) {
case GlooStoreType::HDFS: { case GlooStoreType::HDFS: {
auto context = std::make_shared<gloo::rendezvous::ParallelConnectContext>(
rank_, size_);
context->setTimeout(run_timeout_);
std::string cmd = std::string("${HADOOP_HOME}/bin/hadoop fs"); std::string cmd = std::string("${HADOOP_HOME}/bin/hadoop fs");
cmd += " -D fs.default.name=" + hdfs_name_; cmd += " -D fs.default.name=" + hdfs_name_;
cmd += " -D hadoop.job.ugi=" + hdfs_ugi_; cmd += " -D hadoop.job.ugi=" + hdfs_ugi_;
...@@ -286,22 +288,25 @@ void GlooWrapper::Init() { ...@@ -286,22 +288,25 @@ void GlooWrapper::Init() {
auto prefix_store = auto prefix_store =
std::make_shared<gloo::rendezvous::PrefixStore>(prefix_, *file_store); std::make_shared<gloo::rendezvous::PrefixStore>(prefix_, *file_store);
context->connectFullMesh(*prefix_store, dev); context->connectFullMesh(*prefix_store, dev);
context_ = std::move(context);
break; break;
} }
case GlooStoreType::HTTP: { case GlooStoreType::HTTP: {
auto context = std::make_shared<gloo::rendezvous::Context>(rank_, size_);
context->setTimeout(run_timeout_);
http_store = std::make_shared<gloo::rendezvous::HTTPStore>( http_store = std::make_shared<gloo::rendezvous::HTTPStore>(
http_ip_, http_port_, prefix_ + "_" + http_scope_, rank_); http_ip_, http_port_, prefix_ + "_" + http_scope_, rank_);
http_store->SetTimeoutSeconds(init_timeout_.count()); http_store->SetTimeoutSeconds(init_timeout_.count());
context->connectFullMesh(*http_store, dev); context->connectFullMesh(*http_store, dev);
http_store->Finalize(); http_store->Finalize();
VLOG(3) << "after calling http_store->Finalize."; VLOG(3) << "after calling http_store->Finalize.";
context_ = std::move(context);
break; break;
} }
default: default:
LOG(ERROR) << "unknown store type " << store_type_; LOG(ERROR) << "unknown store type " << store_type_;
exit(-1); exit(-1);
} }
context_ = std::move(context);
#endif #endif
is_initialized_ = true; is_initialized_ = true;
VLOG(3) << "gloo initialized done."; VLOG(3) << "gloo initialized done.";
......
...@@ -2,7 +2,7 @@ include(operators) ...@@ -2,7 +2,7 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
list(APPEND DISTRIBUTE_DEPS fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy) list(APPEND DISTRIBUTE_DEPS fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
...@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) {
.def("pull_dense_params", &FleetWrapper::PullDenseVarsSync) .def("pull_dense_params", &FleetWrapper::PullDenseVarsSync)
.def("save_all_model", &FleetWrapper::SaveModel) .def("save_all_model", &FleetWrapper::SaveModel)
.def("save_one_model", &FleetWrapper::SaveModelOneTable) .def("save_one_model", &FleetWrapper::SaveModelOneTable)
.def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable)
.def("sparse_table_stat", &FleetWrapper::PrintTableStat) .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
.def("stop_server", &FleetWrapper::StopServer) .def("stop_server", &FleetWrapper::StopServer)
.def("stop_worker", &FleetWrapper::FinalizeWorker) .def("stop_worker", &FleetWrapper::FinalizeWorker)
......
...@@ -545,7 +545,7 @@ class Fleet(object): ...@@ -545,7 +545,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment) export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None, mode=1): def save_persistables(self, executor, dirname, main_program=None, mode=0):
""" """
saves all persistable tensors from :code:`main_program` to saves all persistable tensors from :code:`main_program` to
......
...@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_main = compiled_config.origin_main_program.clone() _main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone() _startup = compiled_config.origin_startup_program.clone()
if not compiled_config.is_geo_mode():
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
_add_lr_decay_table_pass( _add_lr_decay_table_pass(
_main, compiled_config, _main, compiled_config,
self.user_defined_strategy.a_sync_configs["lr_decay_steps"]) self.user_defined_strategy.a_sync_configs["lr_decay_steps"])
if not compiled_config.is_geo_mode():
# for main program # for main program
_main = worker.delete_optimizer_pass(_main, compiled_config) _main = worker.delete_optimizer_pass(_main, compiled_config)
_main = worker.distributed_ops_pass(_main, compiled_config) _main = worker.distributed_ops_pass(_main, compiled_config)
......
...@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid return is_valid
def _save_sparse_params(self, executor, dirname, context, main_program): def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = [] values = []
for id, names in context.items(): for id, names in context.items():
if names not in distributed_varnames:
# only save sparse param to local
self._worker.recv_and_save_model(id, dirname)
# save sparse & distributed param on server
self._worker.save_one_model(id, dirname, mode)
values.extend(names) values.extend(names)
self._worker.save_one_model(id, dirname, 0)
return values return values
def _save_distributed_persistables(self, executor, dirname, main_program, def _save_distributed_persistables(self,
mode): executor,
dirname,
main_program,
mode=0):
denses = self.compiled_strategy.get_the_one_recv_context( denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True, is_dense=True,
...@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase):
split_dense_table=self.role_maker._is_heter_parameter_server_mode, split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True) use_origin_program=True)
recv_sparse_varnames = self._save_sparse_params(executor, dirname, sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
sparses, main_program) main_program, mode)
recv_dense_varnames = [] recv_dense_varnames = []
for id, names in denses.items(): for id, names in denses.items():
recv_dense_varnames.extend(names) recv_dense_varnames.extend(names)
saved_varnames = recv_sparse_varnames saved_varnames = sparse_varnames
remaining_vars = list( remaining_vars = list(
filter( filter(
...@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
# Todo(MrChengmo): Save optimizer status
self._save_distributed_persistables(executor, dirname, main_program, self._save_distributed_persistables(executor, dirname, main_program,
mode) mode)
...@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program()) program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables( self._ps_inference_save_persistables(executor, dirname, program)
executor, dirname, program, mode=0)
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
......
...@@ -976,7 +976,7 @@ def sparse_embedding(input, ...@@ -976,7 +976,7 @@ def sparse_embedding(input,
'fluid.contrib.layers.sparse_embedding') 'fluid.contrib.layers.sparse_embedding')
check_dtype(dtype, 'dtype', ['float32'], check_dtype(dtype, 'dtype', ['float32'],
'fluid.contrib.layers.sparse_embedding') 'paddle.static.nn.sparse_embedding')
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
......
...@@ -14,13 +14,37 @@ ...@@ -14,13 +14,37 @@
# TODO: import framework api under this directory # TODO: import framework api under this directory
__all__ = [ __all__ = [
'append_backward', 'gradients', 'Executor', 'global_scope', 'scope_guard', 'append_backward',
'BuildStrategy', 'CompiledProgram', 'Print', 'py_func', 'ExecutionStrategy', 'gradients',
'name_scope', 'ParallelExecutor', 'program_guard', 'WeightNormParamAttr', 'Executor',
'default_main_program', 'default_startup_program', 'Program', 'data', 'global_scope',
'InputSpec', 'save', 'load', 'save_inference_model', 'load_inference_model', 'scope_guard',
'load_program_state', 'set_program_state', 'cpu_places', 'cuda_places', 'BuildStrategy',
'xpu_places', 'Variable' 'CompiledProgram',
'Print',
'py_func',
'ExecutionStrategy',
'name_scope',
'ParallelExecutor',
'program_guard',
'WeightNormParamAttr',
'default_main_program',
'default_startup_program',
'Program',
'data',
'InputSpec',
'save',
'load',
'save_inference_model',
'load_inference_model',
'load_program_state',
'set_program_state',
'cpu_places',
'cuda_places',
'xpu_places',
'Variable',
'load_vars',
'save_vars',
] ]
from . import nn from . import nn
...@@ -61,5 +85,9 @@ from ..fluid.io import save #DEFINE_ALIAS ...@@ -61,5 +85,9 @@ from ..fluid.io import save #DEFINE_ALIAS
from ..fluid.io import load #DEFINE_ALIAS from ..fluid.io import load #DEFINE_ALIAS
from ..fluid.io import load_program_state #DEFINE_ALIAS from ..fluid.io import load_program_state #DEFINE_ALIAS
from ..fluid.io import set_program_state #DEFINE_ALIAS from ..fluid.io import set_program_state #DEFINE_ALIAS
from ..fluid.io import load_vars #DEFINE_ALIAS
from ..fluid.io import save_vars #DEFINE_ALIAS
from ..fluid.layers import create_parameter #DEFINE_ALIAS from ..fluid.layers import create_parameter #DEFINE_ALIAS
from ..fluid.layers import create_global_var #DEFINE_ALIAS from ..fluid.layers import create_global_var #DEFINE_ALIAS
...@@ -38,6 +38,7 @@ __all__ = [ ...@@ -38,6 +38,7 @@ __all__ = [
'spectral_norm', 'spectral_norm',
'switch_case', 'switch_case',
'while_loop', 'while_loop',
'sparse_embedding',
] ]
from .common import fc #DEFINE_ALIAS from .common import fc #DEFINE_ALIAS
...@@ -67,3 +68,4 @@ from ...fluid.layers import switch_case #DEFINE_ALIAS ...@@ -67,3 +68,4 @@ from ...fluid.layers import switch_case #DEFINE_ALIAS
from ...fluid.layers import while_loop #DEFINE_ALIAS from ...fluid.layers import while_loop #DEFINE_ALIAS
from ...fluid.input import embedding #DEFINE_ALIAS from ...fluid.input import embedding #DEFINE_ALIAS
from ...fluid.contrib.layers import sparse_embedding #DEFINE_ALIAS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册