提交 1a7399b4 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_embedding_to_phi

...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id, ...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch, std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) { const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch; VLOG(1) << "BrpcPsClient::save path " << epoch;
...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id, ...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}
std::future<int32_t> BrpcPsClient::clear() { std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
} }
...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id, ...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
} }
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys, std::vector<uint64_t> *keys,
......
...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient { ...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch, std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override; std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override; std::future<int32_t> clear(uint32_t table_id) override;
...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient { ...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys, const uint64_t *keys,
size_t num, bool is_training); size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> print_table_stat(uint32_t table_id); virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type); virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
......
...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer { ...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
private: private:
virtual int32_t initialize(); virtual int32_t initialize();
......
...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer { ...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
std::condition_variable *export_cv() { return &cv_; } std::condition_variable *export_cv() { return &cv_; }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure { ...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises; std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
}; };
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient { class PSClient {
public: public:
PSClient() {} PSClient() {}
...@@ -86,6 +122,9 @@ class PSClient { ...@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load // 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件 // 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch, virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
...@@ -93,6 +132,8 @@ class PSClient { ...@@ -93,6 +132,8 @@ class PSClient {
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据 // 清空table数据
virtual std::future<int32_t> clear() = 0; virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0; virtual std::future<int32_t> clear(uint32_t table_id) = 0;
...@@ -107,6 +148,8 @@ class PSClient { ...@@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留 size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server // firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold // this is neccessary because dense weight initialized in trainer on cold
// start // start
...@@ -117,6 +160,9 @@ class PSClient { ...@@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用 // future结束前keys和values缓冲区不能再次使用
......
...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() { ...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch, ::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() { ...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() { ::std::future<int32_t> PsLocalClient::clear() {
// TODO // TODO
return done(); return done();
...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() { ...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions, ::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
......
...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient { ...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id, virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch, virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id, virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override; virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override; virtual ::std::future<int32_t> clear(uint32_t table_id) override;
...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient { ...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num, virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id); size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> push_dense(const Region* regions, virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id); size_t region_num, size_t table_id);
......
...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer { ...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; } virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure( virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) { const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
......
...@@ -67,8 +67,6 @@ int32_t PSServer::configure( ...@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param(); _config = config.server_param();
_rank = server_rank; _rank = server_rank;
_environment = &env; _environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size(); size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param(); const auto &downpour_param = _config.downpour_server_param();
......
...@@ -69,11 +69,6 @@ class PSServer { ...@@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}); const std::vector<framework::ProgramDesc> &server_sub_program = {});
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0; virtual int32_t stop() = 0;
...@@ -94,15 +89,6 @@ class PSServer { ...@@ -94,15 +89,6 @@ class PSServer {
return &_table_map; return &_table_map;
} }
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t initialize() = 0;
...@@ -111,7 +97,6 @@ class PSServer { ...@@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config; ServerParameter _config;
PSEnvironment *_environment; PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map; std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected: protected:
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
......
...@@ -45,6 +45,17 @@ struct DataConverter { ...@@ -45,6 +45,17 @@ struct DataConverter {
std::string deconverter; std::string deconverter;
}; };
struct AccessorInfo {
size_t dim;
size_t size;
size_t select_size;
size_t select_dim;
size_t update_size;
size_t update_dim;
size_t mf_size;
size_t fea_dim;
};
class ValueAccessor { class ValueAccessor {
public: public:
ValueAccessor() {} ValueAccessor() {}
...@@ -68,6 +79,8 @@ class ValueAccessor { ...@@ -68,6 +79,8 @@ class ValueAccessor {
} }
virtual int initialize() = 0; virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0;
// value维度 // value维度
virtual size_t dim() = 0; virtual size_t dim() = 0;
// value各个维度的size // value各个维度的size
...@@ -163,6 +176,7 @@ class ValueAccessor { ...@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter _config; TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>> std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map; _data_coverter_map;
AccessorInfo _accessor_info;
}; };
REGISTER_PSCORE_REGISTERER(ValueAccessor); REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed } // namespace distributed
......
...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) { ...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return 0; return 0;
} }
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values); pull_values);
......
...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable { ...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t initialize_value();
virtual int32_t initialize_optimizer(); virtual int32_t initialize_optimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override; int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override; int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override; int32_t push_dense(const float* values, size_t num) override;
......
...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable { ...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t get_server_index_by_id(int64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(int64_t id); Node *find_node(int64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
return 0; return 0;
......
...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() { ...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return 0; return 0;
} }
int32_t CommonSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable { ...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -119,6 +119,9 @@ class BarrierTable : public Table { ...@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() { ...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return 0; return 0;
} }
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) { size_t CtrCommonAccessor::dim_size(size_t dim) {
......
...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize(); virtual int initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrDoubleAccessor::dim() { size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim); return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -58,7 +58,7 @@ struct PullSparseValue { ...@@ -58,7 +58,7 @@ struct PullSparseValue {
std::vector<int>* offset_shard) const { std::vector<int>* offset_shard) const {
offset_shard->reserve(numel_ / shard_num + 1); offset_shard->reserve(numel_ / shard_num + 1);
for (int x = 0; x < numel_; ++x) { for (int x = 0; x < numel_; ++x) {
if (feasigns_[x] % shard_num == shard_id) { if (int(feasigns_[x] % shard_num) == shard_id) {
offset_shard->push_back(x); offset_shard->push_back(x);
} }
} }
......
...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrAccessor::dim() { size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim); return DownpourCtrFeatureValue::dim(embedx_dim);
......
...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable { ...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual int32_t save(const std::string& path, const std::string& param) { virtual int32_t save(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; } virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; } virtual void clear() { return; }
......
...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() { ...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t MemorySparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values, int32_t MemorySparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
......
...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable { ...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() { ...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return 0; return 0;
} }
int32_t SSDSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values, int32_t SSDSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable { ...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data // exchange data
virtual int32_t update_table(); virtual int32_t update_table();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
......
...@@ -32,6 +32,30 @@ ...@@ -32,6 +32,30 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext {
const uint64_t *keys;
const PullSparseValue pull_value;
float *values;
char **ptr_values;
};
struct TablePushContext {
const uint64_t *keys;
const float *values;
const float **ptr_values;
};
struct TableContext {
ValueType value_type;
PullContext pull_context;
TablePushContext push_context;
size_t num;
bool use_ptr;
};
class Table { class Table {
public: public:
Table() {} Table() {}
...@@ -39,6 +63,8 @@ class Table { ...@@ -39,6 +63,8 @@ class Table {
virtual int32_t initialize(const TableParameter &config, virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0;
// for push global_step // for push global_step
......
...@@ -20,6 +20,16 @@ namespace distributed { ...@@ -20,6 +20,16 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; } int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
// value 维度 // value 维度
size_t CommMergeAccessor::dim() { return 0; } size_t CommMergeAccessor::dim() { return 0; }
......
...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class TensorTable : public Table { ...@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable() {} TensorTable() {}
virtual ~TensorTable() {} virtual ~TensorTable() {}
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false; ...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms, int connect_timeout_ms,
int max_retry) { int max_retry) {
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/framework/io/shell.h"
...@@ -54,7 +55,7 @@ using framework::Variable; ...@@ -54,7 +55,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>; using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper { class FleetWrapper : public PSWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { FleetWrapper() {
...@@ -68,7 +69,13 @@ class FleetWrapper { ...@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry // pserver request max retry
client2client_max_retry_ = 3; client2client_max_retry_ = 3;
} }
virtual int32_t Initialize(InitContext& context) { return 0; }
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config // set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry); int max_retry);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #pragma once
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#include <atomic>
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
class Scope;
class SelectedRows;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class PSCore;
using framework::LoDTensor;
using framework::Scope;
using phi::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
struct WrapperContext {
uint32_t table_id;
const std::string path;
const int mode;
const std::string meta;
};
struct InitContext {
const std::vector<int> dev_ids; // for gpu
};
class PSWrapper {
public:
virtual ~PSWrapper() {}
PSWrapper() {}
// init server
virtual int32_t Initialize(InitContext& context) = 0;
virtual void Stop() = 0;
virtual void Load(WrapperContext& context) = 0;
virtual void Save(WrapperContext& context) = 0;
};
} // end namespace distributed
} // end namespace paddle
...@@ -2032,7 +2032,15 @@ static std::string GenerateSingleOpBase( ...@@ -2032,7 +2032,15 @@ static std::string GenerateSingleOpBase(
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str = std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name); paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
if (fwd_op_type == "cast") {
// swtich in out dtype
const char* CAST_GRAD =
" auto temp_type = %s[\"in_dtype\"];\n"
" %s[\"in_dtype\"] = %s[\"out_dtype\"];\n"
" %s[\"out_dtype\"] = temp_type;\n";
grad_attrs_str += paddle::string::Sprintf(CAST_GRAD, attrs_name, attrs_name,
attrs_name, attrs_name);
}
// Handle dynamic grad attributes // Handle dynamic grad attributes
grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name); grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
generated_grad_function_body += grad_attrs_str; generated_grad_function_body += grad_attrs_str;
......
此差异已折叠。
...@@ -93,7 +93,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, ...@@ -93,7 +93,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
// Create new tensor->impl and fill it with 1.0 // Create new tensor->impl and fill it with 1.0
if (t.defined()) { if (t.defined()) {
// Fill 1.0 // Fill 1.0
buffer_[slot_id][rank] = paddle::experimental::ones_like(t); buffer_[slot_id][rank] = paddle::experimental::ones_like(t, t.dtype());
} }
} }
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <thrust/host_vector.h>
#include "heter_comm.h" #include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
...@@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> { ...@@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int sample_size, int len); int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info(); void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, void move_neighbor_sample_result_to_source_gpu(
int sample_size, int *h_left, int gpu_id, int gpu_num, int *h_left, int *h_right,
int *h_right, int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
int64_t *src_sample_res, void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
int *actual_sample_size); int *h_left, int *h_right,
int *actual_sample_size,
int *total_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph); int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param); int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() { virtual int32_t end_graph_sampling() {
......
...@@ -13,10 +13,23 @@ ...@@ -13,10 +13,23 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" //#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
constexpr int WARP_SIZE = 32;
/* /*
comment 0 comment 0
this kernel just serves as an example of how to sample nodes' neighbors. this kernel just serves as an example of how to sample nodes' neighbors.
...@@ -29,20 +42,79 @@ sample_size; ...@@ -29,20 +42,79 @@ sample_size;
*/ */
__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, struct MaxFunctor {
int* actual_size, int sample_size;
int64_t* sample_result, int sample_size, HOSTDEVICE explicit inline MaxFunctor(int sample_size) {
int len) { this->sample_size = sample_size;
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; }
if (i < len) { HOSTDEVICE inline int operator()(int x) const {
if (x > sample_size) {
return sample_size;
}
return x;
}
};
struct DegreeFunctor {
GpuPsCommGraph graph;
HOSTDEVICE explicit inline DegreeFunctor(GpuPsCommGraph graph) {
this->graph = graph;
}
HOSTDEVICE inline int operator()(int i) const {
return graph.node_list[i].neighbor_size;
}
};
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample(const uint64_t rand_seed, GpuPsCommGraph graph,
int sample_size, int* index, int len,
int64_t* sample_result, int* output_idx,
int* output_offset) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
curandState rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
while (i < last_idx) {
auto node_index = index[i]; auto node_index = index[i];
actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size int degree = graph.node_list[node_index].neighbor_size;
? graph.node_list[node_index].neighbor_size const int offset = graph.node_list[node_index].neighbor_offset;
: sample_size; int output_start = output_offset[i];
int offset = graph.node_list[node_index].neighbor_offset;
for (int j = 0; j < actual_size[i]; j++) { if (degree <= sample_size) {
sample_result[sample_size * i + j] = graph.neighbor_list[offset + j]; // Just copy
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
sample_result[output_start + j] = graph.neighbor_list[offset + j];
}
} else {
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
output_idx[output_start + j] = j;
}
__syncwarp();
for (int j = sample_size + threadIdx.x; j < degree; j += WARP_SIZE) {
const int num = curand(&rng) % (j + 1);
if (num < sample_size) {
atomicMax(
reinterpret_cast<unsigned int*>(output_idx + output_start + num),
static_cast<unsigned int>(j));
}
}
__syncwarp();
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
const int perm_idx = output_idx[output_start + j] + offset;
sample_result[output_start + j] = graph.neighbor_list[perm_idx];
}
} }
i += BLOCK_WARPS;
} }
} }
...@@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { ...@@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
gpu i triggers a neighbor_sample task, gpu i triggers a neighbor_sample task,
when this task is done, when this task is done,
this function is called to move the sample result on other gpu back this function is called to move the sample result on other gpu back
to gup i and aggragate the result. to gpu i and aggragate the result.
the sample_result is saved on src_sample_res and the actual sample size for the sample_result is saved on src_sample_res and the actual sample size for
each node is saved on actual_sample_size. each node is saved on actual_sample_size.
the number of actual sample_result for the number of actual sample_result for
...@@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { ...@@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
that's what fill_dvals does. that's what fill_dvals does.
*/ */
void GpuPsGraphTable::move_neighbor_sample_size_to_source_gpu(
int gpu_id, int gpu_num, int* h_left, int* h_right, int* actual_sample_size,
int* total_sample_size) {
// This function copyed actual_sample_size to source_gpu,
// and calculate total_sample_size of each gpu sample number.
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto shard_len = h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
total_sample_size[i] = 0;
continue;
}
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
auto shard_len = h_right[i] - h_left[i] + 1;
thrust::device_vector<int> t_actual_sample_size(shard_len);
thrust::copy(actual_sample_size + h_left[i],
actual_sample_size + h_left[i] + shard_len,
t_actual_sample_size.begin());
total_sample_size[i] = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
}
}
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right, int gpu_id, int gpu_num, int* h_left, int* h_right, int64_t* src_sample_res,
int64_t* src_sample_res, int* actual_sample_size) { thrust::host_vector<int>& total_sample_size) {
/*
if total_sample_size is [4, 5, 1, 6],
then cumsum_total_sample_size is [0, 4, 9, 10];
*/
thrust::host_vector<int> cumsum_total_sample_size(gpu_num, 0);
thrust::exclusive_scan(total_sample_size.begin(), total_sample_size.end(),
cumsum_total_sample_size.begin(), 0);
for (int i = 0; i < gpu_num; i++) { for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
...@@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( ...@@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
// auto& node = path_[gpu_id][i].nodes_[cur_step]; // auto& node = path_[gpu_id][i].nodes_[cur_step];
auto& node = path_[gpu_id][i].nodes_.front(); auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync( cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size), reinterpret_cast<char*>(src_sample_res + cumsum_total_sample_size[i]),
node.val_storage + sizeof(int64_t) * shard_len, node.val_storage + sizeof(int64_t) * shard_len,
node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault, sizeof(int64_t) * total_sample_size[i], cudaMemcpyDefault,
node.out_stream); node.out_stream);
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
} }
for (int i = 0; i < gpu_num; ++i) { for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
...@@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( ...@@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
TODO: TODO:
how to optimize it to eliminate the for loop how to optimize it to eliminate the for loop
*/ */
__global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, __global__ void fill_dvalues_actual_sample_size(int* d_shard_actual_sample_size,
int* d_shard_actual_sample_size, int* d_actual_sample_size,
int* d_actual_sample_size, int* idx, int* idx, int len) {
int sample_size, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) { if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i]; d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
// d_vals[idx[i]] = d_shard_vals[i]; }
for (int j = 0; j < sample_size; j++) { }
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void fill_dvalues_sample_result(int64_t* d_shard_vals,
int64_t* d_vals,
int* d_actual_sample_size, int* idx,
int* offset, int* d_offset,
int len) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
while (i < last_idx) {
const int sample_size = d_actual_sample_size[idx[i]];
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
d_vals[offset[idx[i]] + j] = d_shard_vals[d_offset[i] + j];
} }
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
i += BLOCK_WARPS;
} }
} }
...@@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_left = [0,5],h_right = [4,8] h_left = [0,5],h_right = [4,8]
*/ */
NeighborSampleResult* result = new NeighborSampleResult(sample_size, len); NeighborSampleResult* result = new NeighborSampleResult(sample_size, len);
if (len == 0) { if (len == 0) {
return result; return result;
} }
cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_gpu(); int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_id); int dev_id = resource_->dev_id(gpu_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id);
...@@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr()); int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
...@@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
of alloc_mem_i, actual_sample_size_of_x equals ((int of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x] *)alloc_mem_i)[shard_len + x]
*/ */
create_storage(gpu_id, i, shard_len * sizeof(int64_t), create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t)); shard_len * (1 + sample_size) * sizeof(int64_t));
} }
...@@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_right[i] - h_left[i] + 1, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id)); resource_->remote_stream(i, gpu_id));
} }
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
...@@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* res_array = reinterpret_cast<int*>(node.val_storage); int* res_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = res_array + shard_len; int* actual_size_array = res_array + shard_len;
int64_t* sample_array = (int64_t*)(res_array + shard_len * 2); int64_t* sample_array = (int64_t*)(res_array + shard_len * 2);
neighbor_sample_example<<<grid_size, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>( // 1. get actual_size_array.
graph, res_array, actual_size_array, sample_array, sample_size, // 2. get sum of actual_size.
shard_len); // 3. get offset ptr
thrust::device_vector<int> t_res_array(shard_len);
thrust::copy(res_array, res_array + shard_len, t_res_array.begin());
thrust::device_vector<int> t_actual_size_array(shard_len);
thrust::transform(t_res_array.begin(), t_res_array.end(),
t_actual_size_array.begin(), DegreeFunctor(graph));
if (sample_size >= 0) {
thrust::transform(t_actual_size_array.begin(), t_actual_size_array.end(),
t_actual_size_array.begin(), MaxFunctor(sample_size));
}
thrust::copy(t_actual_size_array.begin(), t_actual_size_array.end(),
actual_size_array);
int total_sample_sum =
thrust::reduce(t_actual_size_array.begin(), t_actual_size_array.end());
thrust::device_vector<int> output_idx(total_sample_sum);
thrust::device_vector<int> output_offset(shard_len);
thrust::exclusive_scan(t_actual_size_array.begin(),
t_actual_size_array.end(), output_offset.begin(), 0);
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block_(WARP_SIZE, BLOCK_WARPS);
const dim3 grid_((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample<
BLOCK_WARPS,
TILE_SIZE><<<grid_, block_, 0, resource_->remote_stream(i, gpu_id)>>>(
0, graph, sample_size, res_array, shard_len, sample_array,
thrust::raw_pointer_cast(output_idx.data()),
thrust::raw_pointer_cast(output_offset.data()));
} }
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
...@@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
tables_[i]->rwlock_->UNLock(); tables_[i]->rwlock_->UNLock();
} }
// walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); // walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>( auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, int* d_shard_actual_sample_size_ptr =
d_idx_ptr, sample_size, len); reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
// Store total sample number of each gpu.
thrust::host_vector<int> d_shard_total_sample_size(total_gpu, 0);
move_neighbor_sample_size_to_source_gpu(
gpu_id, total_gpu, h_left, h_right, d_shard_actual_sample_size_ptr,
thrust::raw_pointer_cast(d_shard_total_sample_size.data()));
int allocate_sample_num = 0;
for (int i = 0; i < total_gpu; ++i) {
allocate_sample_num += d_shard_total_sample_size[i];
}
auto d_shard_vals =
memory::Alloc(place, allocate_sample_num * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, h_left, h_right,
d_shard_vals_ptr,
d_shard_total_sample_size);
cudaMalloc((void**)&result->val, allocate_sample_num * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
cudaMalloc((void**)&result->offset, len * sizeof(int));
int64_t* val = result->val;
int* actual_sample_size = result->actual_sample_size;
int* offset = result->offset;
fill_dvalues_actual_sample_size<<<grid_size, block_size_, 0, stream>>>(
d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, len);
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), offset, 0);
int* d_offset;
cudaMalloc(&d_offset, len * sizeof(int));
thrust::copy(d_shard_actual_sample_size_ptr,
d_shard_actual_sample_size_ptr + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), d_offset, 0);
constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE;
constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16;
const dim3 block__(WARP_SIZE, BLOCK_WARPS_);
const dim3 grid__((len + TILE_SIZE_ - 1) / TILE_SIZE_);
fill_dvalues_sample_result<BLOCK_WARPS_,
TILE_SIZE_><<<grid__, block__, 0, stream>>>(
d_shard_vals_ptr, val, actual_sample_size, d_idx_ptr, offset, d_offset,
len);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
...@@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
} }
destroy_storage(gpu_id, i); destroy_storage(gpu_id, i);
} }
cudaFree(d_offset);
return result; return result;
} }
......
...@@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) { ...@@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) {
0 --index--->0 0 --index--->0
7 --index-->2 7 --index-->2
*/ */
int64_t cpu_key[3] = {7, 0, 6}; int64_t cpu_key[3] = {7, 0, 6};
void *key; void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t)); cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3); auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
res = new int64_t[9]; res = new int64_t[7];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost); cudaMemcpy(res, neighbor_sample_res->val, 56, cudaMemcpyDeviceToHost);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23}; int *actual_sample_size = new int[3];
for (int i = 0; i < 9; i++) { cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, 12,
if (expected_sample_val[i] != -1) { cudaMemcpyDeviceToHost); // 3, 1, 3
ASSERT_EQ(res[i], expected_sample_val[i]); int *cumsum_sample_size = new int[3];
cudaMemcpy(cumsum_sample_size, neighbor_sample_res->offset, 12,
cudaMemcpyDeviceToHost); // 0, 3, 4
std::vector<std::vector<int64_t>> neighbors_;
std::vector<int64_t> neighbors_7 = {28, 29, 30, 31, 32, 33, 34, 35};
std::vector<int64_t> neighbors_0 = {0};
std::vector<int64_t> neighbors_6 = {21, 22, 23, 24, 25, 26, 27};
neighbors_.push_back(neighbors_7);
neighbors_.push_back(neighbors_0);
neighbors_.push_back(neighbors_6);
for (int i = 0; i < 3; i++) {
for (int j = cumsum_sample_size[i];
j < cumsum_sample_size[i] + actual_sample_size[i]; j++) {
bool flag = false;
for (int k = 0; k < neighbors_[i].size(); k++) {
if (res[j] == neighbors_[i][k]) {
flag = true;
break;
}
}
ASSERT_EQ(flag, true);
} }
} }
delete[] res; delete[] res;
delete[] actual_sample_size;
delete[] cumsum_sample_size;
delete neighbor_sample_res; delete neighbor_sample_res;
} }
...@@ -25,14 +25,14 @@ std::set<std::string> ignored_ops = { ...@@ -25,14 +25,14 @@ std::set<std::string> ignored_ops = {
"sum", "sum",
"clip", "clip",
"clip_by_norm", "clip_by_norm",
"square",
"reduce_sum", "reduce_sum",
"sqrt", "sqrt",
"elementwise_max", "elementwise_max",
"elementwise_div", "elementwise_div",
"elementwise_mul", "elementwise_mul",
"scale", // adamax "scale", // adamax
"assign", // adamw "assign", // adamw
"squared_l2_norm" // gradient_clip_norm
}; };
const bool startswith(const std::string& str, const std::string& pre) { const bool startswith(const std::string& str, const std::string& pre) {
...@@ -62,6 +62,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { ...@@ -62,6 +62,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
new_op.SetAttr("with_lr_sched", false); new_op.SetAttr("with_lr_sched", false);
std::set<std::string> set_ops{}; std::set<std::string> set_ops{};
// save the weight decay tensor_name and weight_decay_value for Lamb
std::vector<std::string> weight_decay_vars{};
std::vector<float> weight_decay_values{};
// use map store <op_type, op_ptr> ? // use map store <op_type, op_ptr> ?
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (!node->IsOp()) { if (!node->IsOp()) {
...@@ -75,6 +79,15 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,6 +79,15 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
auto op_role = static_cast<OpRole>(op_role_); auto op_role = static_cast<OpRole>(op_role_);
if (op_role == OpRole::kOptimize) { if (op_role == OpRole::kOptimize) {
// save weight decay value from every lamb optimizer op
if (op_type == "lamb" && op->HasAttr("weight_decay")) {
auto weight_decay_value =
BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
auto params = op->Output("ParamOut");
weight_decay_vars.push_back(params[0]);
weight_decay_values.push_back(weight_decay_value);
}
if (set_ops.count(op_type)) { if (set_ops.count(op_type)) {
continue; continue;
} }
...@@ -270,7 +283,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { ...@@ -270,7 +283,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
// seems with_lr_sched is always true // seems with_lr_sched is always true
new_op.SetAttr("with_lr_sched", true); new_op.SetAttr("with_lr_sched", true);
// setup weight deacy // setup weight decay for Lamb
new_op.SetAttr("weight_decay_vars", weight_decay_vars);
new_op.SetAttr("weight_decay_values", weight_decay_values);
// weight_decay/coeff is "scale" attr of scale_op // weight_decay/coeff is "scale" attr of scale_op
if (set_ops.count("scale") && set_ops.count("sum")) { if (set_ops.count("scale") && set_ops.count("sum")) {
if (set_ops.count("sign")) { if (set_ops.count("sign")) {
......
...@@ -30,7 +30,8 @@ void TransferCastOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -30,7 +30,8 @@ void TransferCastOpPass::ApplyImpl(ir::Graph* graph) const {
auto ipu_backend = platform::ipu::IpuBackend::GetInstance(); auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
auto enable_fp16 = ipu_backend->GetIpuStrategy()->enable_fp16; auto enable_fp16 = ipu_backend->GetIpuStrategy()->enable_fp16;
if (enable_fp16) { auto transfer_cast_op = ipu_backend->GetIpuStrategy()->transfer_cast_op;
if (enable_fp16 && transfer_cast_op) {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "popart_cast") { if (node->IsOp() && node->Op()->Type() == "popart_cast") {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr("to")) == if (BOOST_GET_CONST(std::string, node->Op()->GetAttr("to")) ==
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
USE_OP_ITSELF(batch_norm); USE_OP_ITSELF(batch_norm);
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN); USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
USE_OP(conv2d_transpose); USE_OP_ITSELF(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
......
...@@ -79,18 +79,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, ...@@ -79,18 +79,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} }
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_ipu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
...@@ -390,6 +378,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, ...@@ -390,6 +378,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
"Copying from %s to %s is not supported.", src_place, dst_place)); "Copying from %s to %s is not supported.", src_place, dst_place));
} }
#endif #endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copying from %s to %s is not supported.", src_place, dst_place));
}
#endif
} }
template <typename TENSOR> template <typename TENSOR>
...@@ -447,27 +458,15 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -447,27 +458,15 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} }
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { /* custom_device -> cpu*/ platform::is_cpu_place(dst_place)) { /* custom_device -> cpu*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
} } // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) { /* cpu -> custom_device*/ platform::is_custom_place(dst_place)) { /* cpu -> custom_device*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
} } // NOLINT
else if (platform::is_custom_place(src_place) && // NOLINT else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_custom_place( platform::is_custom_place(
dst_place)) { /* custom_device -> custom_device*/ dst_place)) { /* custom_device -> custom_device*/
...@@ -483,11 +482,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -483,11 +482,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
else if (platform::is_xpu_place(src_place) && // NOLINT else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} } // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) { platform::is_xpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} } // NOLINT
else if (platform::is_xpu_place(src_place) && // NOLINT else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) { platform::is_xpu_place(dst_place)) {
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
...@@ -502,7 +501,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -502,7 +501,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place); auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place);
xpu_ctx->Wait(); xpu_ctx->Wait();
} }
} } // NOLINT
else { // NOLINT else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place)); "Copy from %s to %s is not supported.", src_place, dst_place));
...@@ -601,6 +600,29 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -601,6 +600,29 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
"Copy from %s to %s is not supported.", src_place, dst_place)); "Copy from %s to %s is not supported.", src_place, dst_place));
} }
#endif #endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
} }
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
......
...@@ -1109,8 +1109,9 @@ void Reducer::FinalizeBackward() { ...@@ -1109,8 +1109,9 @@ void Reducer::FinalizeBackward() {
if (find_unused_vars_each_step_) { if (find_unused_vars_each_step_) {
// TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector // TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
ProcessUnusedDenseVars(); ProcessUnusedDenseVars();
#endif #endif
// Initialize local used vars // Initialize local used vars
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP_ITSELF(conv2d); USE_OP_ITSELF(conv2d);
USE_OP(conv2d_transpose); USE_OP_ITSELF(conv2d_transpose);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -40,6 +40,13 @@ class FeedVariableVisitor : public boost::static_visitor<void> { ...@@ -40,6 +40,13 @@ class FeedVariableVisitor : public boost::static_visitor<void> {
out_var_->GetMutable<framework::LoDTensor>(); out_var_->GetMutable<framework::LoDTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) { if (platform::is_same_place(in_tensor.place(), place_)) {
out_tensor->ShareDataWith(in_tensor); out_tensor->ShareDataWith(in_tensor);
#ifdef PADDLE_WITH_IPU
} else if (platform::is_ipu_place(place_)) {
// For ipu, both in_tensor and out_tensor are allocated on cpu,
// PopART will copy tensor from host automatically,
// no TensorCopy() is required here.
out_tensor->ShareDataWith(in_tensor);
#endif
} else { } else {
platform::DeviceContext *context = platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_); platform::DeviceContextPool::Instance().Get(place_);
......
...@@ -19,14 +19,16 @@ namespace operators { ...@@ -19,14 +19,16 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> { class GemmConvXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations, // The filter will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor *output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<float, float, float, int16_t>( const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
dev_ctx.x_context(), input->data<float>(), filter.data<float>(), const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize, XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> { class GemmConvGradXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
const Tensor* output_grad = const Tensor *output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = Tensor *input_grad =
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor *filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations, // The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
const XPUT *output_grad_data =
reinterpret_cast<const XPUT *>(output_grad->data<T>());
XPUT *input_grad_data = nullptr;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
} }
XPUT *filter_grad_data = nullptr;
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<float, float, float, int16_t>( int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(), dev_ctx.x_context(), input_data, filter_data, output_grad_data,
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr, input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d, conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
REGISTER_OP_XPU_KERNEL( paddle::platform::float16>);
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
conv2d_grad, conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad, depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -29,165 +33,6 @@ namespace operators { ...@@ -29,165 +33,6 @@ namespace operators {
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ConvTranspose");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
std::vector<int> output_padding =
ctx->Attrs().Get<std::vector<int>>("output_padding");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format");
const DataLayout data_layout =
ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument(
"Input of Op(conv_transpose) should be 4-D or "
"5-D Tensor. But received: %u-D Tensor, "
"the shape of input is [%s]",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"The input's dimension size and filter's dimension size of "
"Op (conv_transpose) should be equal. But received: the shape of "
"input is [%s], the dimension size of input is [%d], the shape "
"of filter is [%s], the dimension size of filter is [%d]. ",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i], 0,
platform::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
platform::errors::InvalidArgument(
"The input's dimension size minus Attr(stride)'s size must "
"be euqal to 2 for Op(conv_transpose). But received: [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
if (output_size.size())
PADDLE_ENFORCE_EQ(
output_size.size(), strides.size(),
platform::errors::InvalidArgument(
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
if (output_padding.size())
PADDLE_ENFORCE_EQ(
output_padding.size(), strides.size(),
platform::errors::InvalidArgument(
"The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
const int64_t C =
(data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ(
C, filter_dims[0],
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter channels "
"for Op(conv_transpose). But received: the input's channels is "
"[%d], the shape of input is [%s], the filter's channels is [%d], "
"the shape of filter is [%s]. The data_format is %s."
"The error may come from wrong data_format setting.",
C, in_dims, filter_dims[0], filter_dims, data_layout_str));
framework::DDim in_data_dims;
if (data_layout != DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (data_layout != DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
auto infer_shape = (ctx->IsRuntime() || in_dims[i + offset] > 0)
? (in_dims[i + offset] - 1) * strides[i] -
paddings[2 * i] - paddings[2 * i + 1] +
filter_extent
: -1;
if (output_size.size()) {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
output_size[i], infer_shape,
platform::errors::InvalidArgument(
"output_size of Op(ConvTransposeOp) should not be "
"less than the infered output size. But received output_size = "
"[%s], whose dim %d is less than the infered output size [%s]",
phi::make_ddim(output_size).to_str(), i, infer_shape));
PADDLE_ENFORCE_LT(
output_size[i], infer_shape + strides[i],
platform::errors::InvalidArgument(
"output_size of Op(ConvTransposeOp) should be less "
"than infered size + stride. But received output_size = [%s], "
"whose dim %d is not less than the infered output size (%d) + "
"stride (%d) = %d",
phi::make_ddim(output_size).to_str(), i, infer_shape,
strides[i], infer_shape + strides[i]));
}
output_shape.push_back(output_size[i]);
} else if (output_padding.size()) {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
output_padding[i], 0,
platform::errors::InvalidArgument(
"output_padding of Op(ConvTransposeOp) should not be "
"less than the 0. But received output_padding = "
"[%s], whose dim %d is less than 0",
phi::make_ddim(output_padding).to_str(), i));
PADDLE_ENFORCE_LT(
output_padding[i], std::max(strides[i], dilations[i]),
platform::errors::InvalidArgument(
"output_padding of Op(ConvTransposeOp) should be less "
"than either stride or dilation. But received output_size = "
"[%s], "
"whose dim %d is not less than either stride (%d) or "
"dilation (%d)",
phi::make_ddim(output_size).to_str(), i, strides[i],
dilations[i]));
}
output_shape.push_back((infer_shape + output_padding[i]));
} else {
output_shape.push_back(infer_shape);
}
}
if (data_layout == DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
}
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
...@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
} }
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const { const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and // Only input require reshaping, weights and
...@@ -493,17 +338,6 @@ Example: ...@@ -493,17 +338,6 @@ Example:
)DOC"); )DOC");
} }
void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = bool use_cudnn =
...@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
void ConvTransposeOpDoubleGrad::InferShape(
framework::InferShapeContext* ctx) const {
auto x_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput") &&
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
ctx->SetOutputDim("DDOutput", do_dims);
}
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
ctx->SetOutputDim("DFilter", w_dims);
}
if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
ctx->SetOutputDim("DInput", x_dims);
}
}
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = bool use_cudnn =
...@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( ...@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
namespace ops = paddle::operators; namespace ops = paddle::operators;
// conv2d_transpose // conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose, Conv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose_grad,
Conv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(
conv2d_transpose_grad_grad, Conv2dTranposeDoubleGradInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeDoubleGradInferMeta));
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>, ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>); ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR( Conv2dTranposeInferShapeFunctor);
conv2d_transpose_grad, ops::ConvTransposeOpGrad, REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad,
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>, ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>); ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad); Conv2dTranposeGradInferShapeFunctor);
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad,
REGISTER_OP_CPU_KERNEL( Conv2dTranposeDoubleGradInferShapeFunctor);
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
// conv3d_transpose // conv3d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose, Conv3dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose_grad,
Conv3dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker, ops::Conv3DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>, ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>); ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad); Conv3dTranposeInferShapeFunctor);
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad,
REGISTER_OP_CPU_KERNEL( Conv3dTranposeGradInferShapeFunctor);
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
// depthwise conv2d_transpose // depthwise conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose,
DepthWiseConv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose_grad,
DepthWiseConv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>, ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>); ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad); DepthWiseConv2dTranposeInferShapeFunctor);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad,
REGISTER_OP_CPU_KERNEL( DepthWiseConv2dTranposeGradInferShapeFunctor);
depthwise_conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_VERSION(conv_transpose) REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint( .AddCheckpoint(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/phi/kernels/gpu/depthwise_conv.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
PADDLE_ENFORCE_EQ(
groups, filter.dims()[0],
platform::errors::InvalidArgument(
"groups should be error to the 1st dimension of filter. But "
"received groups is %d and filter dimension[0] is %d",
groups, filter.dims()[0]));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
"dilations should be 1 in depthwise conv. "
"But received dilations is %d",
v));
}
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, output, static_cast<T>(0));
math::DepthwiseConvInputGradFunctor<phi::GPUContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output, filter, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, output, data_layout);
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
if (input_grad) {
math::DepthwiseConvFunctor<phi::GPUContext, T> depthwiseConv;
depthwiseConv(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, filter, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, input_grad, data_layout);
}
if (filter_grad) {
phi::funcs::SetConstant<DeviceContext, T> set_zero;
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::DepthwiseConvFilterGradFunctor<phi::GPUContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, filter_grad, data_layout);
}
}
};
} // namespace operators
} // namespace paddle
// conv2d
REGISTER_OP_CUDA_KERNEL(conv2d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
// conv3d
REGISTER_OP_CUDA_KERNEL(conv3d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
// depthwise conv2d
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose,
ops::DepthwiseConvTransposeKernel<CUDA, float>,
ops::DepthwiseConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose_grad,
ops::DepthwiseConvTransposeGradKernel<CUDA, float>,
ops::DepthwiseConvTransposeGradKernel<CUDA, double>);
...@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext; using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T> template <typename T>
...@@ -55,8 +59,8 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> { ...@@ -55,8 +59,8 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size()); filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, phi::UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize); in_data_dims, stride, ksize);
// construct NPU attr // construct NPU attr
std::vector<int> strides(4, 1); std::vector<int> strides(4, 1);
...@@ -137,8 +141,8 @@ class Conv2DTransposeGradNPUKernel : public framework::OpKernel<T> { ...@@ -137,8 +141,8 @@ class Conv2DTransposeGradNPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size()); phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
std::vector<int> strides_vec(4, 1); std::vector<int> strides_vec(4, 1);
std::vector<int> dilations_vec(4, 1); std::vector<int> dilations_vec(4, 1);
......
...@@ -8,15 +8,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -8,15 +8,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
// target_len == 2 || target_len == 4 // target_len == 2 || target_len == 4
inline std::vector<int> vector_extend(const std::vector<int>& src, inline std::vector<int> vector_extend(const std::vector<int>& src,
int target_len) { int target_len) {
...@@ -61,8 +68,8 @@ class Conv2DTransposeXPUKernel : public framework::OpKernel<T> { ...@@ -61,8 +68,8 @@ class Conv2DTransposeXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size()); phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int img_yc = static_cast<int>(input->dims()[1]); const int img_yc = static_cast<int>(input->dims()[1]);
...@@ -135,8 +142,8 @@ class Conv2DTransposeGradXPUKernel : public framework::OpKernel<T> { ...@@ -135,8 +142,8 @@ class Conv2DTransposeGradXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size()); phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int img_yc = static_cast<int>(input->dims()[1]); const int img_yc = static_cast<int>(input->dims()[1]);
......
...@@ -243,8 +243,6 @@ class ConcatFunctor<platform::MLUDeviceContext, T> { ...@@ -243,8 +243,6 @@ class ConcatFunctor<platform::MLUDeviceContext, T> {
const int axis_t = axis; const int axis_t = axis;
const int ins_size_t = ins_size; const int ins_size_t = ins_size;
auto place = context.GetPlace();
output->mutable_data<T>(place);
// mlu should do sth // mlu should do sth
// init ins tensors // init ins tensors
...@@ -295,7 +293,6 @@ class SplitFunctor<platform::MLUDeviceContext, T> { ...@@ -295,7 +293,6 @@ class SplitFunctor<platform::MLUDeviceContext, T> {
std::vector<cnnlTensorDescriptor_t> desc_vector; std::vector<cnnlTensorDescriptor_t> desc_vector;
for (size_t i = 0; i < out_size; i++) { for (size_t i = 0; i < out_size; i++) {
(*outputs)[i]->Resize(outs_dims[i]); (*outputs)[i]->Resize(outs_dims[i]);
(*outputs)[i]->mutable_data<T>(context.GetPlace());
output_descs.emplace_back( output_descs.emplace_back(
MLUCnnlTensorDesc(*(*outputs)[i], CNNL_LAYOUT_ARRAY, MLUCnnlTensorDesc(*(*outputs)[i], CNNL_LAYOUT_ARRAY,
ToCnnlDataType((*outputs)[i]->dtype()))); ToCnnlDataType((*outputs)[i]->dtype())));
......
此差异已折叠。
此差异已折叠。
...@@ -117,7 +117,7 @@ endif() ...@@ -117,7 +117,7 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# seperate init from device_context to avoid cycle dependencies # seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context custom_kernel) cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool)
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
......
...@@ -13,7 +13,7 @@ IF(WITH_IPU) ...@@ -13,7 +13,7 @@ IF(WITH_IPU)
"ipu_device.cc" "ipu_device.cc"
) )
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper) cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper popdist)
cc_library(ipu_info SRCS ${IPU_INFO_SRC} DEPS popart-only enforce) cc_library(ipu_info SRCS ${IPU_INFO_SRC} DEPS popart-only enforce)
add_library(paddle_ipu SHARED ${PADDLE_IPU_SRC}) add_library(paddle_ipu SHARED ${PADDLE_IPU_SRC})
add_dependencies(paddle_ipu ipu_backend) add_dependencies(paddle_ipu ipu_backend)
......
...@@ -24,6 +24,8 @@ static constexpr const char *sIpuIndexAttr = "ipu_index"; ...@@ -24,6 +24,8 @@ static constexpr const char *sIpuIndexAttr = "ipu_index";
static constexpr const char *sIpuStageAttr = "ipu_stage"; static constexpr const char *sIpuStageAttr = "ipu_stage";
static constexpr const char *sMatmulSerializeFactor = "serialize_factor"; static constexpr const char *sMatmulSerializeFactor = "serialize_factor";
static constexpr const char *sMatmulSerializeMode = "serialize_mode"; static constexpr const char *sMatmulSerializeMode = "serialize_mode";
static constexpr const char *sAvailMemAttribute = "__available_memory";
static constexpr const char *sOpNamescope = "op_namescope";
static constexpr const char *sOpIdentifyIdAttr = "op_identify_id"; static constexpr const char *sOpIdentifyIdAttr = "op_identify_id";
static constexpr const char *sDebugInfoId = "__debug_info_id"; static constexpr const char *sDebugInfoId = "__debug_info_id";
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
gather_srcs(infrt_src SRCS
tensor_map.cc
)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册