提交 0c1d5408 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id, ...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch, std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) { const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch; VLOG(1) << "BrpcPsClient::save path " << epoch;
...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id, ...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}
std::future<int32_t> BrpcPsClient::clear() { std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
} }
...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id, ...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
} }
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys, std::vector<uint64_t> *keys,
......
...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient { ...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch, std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override; std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override; std::future<int32_t> clear(uint32_t table_id) override;
...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient { ...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys, const uint64_t *keys,
size_t num, bool is_training); size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> print_table_stat(uint32_t table_id); virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type); virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
......
...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer { ...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
private: private:
virtual int32_t initialize(); virtual int32_t initialize();
......
...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer { ...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
std::condition_variable *export_cv() { return &cv_; } std::condition_variable *export_cv() { return &cv_; }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure { ...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises; std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
}; };
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient { class PSClient {
public: public:
PSClient() {} PSClient() {}
...@@ -86,6 +122,9 @@ class PSClient { ...@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load // 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件 // 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch, virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
...@@ -93,6 +132,8 @@ class PSClient { ...@@ -93,6 +132,8 @@ class PSClient {
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据 // 清空table数据
virtual std::future<int32_t> clear() = 0; virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0; virtual std::future<int32_t> clear(uint32_t table_id) = 0;
...@@ -107,6 +148,8 @@ class PSClient { ...@@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留 size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server // firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold // this is neccessary because dense weight initialized in trainer on cold
// start // start
...@@ -117,6 +160,9 @@ class PSClient { ...@@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用 // future结束前keys和values缓冲区不能再次使用
......
...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() { ...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch, ::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() { ...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() { ::std::future<int32_t> PsLocalClient::clear() {
// TODO // TODO
return done(); return done();
...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() { ...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions, ::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
......
...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient { ...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id, virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch, virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id, virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override; virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override; virtual ::std::future<int32_t> clear(uint32_t table_id) override;
...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient { ...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num, virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id); size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> push_dense(const Region* regions, virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id); size_t region_num, size_t table_id);
......
...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer { ...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; } virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure( virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) { const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
......
...@@ -67,8 +67,6 @@ int32_t PSServer::configure( ...@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param(); _config = config.server_param();
_rank = server_rank; _rank = server_rank;
_environment = &env; _environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size(); size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param(); const auto &downpour_param = _config.downpour_server_param();
......
...@@ -69,11 +69,6 @@ class PSServer { ...@@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}); const std::vector<framework::ProgramDesc> &server_sub_program = {});
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0; virtual int32_t stop() = 0;
...@@ -94,15 +89,6 @@ class PSServer { ...@@ -94,15 +89,6 @@ class PSServer {
return &_table_map; return &_table_map;
} }
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t initialize() = 0;
...@@ -111,7 +97,6 @@ class PSServer { ...@@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config; ServerParameter _config;
PSEnvironment *_environment; PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map; std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected: protected:
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
......
...@@ -45,6 +45,17 @@ struct DataConverter { ...@@ -45,6 +45,17 @@ struct DataConverter {
std::string deconverter; std::string deconverter;
}; };
struct AccessorInfo {
size_t dim;
size_t size;
size_t select_size;
size_t select_dim;
size_t update_size;
size_t update_dim;
size_t mf_size;
size_t fea_dim;
};
class ValueAccessor { class ValueAccessor {
public: public:
ValueAccessor() {} ValueAccessor() {}
...@@ -68,6 +79,8 @@ class ValueAccessor { ...@@ -68,6 +79,8 @@ class ValueAccessor {
} }
virtual int initialize() = 0; virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0;
// value维度 // value维度
virtual size_t dim() = 0; virtual size_t dim() = 0;
// value各个维度的size // value各个维度的size
...@@ -163,6 +176,7 @@ class ValueAccessor { ...@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter _config; TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>> std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map; _data_coverter_map;
AccessorInfo _accessor_info;
}; };
REGISTER_PSCORE_REGISTERER(ValueAccessor); REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed } // namespace distributed
......
...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) { ...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return 0; return 0;
} }
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values); pull_values);
......
...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable { ...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t initialize_value();
virtual int32_t initialize_optimizer(); virtual int32_t initialize_optimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override; int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override; int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override; int32_t push_dense(const float* values, size_t num) override;
......
...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable { ...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t get_server_index_by_id(int64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(int64_t id); Node *find_node(int64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
return 0; return 0;
......
...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() { ...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return 0; return 0;
} }
int32_t CommonSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable { ...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -119,6 +119,9 @@ class BarrierTable : public Table { ...@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() { ...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return 0; return 0;
} }
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) { size_t CtrCommonAccessor::dim_size(size_t dim) {
......
...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize(); virtual int initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrDoubleAccessor::dim() { size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim); return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -58,7 +58,7 @@ struct PullSparseValue { ...@@ -58,7 +58,7 @@ struct PullSparseValue {
std::vector<int>* offset_shard) const { std::vector<int>* offset_shard) const {
offset_shard->reserve(numel_ / shard_num + 1); offset_shard->reserve(numel_ / shard_num + 1);
for (int x = 0; x < numel_; ++x) { for (int x = 0; x < numel_; ++x) {
if (feasigns_[x] % shard_num == shard_id) { if (int(feasigns_[x] % shard_num) == shard_id) {
offset_shard->push_back(x); offset_shard->push_back(x);
} }
} }
......
...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrAccessor::dim() { size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim); return DownpourCtrFeatureValue::dim(embedx_dim);
......
...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable { ...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual int32_t save(const std::string& path, const std::string& param) { virtual int32_t save(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; } virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; } virtual void clear() { return; }
......
...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() { ...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t MemorySparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values, int32_t MemorySparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
......
...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable { ...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() { ...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return 0; return 0;
} }
int32_t SSDSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values, int32_t SSDSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable { ...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data // exchange data
virtual int32_t update_table(); virtual int32_t update_table();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
......
...@@ -32,6 +32,30 @@ ...@@ -32,6 +32,30 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext {
const uint64_t *keys;
const PullSparseValue pull_value;
float *values;
char **ptr_values;
};
struct TablePushContext {
const uint64_t *keys;
const float *values;
const float **ptr_values;
};
struct TableContext {
ValueType value_type;
PullContext pull_context;
TablePushContext push_context;
size_t num;
bool use_ptr;
};
class Table { class Table {
public: public:
Table() {} Table() {}
...@@ -39,6 +63,8 @@ class Table { ...@@ -39,6 +63,8 @@ class Table {
virtual int32_t initialize(const TableParameter &config, virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0;
// for push global_step // for push global_step
......
...@@ -20,6 +20,16 @@ namespace distributed { ...@@ -20,6 +20,16 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; } int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
// value 维度 // value 维度
size_t CommMergeAccessor::dim() { return 0; } size_t CommMergeAccessor::dim() { return 0; }
......
...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class TensorTable : public Table { ...@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable() {} TensorTable() {}
virtual ~TensorTable() {} virtual ~TensorTable() {}
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false; ...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms, int connect_timeout_ms,
int max_retry) { int max_retry) {
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/framework/io/shell.h"
...@@ -54,7 +55,7 @@ using framework::Variable; ...@@ -54,7 +55,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>; using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper { class FleetWrapper : public PSWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { FleetWrapper() {
...@@ -68,7 +69,13 @@ class FleetWrapper { ...@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry // pserver request max retry
client2client_max_retry_ = 3; client2client_max_retry_ = 3;
} }
virtual int32_t Initialize(InitContext& context) { return 0; }
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config // set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry); int max_retry);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #pragma once
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#include <atomic>
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
class Scope;
class SelectedRows;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class PSCore;
using framework::LoDTensor;
using framework::Scope;
using phi::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
struct WrapperContext {
uint32_t table_id;
const std::string path;
const int mode;
const std::string meta;
};
struct InitContext {
const std::vector<int> dev_ids; // for gpu
};
class PSWrapper {
public:
virtual ~PSWrapper() {}
PSWrapper() {}
// init server
virtual int32_t Initialize(InitContext& context) = 0;
virtual void Stop() = 0;
virtual void Load(WrapperContext& context) = 0;
virtual void Save(WrapperContext& context) = 0;
};
} // end namespace distributed
} // end namespace paddle
...@@ -2032,7 +2032,15 @@ static std::string GenerateSingleOpBase( ...@@ -2032,7 +2032,15 @@ static std::string GenerateSingleOpBase(
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str = std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name); paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
if (fwd_op_type == "cast") {
// swtich in out dtype
const char* CAST_GRAD =
" auto temp_type = %s[\"in_dtype\"];\n"
" %s[\"in_dtype\"] = %s[\"out_dtype\"];\n"
" %s[\"out_dtype\"] = temp_type;\n";
grad_attrs_str += paddle::string::Sprintf(CAST_GRAD, attrs_name, attrs_name,
attrs_name, attrs_name);
}
// Handle dynamic grad attributes // Handle dynamic grad attributes
grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name); grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
generated_grad_function_body += grad_attrs_str; generated_grad_function_body += grad_attrs_str;
......
...@@ -93,7 +93,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, ...@@ -93,7 +93,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
// Create new tensor->impl and fill it with 1.0 // Create new tensor->impl and fill it with 1.0
if (t.defined()) { if (t.defined()) {
// Fill 1.0 // Fill 1.0
buffer_[slot_id][rank] = paddle::experimental::ones_like(t); buffer_[slot_id][rank] = paddle::experimental::ones_like(t, t.dtype());
} }
} }
} }
......
此差异已折叠。
此差异已折叠。
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -62,10 +63,10 @@ int TensorDtype2NumpyDtype(phi::DataType dtype) { ...@@ -62,10 +63,10 @@ int TensorDtype2NumpyDtype(phi::DataType dtype) {
return pybind11::detail::npy_api::NPY_INT32_; return pybind11::detail::npy_api::NPY_INT32_;
case phi::DataType::INT64: case phi::DataType::INT64:
return pybind11::detail::npy_api::NPY_INT64_; return pybind11::detail::npy_api::NPY_INT64_;
case phi::DataType::FLOAT16:
return pybind11::detail::NPY_FLOAT16_;
case phi::DataType::BFLOAT16: case phi::DataType::BFLOAT16:
return pybind11::detail::NPY_UINT16_; return pybind11::detail::NPY_UINT16_;
case phi::DataType::FLOAT16:
return pybind11::detail::NPY_FLOAT16_;
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
return pybind11::detail::npy_api::NPY_FLOAT_; return pybind11::detail::npy_api::NPY_FLOAT_;
case phi::DataType::FLOAT64: case phi::DataType::FLOAT64:
......
...@@ -877,6 +877,77 @@ void PadInferMeta(const MetaTensor& input, ...@@ -877,6 +877,77 @@ void PadInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype()); out->set_dtype(input.dtype());
} }
void Pad3dInferMeta(const MetaTensor& x,
const ScalarArray& paddings_scalar_array,
const std::string& mode,
float value,
const std::string& data_format,
MetaTensor* out,
MetaConfig config) {
auto x_dim = x.dims();
PADDLE_ENFORCE_EQ(x_dim.size(),
5,
errors::InvalidArgument(
"The size of Input(X)'s dimension should be equal to "
"5, but received %d. ",
x_dim.size()));
std::vector<int64_t> out_dims(x_dim.size());
out_dims[0] = x_dim[0];
if (paddings_scalar_array.FromTensor()) {
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
paddings_scalar_array.GetData().size(),
6,
errors::InvalidArgument("Shape of Input(Paddings) should be equal to "
"[6], but received [%d].",
paddings_scalar_array.GetData().size()));
}
out_dims[1] = x_dim[1];
out_dims[2] = x_dim[2];
out_dims[3] = x_dim[3];
} else {
auto paddings = paddings_scalar_array.GetData();
PADDLE_ENFORCE_EQ(
paddings.size(),
6,
errors::InvalidArgument(
"Size of paddings should be equal to 6, but received %d.",
static_cast<int>(paddings.size())));
if (data_format == "NCDHW") {
out_dims[1] = x_dim[1]; // channel
out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0))
? x_dim[2]
: (x_dim[2] + paddings[4] + paddings[5]); // depth
out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0))
? x_dim[3]
: (x_dim[3] + paddings[2] + paddings[3]); // height
out_dims[4] = ((!config.is_runtime) && (x_dim[4] < 0))
? x_dim[4]
: (x_dim[4] + paddings[0] + paddings[1]); // width
} else { // NDHWC
out_dims[4] = x_dim[4]; // channel
out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0))
? x_dim[1]
: (x_dim[1] + paddings[4] + paddings[5]); // depth
out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0))
? x_dim[2]
: (x_dim[2] + paddings[2] + paddings[3]); // height
out_dims[3] = ((!config.is_runtime) && (x_dim[3] < 0))
? x_dim[3]
: (x_dim[3] + paddings[0] + paddings[1]); // width
}
}
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype(x.dtype());
out->share_lod(x);
}
void PixelShuffleInferMeta(const MetaTensor& x, void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor, int upscale_factor,
const std::string& data_format, const std::string& data_format,
......
...@@ -147,6 +147,14 @@ void PadInferMeta(const MetaTensor& input, ...@@ -147,6 +147,14 @@ void PadInferMeta(const MetaTensor& input,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void Pad3dInferMeta(const MetaTensor& x,
const ScalarArray& paddings,
const std::string& mode,
float value,
const std::string& data_format,
MetaTensor* out,
MetaConfig config = MetaConfig());
void PixelShuffleInferMeta(const MetaTensor& x, void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor, int upscale_factor,
const std::string& data_format, const std::string& data_format,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pad3d_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
void ConstPad3DGradNCDHW(T* d_in_data,
const T* d_out_data,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = out_d - pad_front;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth ||
in_h >= in_height || in_w >= in_width)) {
d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] =
d_out_data[out_d * out_height * out_width + out_h * out_width + out_w];
}
}
template <typename T>
void ConstPad3DGradNDHWC(T* d_in_data,
const T* d_out_data,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = out_d - pad_front;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
const int out_index =
(out_d * out_height * out_width + out_h * out_width + out_w) * channels;
if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth ||
in_h >= in_height || in_w >= in_width)) {
const int in_index =
(in_d * in_height * in_width + in_h * in_width + in_w) * channels;
for (int c = 0; c < channels; ++c) {
d_in_data[in_index + c] = d_out_data[out_index + c];
}
}
}
template <typename T>
void ReflectPad3DGradNCDHW(T* d_in_data,
const T* d_out_data,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = out_d - pad_front;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_d = std::max(in_d, -in_d); // reflect by 0
in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth
in_h = std::max(in_h, -in_h); // reflect by 0
in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height
in_w = std::max(in_w, -in_w); // reflect by 0
in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width
d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] +=
d_out_data[out_d * out_height * out_width + out_h * out_width + out_w];
}
template <typename T>
void ReflectPad3DGradNDHWC(T* d_in_data,
const T* d_out_data,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = out_d - pad_front;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_d = std::max(in_d, -in_d);
in_d = std::min(in_d, 2 * in_depth - in_d - 2);
in_h = std::max(in_h, -in_h);
in_h = std::min(in_h, 2 * in_height - in_h - 2);
in_w = std::max(in_w, -in_w);
in_w = std::min(in_w, 2 * in_width - in_w - 2);
const int out_index =
(out_d * out_height * out_width + out_h * out_width + out_w) * channels;
const int in_index =
(in_d * in_height * in_width + in_h * in_width + in_w) * channels;
for (int c = 0; c < channels; ++c) {
d_in_data[in_index + c] += d_out_data[out_index + c];
}
}
template <typename T>
void ReplicatePad3DGradNCDHW(T* d_in_data,
const T* d_out_data,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0));
int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0));
int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0));
d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] +=
d_out_data[out_d * out_height * out_width + out_h * out_width + out_w];
}
template <typename T>
void ReplicatePad3DGradNDHWC(T* d_in_data,
const T* d_out_data,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0));
int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0));
int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0));
const int out_index =
(out_d * out_height * out_width + out_h * out_width + out_w) * channels;
const int in_index =
(in_d * in_height * in_width + in_h * in_width + in_w) * channels;
for (int c = 0; c < channels; ++c) {
d_in_data[in_index + c] += d_out_data[out_index + c];
}
}
template <typename T>
void CircularPad3DGradNCDHW(T* d_in_data,
const T* d_out_data,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;
d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] +=
d_out_data[out_d * out_height * out_width + out_h * out_width + out_w];
}
template <typename T>
void CircularPad3DGradNDHWC(T* d_in_data,
const T* d_out_data,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const int out_d,
const int out_h,
const int out_w) {
int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth;
int in_h = ((out_h - pad_top) % in_height + in_height) % in_height;
int in_w = ((out_w - pad_left) % in_width + in_width) % in_width;
const int out_index =
(out_d * out_height * out_width + out_h * out_width + out_w) * channels;
const int in_index =
(in_d * in_height * in_width + in_h * in_width + in_w) * channels;
for (int c = 0; c < channels; ++c) {
d_in_data[in_index + c] += d_out_data[out_index + c];
}
}
template <typename T>
void Pad3DGradNCDHW(T* d_in_data,
const int num,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const T* d_out_data,
void (*pad_func)(T*,
const T*,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int)) {
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
for (int out_d = 0; out_d < out_depth; ++out_d) {
for (int out_h = 0; out_h < out_height; ++out_h) {
for (int out_w = 0; out_w < out_width; ++out_w) {
pad_func(d_in_data,
d_out_data,
in_depth,
in_height,
in_width,
out_depth,
out_height,
out_width,
pad_front,
pad_top,
pad_left,
out_d,
out_h,
out_w);
}
}
}
d_in_data += in_depth * in_height * in_width;
d_out_data += out_depth * out_height * out_width;
}
}
}
template <typename T>
void Pad3DGradNDHWC(T* d_in_data,
const int num,
const int channels,
const int in_depth,
const int in_height,
const int in_width,
const int out_depth,
const int out_height,
const int out_width,
const int pad_front,
const int pad_top,
const int pad_left,
const T* d_out_data,
void (*pad_func)(T*,
const T*,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int)) {
for (int n = 0; n < num; ++n) {
for (int out_d = 0; out_d < out_depth; ++out_d) {
for (int out_h = 0; out_h < out_height; ++out_h) {
for (int out_w = 0; out_w < out_width; ++out_w) {
pad_func(d_in_data,
d_out_data,
channels,
in_depth,
in_height,
in_width,
out_depth,
out_height,
out_width,
pad_front,
pad_top,
pad_left,
out_d,
out_h,
out_w);
}
}
}
d_in_data += in_depth * in_height * in_width * channels;
d_out_data += out_depth * out_height * out_width * channels;
}
}
template <typename T, typename Context>
void Pad3dGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& paddings,
const std::string& mode,
float pad_value,
const std::string& data_format,
DenseTensor* x_grad) {
std::vector<int64_t> pads = paddings.GetData();
auto* d_out = &out_grad;
auto* d_in = x_grad;
auto d_in_dims = d_in->dims();
auto d_out_dims = d_out->dims();
const T* d_out_data = d_out->data<T>();
T* d_in_data = dev_ctx.template Alloc<T>(d_in);
phi::funcs::SetConstant<Context, T>()(dev_ctx, d_in, static_cast<T>(0));
const int pad_left = pads[0];
const int pad_top = pads[2];
const int pad_front = pads[4];
const int num = d_in_dims[0];
if (data_format == "NCDHW") {
const int channels = d_in_dims[1];
const int in_depth = d_in_dims[2];
const int in_height = d_in_dims[3];
const int in_width = d_in_dims[4];
const int out_depth = d_out_dims[2];
const int out_height = d_out_dims[3];
const int out_width = d_out_dims[4];
std::map<std::string,
void (*)(T*,
const T*,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int)>
func_map;
func_map["reflect"] = ReflectPad3DGradNCDHW;
func_map["replicate"] = ReplicatePad3DGradNCDHW;
func_map["circular"] = CircularPad3DGradNCDHW;
func_map["constant"] = ConstPad3DGradNCDHW;
Pad3DGradNCDHW(d_in_data,
num,
channels,
in_depth,
in_height,
in_width,
out_depth,
out_height,
out_width,
pad_front,
pad_top,
pad_left,
d_out_data,
func_map[mode]);
} else {
const int channels = d_in_dims[4];
const int in_depth = d_in_dims[1];
const int in_height = d_in_dims[2];
const int in_width = d_in_dims[3];
const int out_depth = d_out_dims[1];
const int out_height = d_out_dims[2];
const int out_width = d_out_dims[3];
std::map<std::string,
void (*)(T*,
const T*,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int,
const int)>
func_map;
func_map["reflect"] = ReflectPad3DGradNDHWC;
func_map["replicate"] = ReplicatePad3DGradNDHWC;
func_map["circular"] = CircularPad3DGradNDHWC;
func_map["constant"] = ConstPad3DGradNDHWC;
Pad3DGradNDHWC(d_in_data,
num,
channels,
in_depth,
in_height,
in_width,
out_depth,
out_height,
out_width,
pad_front,
pad_top,
pad_left,
d_out_data,
func_map[mode]);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
pad3d_grad, CPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {}
此差异已折叠。
此差异已折叠。
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Pad3dGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& paddings,
const std::string& mode,
float pad_value,
const std::string& data_format,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Pad3dKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& paddings,
const std::string& mode,
float pad_value,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Pad3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Paddings")) {
return KernelSignature(
"pad3d", {"X"}, {"Paddings", "mode", "value", "data_format"}, {"Out"});
}
return KernelSignature(
"pad3d", {"X"}, {"paddings", "mode", "value", "data_format"}, {"Out"});
}
KernelSignature Pad3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Paddings")) {
return KernelSignature("pad3d_grad",
{"X", GradVarName("Out")},
{"Paddings", "mode", "value", "data_format"},
{GradVarName("X")});
}
return KernelSignature("pad3d_grad",
{"X", GradVarName("Out")},
{"paddings", "mode", "value", "data_format"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pad3d, phi::Pad3dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pad3d_grad, phi::Pad3dGradOpArgumentMapping);
...@@ -612,7 +612,7 @@ def grad(outputs, ...@@ -612,7 +612,7 @@ def grad(outputs,
if no_grad_vars is None: if no_grad_vars is None:
no_grad_vars = [] no_grad_vars = []
elif isinstance(no_grad_vars, core.VarBase): elif isinstance(no_grad_vars, (core.VarBase, core.eager.Tensor)):
no_grad_vars = [no_grad_vars] no_grad_vars = [no_grad_vars]
elif isinstance(no_grad_vars, core.eager.Tensor): elif isinstance(no_grad_vars, core.eager.Tensor):
no_grad_vars = [no_grad_vars] no_grad_vars = [no_grad_vars]
...@@ -718,13 +718,13 @@ def to_variable(value, name=None, zero_copy=None, dtype=None): ...@@ -718,13 +718,13 @@ def to_variable(value, name=None, zero_copy=None, dtype=None):
y.shape # [3L, 2L] y.shape # [3L, 2L]
""" """
support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable, support_type = (list, tuple, np.ndarray, core.eager.Tensor, core.VarBase,
core.Tensor, core.LoDTensor) framework.Variable, core.Tensor, core.LoDTensor)
if not isinstance(value, support_type): if not isinstance(value, support_type):
raise TypeError( raise TypeError(
"The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s." "The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
% (support_type, type(value))) % (support_type, type(value)))
if isinstance(value, (core.VarBase, framework.Variable)): if isinstance(value, (core.eager.Tensor, core.VarBase, framework.Variable)):
return value return value
elif isinstance(value, (core.Tensor, core.LoDTensor)): elif isinstance(value, (core.Tensor, core.LoDTensor)):
return core.VarBase(value) return core.VarBase(value)
......
...@@ -28,6 +28,7 @@ from .math_op_patch import monkey_patch_math_varbase ...@@ -28,6 +28,7 @@ from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss from .parallel import scale_loss
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
import paddle.utils.deprecated as deprecated import paddle.utils.deprecated as deprecated
from paddle import _C_ops
class TensorHookRemoveHelper(object): class TensorHookRemoveHelper(object):
...@@ -782,7 +783,7 @@ def monkey_patch_varbase(): ...@@ -782,7 +783,7 @@ def monkey_patch_varbase():
@framework.dygraph_only @framework.dygraph_only
def clone(self): def clone(self):
return _C_ops_.assign(self) return _C_ops.assign(self)
@framework.dygraph_only @framework.dygraph_only
def value(self): def value(self):
......
...@@ -316,7 +316,8 @@ def _dygraph_not_support_(func): ...@@ -316,7 +316,8 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func): def _dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert in_dygraph_mode( assert (
in_dygraph_mode() or _in_eager_mode()
), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__ ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
......
...@@ -886,6 +886,7 @@ class TestDistributeFpnProposals(LayerTest): ...@@ -886,6 +886,7 @@ class TestDistributeFpnProposals(LayerTest):
refer_level=4, refer_level=4,
refer_scale=224, refer_scale=224,
rois_num=rois_num_dy) rois_num=rois_num_dy)
print(type(multi_rois_dy))
output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy
output_dy_np = [] output_dy_np = []
for output in output_dy: for output in output_dy:
...@@ -973,4 +974,5 @@ class TestBoxDecoderAndAssign(unittest.TestCase): ...@@ -973,4 +974,5 @@ class TestBoxDecoderAndAssign(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -50,3 +50,7 @@ class TestExponentialFamilyException(unittest.TestCase): ...@@ -50,3 +50,7 @@ class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_exception(self): def test_entropy_exception(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
paddle.distribution.ExponentialFamily.entropy(self.dist) paddle.distribution.ExponentialFamily.entropy(self.dist)
if __name__ == '__main__':
unittest.main()
...@@ -112,3 +112,7 @@ class TestKLExpfamilyExpFamily(unittest.TestCase): ...@@ -112,3 +112,7 @@ class TestKLExpfamilyExpFamily(unittest.TestCase):
kl._kl_expfamily_expfamily(self.p, self.q), kl._kl_expfamily_expfamily(self.p, self.q),
rtol=config.RTOL.get(config.DEFAULT_DTYPE), rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE)) atol=config.ATOL.get(config.DEFAULT_DTYPE))
if __name__ == '__main__':
unittest.main()
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
sys.path.append("../") sys.path.append("../")
from op_test import OpTest from op_test import OpTest
import paddle
from paddle import fluid from paddle import fluid
...@@ -115,4 +116,5 @@ class TestSequenceConcatOpError(unittest.TestCase): ...@@ -115,4 +116,5 @@ class TestSequenceConcatOpError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -39,6 +39,7 @@ class TensorTypeTest(unittest.TestCase): ...@@ -39,6 +39,7 @@ class TensorTypeTest(unittest.TestCase):
tensorx = paddle.tensor.logic.Tensor(inx) tensorx = paddle.tensor.logic.Tensor(inx)
typex_str = str(type(tensorx)) typex_str = str(type(tensorx))
expectx = "<class 'paddle.Tensor'>" expectx = "<class 'paddle.Tensor'>"
self.assertEqual((typex_str == expectx), True) self.assertEqual((typex_str == expectx), True)
......
...@@ -1202,4 +1202,5 @@ class TestMultiTensorAdam(unittest.TestCase): ...@@ -1202,4 +1202,5 @@ class TestMultiTensorAdam(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -451,4 +451,5 @@ class TestLayerTo(unittest.TestCase): ...@@ -451,4 +451,5 @@ class TestLayerTo(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid import ParamAttr, initializer from paddle.fluid import ParamAttr, initializer
import paddle
class TestCreateParameterError(unittest.TestCase): class TestCreateParameterError(unittest.TestCase):
...@@ -50,4 +51,5 @@ class TestCreateParameterError(unittest.TestCase): ...@@ -50,4 +51,5 @@ class TestCreateParameterError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None): def CTCAlign(input, lod, blank, merge_repeated, padding=0, input_length=None):
...@@ -229,4 +230,5 @@ class BadInputTestCTCAlignr(unittest.TestCase): ...@@ -229,4 +230,5 @@ class BadInputTestCTCAlignr(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -211,4 +211,5 @@ class TestDiffOpPreAppendAxis(TestDiffOp): ...@@ -211,4 +211,5 @@ class TestDiffOpPreAppendAxis(TestDiffOp):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -190,4 +190,5 @@ class TestDygraphRemoveWeightNorm(unittest.TestCase): ...@@ -190,4 +190,5 @@ class TestDygraphRemoveWeightNorm(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -209,4 +209,5 @@ class TestExponentialAPI(unittest.TestCase): ...@@ -209,4 +209,5 @@ class TestExponentialAPI(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -189,3 +189,8 @@ class TestElementwiseFmin2Op(OpTest): ...@@ -189,3 +189,8 @@ class TestElementwiseFmin2Op(OpTest):
"""test_check_grad_ingore_y""" """test_check_grad_ingore_y"""
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -1025,4 +1025,5 @@ class TestDiracInitializer3(TestDiracInitializer1): ...@@ -1025,4 +1025,5 @@ class TestDiracInitializer3(TestDiracInitializer1):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -163,4 +163,5 @@ class TestMultiplyError(unittest.TestCase): ...@@ -163,4 +163,5 @@ class TestMultiplyError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -88,4 +88,5 @@ class TestWhenTrainWithNoGrad(unittest.TestCase): ...@@ -88,4 +88,5 @@ class TestWhenTrainWithNoGrad(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -210,6 +210,9 @@ class TestIscloseOpFloat64(TestIscloseOp): ...@@ -210,6 +210,9 @@ class TestIscloseOpFloat64(TestIscloseOp):
self.atol = np.array([0]).astype("float64") self.atol = np.array([0]).astype("float64")
self.equal_nan = False self.equal_nan = False
def test_check_output(self):
self.check_output()
class TestIscloseOpLargeDimInput(TestIscloseOp): class TestIscloseOpLargeDimInput(TestIscloseOp):
def set_args(self): def set_args(self):
...@@ -222,4 +225,5 @@ class TestIscloseOpLargeDimInput(TestIscloseOp): ...@@ -222,4 +225,5 @@ class TestIscloseOpLargeDimInput(TestIscloseOp):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册