提交 c72cf5fa 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_temporal_shift_to_phi

...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS ...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS
-DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH} -DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH}
-DWITH_STATIC=OFF -DWITH_STATIC=OFF
-DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${PADDLE2ONNX_INSTALL_DIR}/${LIBDIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
......
...@@ -4,7 +4,7 @@ if(WITH_PYTHON) ...@@ -4,7 +4,7 @@ if(WITH_PYTHON)
endif() endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto) proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog) set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
......
...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; } ...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource."; VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
server_.Join(); server_.Join();
#endif #endif
...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank, ...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank,
IsInit(), true, IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized.")); "Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) { while (retry_time < 10) {
++retry_time; ++retry_time;
...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() { ...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() {
LOG(INFO) << "No need listen to port since training on single card."; LOG(INFO) << "No need listen to port since training on single card.";
return; return;
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message // function keep listen the port and handle the message
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() { ...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() {
#endif #endif
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(int64_t dst_rank, bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank); const auto& dst_addr = GetAddr(dst_rank);
......
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
...@@ -64,8 +63,7 @@ class MessageBus final { ...@@ -64,8 +63,7 @@ class MessageBus final {
const std::string& GetAddr(int64_t rank) const; const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src) // send the message inter rank (dst is different rank with src)
bool SendInterRank(int64_t dst_rank, bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
...@@ -81,8 +79,7 @@ class MessageBus final { ...@@ -81,8 +79,7 @@ class MessageBus final {
// the ip needs to be listened // the ip needs to be listened
std::string addr_; std::string addr_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
MessageServiceImpl message_service_; MessageServiceImpl message_service_;
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#pragma once #pragma once
#include "brpc/server.h" #include "brpc/server.h"
......
...@@ -115,6 +115,7 @@ message TableParameter { ...@@ -115,6 +115,7 @@ message TableParameter {
optional CommonAccessorParameter common = 6; optional CommonAccessorParameter common = 6;
optional TableType type = 7; optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ]; optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
} }
message TableAccessorParameter { message TableAccessorParameter {
...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule ...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
optional double ada_epsilon = 5 [ default = 1e-08 ]; optional double ada_epsilon = 5 [ default = 1e-08 ];
repeated float weight_bounds = 6; repeated float weight_bounds = 6;
} }
message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ];
optional string gpups_graph_sample_class = 3
[ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ];
optional bool use_cache = 5 [ default = true ];
optional float cache_ratio = 6 [ default = 0.3 ];
optional int32 cache_ttl = 7 [ default = 5 ];
optional GraphFeature graph_feature = 8;
optional string table_name = 9 [ default = "" ];
optional string table_type = 10 [ default = "" ];
optional int32 gpups_mode_shard_num = 11 [ default = 127 ];
optional int32 gpu_num = 12 [ default = 1 ];
}
message GraphFeature {
repeated string name = 1;
repeated string dtype = 2;
repeated int32 shape = 3;
}
\ No newline at end of file
...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service( ...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
} }
} }
int GraphBrpcClient::get_server_index_by_id(uint64_t id) { int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num(); int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0 int shard_per_server = shard_num % server_size == 0
? shard_num / server_size ? shard_num / server_size
...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) { ...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
} }
std::future<int32_t> GraphBrpcClient::get_node_feat( std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) { std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) { ...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::add_graph_node( std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list, uint32_t table_id, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) { std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket; std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0; bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>()); if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
size_t node_num = request_bucket[request_idx].size(); size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
if (add_weight) { if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1]; bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++) for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::remove_graph_node( std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) { uint32_t table_id, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1); std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]); node_id_list[query_idx]);
...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index)); getServiceStub(get_cmd_channel(server_index));
...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
} }
// char* &buffer,int &actual_size // char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res, // std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res, std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight, std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) { int server_index) {
if (server_index != -1) { if (server_index != -1) {
...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[node_idx].emplace_back( res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[node_idx].emplace_back( res_weight[node_idx].emplace_back(
...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_table_id(table_id); closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)node_ids.data(), closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size()); sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
request2server.push_back(server_index); request2server.push_back(server_index);
} }
// res.push_back(std::vector<std::pair<uint64_t, float>>()); // res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({}); res.push_back({});
if (need_weight) { if (need_weight) {
res_weight.push_back({}); res_weight.push_back({});
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[query_idx].emplace_back( res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[query_idx].emplace_back( res_weight[query_idx].emplace_back(
...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
} }
std::future<int32_t> GraphBrpcClient::random_sample_nodes( std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size, uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) { std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0; int index = 0;
while (index < bytes_size) { while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index)); ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size; index += GraphNode::id_size;
} }
delete[] buffer; delete[] buffer;
...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
} }
std::future<int32_t> GraphBrpcClient::set_node_feat( std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) { const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets( std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num); request_call_num);
...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
......
...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {} virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<uint64_t>>& res, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight, std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1); int server_index = -1);
...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id, virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index, int server_index,
int sample_size, int sample_size,
std::vector<uint64_t>& ids); std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat( virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res); std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat( virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features); const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id); virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list, uint32_t table_id, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list); std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id, virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit, size_t size_limit,
...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id, virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path); std::string path);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list); uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t initialize();
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(uint64_t id); int get_server_index_by_id(int64_t id);
void set_local_channel(int index) { void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index); this->local_channel = get_cmd_channel(index);
} }
......
...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table, ...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list; std::vector<bool> is_weighted_list;
if (request.params_size() == 2) { if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool); size_t weight_list_size = request.params(1).size() / sizeof(bool);
...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"graph_get_node_feat request requires at least 1 argument"); "graph_get_node_feat request requires at least 1 argument");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids); ((GraphTable *)table)->remove_graph_node(node_ids);
return 0; return 0;
...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"); "graph_random_sample_neighbors request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str()); bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
int32_t GraphBrpcService::graph_random_sample_nodes( int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str()); size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) == if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 2 arguments"); "graph_get_node_feat request requires at least 2 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
"at least 3 arguments"); "at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t), size_t node_num = request.params(0).size() / sizeof(int64_t),
size_of_size_t = sizeof(size_t); size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str()); bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable // std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size); // *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = get_rank(); size_t rank = get_rank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::shared_ptr<char>> local_buffers; std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes; std::vector<int> local_actual_sizes;
std::vector<size_t> seq; std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"); "graph_set_node_feat request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
......
...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name, ...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
} }
} }
void add_graph_node(std::vector<uint64_t> node_ids, void add_graph_node(std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {} std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {} void remove_graph_node(std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num, void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types, std::vector<std::string> node_types,
std::vector<std::string> edge_types) { std::vector<std::string> edge_types) {
...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) { ...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
} }
void GraphPyClient::add_graph_node(std::string name, void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) { std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name, ...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
} }
void GraphPyClient::remove_graph_node(std::string name, void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) { std::vector<int64_t>& node_ids) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids); auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
} }
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name, GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges) { bool return_edges) {
// std::vector<std::vector<std::pair<uint64_t, float>>> v; std::vector<std::vector<int64_t>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1; std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
// res.first[1]: slice index // res.first[1]: slice index
// res.first[2]: src nodes // res.first[2]: src nodes
// res.second: edges weight // res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({}); res.first.push_back({});
res.first.push_back({}); res.first.push_back({});
if (return_edges) res.first.push_back({}); if (return_edges) res.first.push_back({});
...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name, ...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
status.wait(); status.wait();
} }
} }
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index, int server_index,
int sample_size) { int sample_size) {
std::vector<uint64_t> v; std::vector<int64_t> v;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, ...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
// (name, dtype, ndarray) // (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) { std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v( std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size())); feature_names.size(), std::vector<std::string>(node_ids.size()));
...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( ...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
} }
void GraphPyClient::set_node_feat( void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) { const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) { if (this->table_id_map.count(node_type)) {
......
...@@ -70,18 +70,34 @@ class GraphPyService { ...@@ -70,18 +70,34 @@ class GraphPyService {
::paddle::distributed::TableAccessorParameter* accessor_proto = ::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor(); sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto = // ::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common(); // sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
::paddle::distributed::GraphFeature* graph_feature =
graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name(table_name);
graph_proto->set_table_type(table_type);
graph_proto->set_use_cache(false);
// Set GraphTable Parameter // Set GraphTable Parameter
common_proto->set_table_name(table_name); // common_proto->set_table_name(table_name);
common_proto->set_name(table_type); // common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
for (size_t i = 0; i < feat_name.size(); i++) { for (size_t i = 0; i < feat_name.size(); i++) {
common_proto->add_params(feat_dtype[i]); graph_feature->add_dtype(feat_dtype[i]);
common_proto->add_dims(feat_shape[i]); graph_feature->add_shape(feat_shape[i]);
common_proto->add_attributes(feat_name[i]); graph_feature->add_name(feat_name[i]);
} }
accessor_proto->set_accessor_class("CommMergeAccessor"); accessor_proto->set_accessor_class("CommMergeAccessor");
} }
...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService { ...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
void load_edge_file(std::string name, std::string filepath, bool reverse); void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath); void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name); void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids, void add_graph_node(std::string name, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list); std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids); void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; } int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; } void set_client_id(int client_id) { this->client_id = client_id; }
void start_client(); void start_client();
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids, batch_sample_neighbors(std::string name, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges); bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index, std::vector<int64_t> random_sample_nodes(std::string name, int server_index,
int sample_size); int sample_size);
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names); std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit, void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl); size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids, void set_node_feat(std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features); const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index, std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
......
...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro ...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp) target_link_libraries(table -fopenmp)
...@@ -38,10 +38,14 @@ ...@@ -38,10 +38,14 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h" #include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class GraphShard { class GraphShard {
...@@ -51,37 +55,37 @@ class GraphShard { ...@@ -51,37 +55,37 @@ class GraphShard {
~GraphShard(); ~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; } std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step); std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) { std::vector<int64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res; std::vector<int64_t> res;
for (int i = start; i < end && i < (int)bucket.size(); i++) { for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id()); res.push_back(bucket[i]->get_id());
} }
return res; return res;
} }
GraphNode *add_graph_node(uint64_t id); GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node); GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id); FeatureNode *add_feature_node(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
void delete_node(uint64_t id); void delete_node(int64_t id);
void clear(); void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight); void add_neighbor(int64_t id, int64_t dst_id, float weight);
std::unordered_map<uint64_t, int> &get_node_location() { std::unordered_map<int64_t, int> &get_node_location() {
return node_location; return node_location;
} }
private: private:
std::unordered_map<uint64_t, int> node_location; std::unordered_map<int64_t, int> node_location;
std::vector<Node *> bucket; std::vector<Node *> bucket;
}; };
enum LRUResponse { ok = 0, blocked = 1, err = 2 }; enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey { struct SampleKey {
uint64_t node_key; int64_t node_key;
size_t sample_size; size_t sample_size;
bool is_weighted; bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted) SampleKey(int64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key), : node_key(_node_key),
sample_size(_sample_size), sample_size(_sample_size),
is_weighted(_is_weighted) {} is_weighted(_is_weighted) {}
...@@ -300,7 +304,7 @@ class ScaledLRU { ...@@ -300,7 +304,7 @@ class ScaledLRU {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count; node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
if (node_size <= size_t(1.1 * size_limit) + 1) return 0; if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) { if (pthread_rwlock_wrlock(&rwlock) == 0) {
// VLOG(0)<"in shrink\n"; // VLOG(0)<"in shrink\n";
global_count = 0; global_count = 0;
...@@ -308,9 +312,9 @@ class ScaledLRU { ...@@ -308,9 +312,9 @@ class ScaledLRU {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count; global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
// VLOG(0)<<"global_count "<<global_count<<"\n"; // VLOG(0)<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) { if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit; size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) { for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0; lru_pool[i].total_diff = 0;
lru_pool[i].remove_count += lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) / 1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
...@@ -352,9 +356,69 @@ class ScaledLRU { ...@@ -352,9 +356,69 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>; friend class RandomSampleLRU<K, V>;
}; };
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
public: public:
GraphTable() { use_cache = false; } GraphTable() {
use_cache = false;
shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable(); virtual ~GraphTable();
virtual int32_t pull_graph_list(int start, int size, virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer, std::unique_ptr<char[]> &buffer,
...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable { ...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
int step); int step);
virtual int32_t random_sample_neighbors( virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size, int64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes, bool need_weight); std::vector<int> &actual_sizes, bool need_weight);
...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable { ...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
int &actual_sizes); int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges( virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res); std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize(); virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param); int32_t load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path); int32_t load_graph_split_config(const std::string &path);
...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable { ...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type); int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list, int32_t add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list); std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list); int32_t remove_graph_node(std::vector<int64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable { ...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index); virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
virtual uint32_t get_thread_pool_index(uint64_t node_id); _shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(int64_t shard_index);
virtual uint32_t get_thread_pool_index(int64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str); virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids, virtual int32_t get_node_feat(const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res); std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat( virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res); const std::vector<std::vector<std::string>> &res);
...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable { ...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
} }
return 0; return 0;
} }
#ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling();
}
virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling();
}
virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
graph_sampler->set_graph_sample_callback(callback);
return 0;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
#endif
protected: protected:
std::vector<GraphShard *> shards, extra_shards; std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24; int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3; const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name; std::vector<std::string> feat_name;
...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable { ...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool; std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru; std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes; std::unordered_set<int64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index; std::unordered_map<int64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes; bool use_cache, use_duplicate_nodes;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
size_t gpu_num;
int node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<int64_t, std::vector<int64_t>>>
sample_neighbors_map;
}; };
#endif
} // namespace distributed } // namespace distributed
}; // namespace paddle }; // namespace paddle
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace paddle {
...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer { ...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer {
seed_ = static_cast<unsigned int>(std::stoi(attrs[1])); seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]); mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]); std_ = std::stof(attrs[3]);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; std::uniform_real_distribution<float> dist_(
}; std::numeric_limits<float>::min(), 1.0);
float a_normal_cdf = normal_cdf((-2.0 - mean_) / std_);
float b_normal_cdf = normal_cdf((2.0 - mean_) / std_);
std::uniform_real_distribution<float> dist_(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_); random_engine_ = framework::GetCPURandomEngine(seed_);
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void GraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
} }
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void WeightedGraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
weight_arr.push_back(weight); weight_arr.push_back(weight);
} }
......
...@@ -24,19 +24,20 @@ class GraphEdgeBlob { ...@@ -24,19 +24,20 @@ class GraphEdgeBlob {
GraphEdgeBlob() {} GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {} virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); } size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; } int64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; } virtual float get_weight(int idx) { return 1; }
std::vector<int64_t>& export_id_array() { return id_arr; }
protected: protected:
std::vector<uint64_t> id_arr; std::vector<int64_t> id_arr;
}; };
class WeightedGraphEdgeBlob : public GraphEdgeBlob { class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public: public:
WeightedGraphEdgeBlob() {} WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {} virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; } virtual float get_weight(int idx) { return weight_arr[idx]; }
protected: protected:
......
...@@ -48,6 +48,7 @@ class Node { ...@@ -48,6 +48,7 @@ class Node {
virtual void set_feature(int idx, std::string str) {} virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {} virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; } virtual int get_feature_size() { return 0; }
virtual size_t get_neighbor_size() { return 0; }
protected: protected:
uint64_t id; uint64_t id;
...@@ -70,6 +71,7 @@ class GraphNode : public Node { ...@@ -70,6 +71,7 @@ class GraphNode : public Node {
} }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
virtual size_t get_neighbor_size() { return edges->size(); }
protected: protected:
Sampler *sampler; Sampler *sampler;
......
...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable); ...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(Table, SSDSparseTable); REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif #endif
REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
......
...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv ...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table) cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
...@@ -236,7 +236,7 @@ void RunGraphSplit() { ...@@ -236,7 +236,7 @@ void RunGraphSplit() {
sleep(2); sleep(2);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -250,16 +250,16 @@ void RunGraphSplit() { ...@@ -250,16 +250,16 @@ void RunGraphSplit() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
_vs.clear(); _vs.clear();
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 97), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 97), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(3, _vs[0].size()); ASSERT_EQ(3, _vs[0].size());
std::remove(edge_file_name); std::remove(edge_file_name);
......
...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed; ...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
void testSampleNodes( void testSampleNodes(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<uint64_t> ids; std::vector<int64_t> ids;
auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids); auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids);
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {37, 59}; std::unordered_set<int64_t> s1 = {37, 59};
pull_status.wait(); pull_status.wait();
for (auto id : ids) s.insert(id); for (auto id : ids) s.insert(id);
ASSERT_EQ(true, s.size() == s1.size()); ASSERT_EQ(true, s.size() == s1.size());
...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() { ...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor( void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors( auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor( ...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
vs.clear(); vs.clear();
vs1.clear(); vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
for (auto g : vs[0]) { for (auto g : vs[0]) {
...@@ -147,30 +147,30 @@ void testAddNode( ...@@ -147,30 +147,30 @@ void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0); worker_ptr_->clear_nodes(0);
int total_num = 270000; int total_num = 270000;
uint64_t id; int64_t id;
std::unordered_set<uint64_t> id_set; std::unordered_set<int64_t> id_set;
for (int i = 0; i < total_num; i++) { for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end()) while (id_set.find(id = rand()) != id_set.end())
; ;
id_set.insert(id); id_set.insert(id);
} }
std::vector<uint64_t> id_list(id_set.begin(), id_set.end()); std::vector<int64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list; std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list); auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait(); status.wait();
std::vector<uint64_t> ids[2]; std::vector<int64_t> ids[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto sample_status = auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x); for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size()); ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) { for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true); ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
} }
std::vector<uint64_t> remove_ids; std::vector<int64_t> remove_ids;
for (auto p : id_set_check) { for (auto p : id_set_check) {
if (remove_ids.size() == 0) if (remove_ids.size() == 0)
remove_ids.push_back(p); remove_ids.push_back(p);
...@@ -187,7 +187,7 @@ void testAddNode( ...@@ -187,7 +187,7 @@ void testAddNode(
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x); for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size()); ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) { for (auto x : id_set_check1) {
...@@ -196,14 +196,14 @@ void testAddNode( ...@@ -196,14 +196,14 @@ void testAddNode(
} }
void testBatchSampleNeighboor( void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96}; std::vector<std::int64_t> v = {37, 96};
auto pull_status = auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false); worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() { ...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() { ...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_); testSampleNodes(worker_ptr_);
sleep(5); sleep(5);
testSingleSampleNeighboor(worker_ptr_); testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() { ...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
while (round--) { while (round--) {
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false); 0, std::vector<int64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait(); pull_status.wait();
for (int i = 0; i < ttl; i++) { for (int i = 0; i < ttl; i++) {
std::vector<std::vector<uint64_t>> vs1; std::vector<std::vector<int64_t>> vs1;
std::vector<std::vector<float>> vs2; std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false); 0, std::vector<int64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(_vs[0].size(), vs1[0].size()); ASSERT_EQ(_vs[0].size(), vs1[0].size());
...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() { ...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
// Test Pull by step // Test Pull by step
std::unordered_set<uint64_t> count_item_nodes; std::unordered_set<int64_t> count_item_nodes;
// pull by step 2 // pull by step 2
for (int test_step = 1; test_step < 4; test_step++) { for (int test_step = 1; test_step < 4; test_step++) {
count_item_nodes.clear(); count_item_nodes.clear();
...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() { ...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12); ASSERT_EQ(count_item_nodes.size(), 12);
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors( res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false); std::string("user2item"), std::vector<int64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3); ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids; std::vector<int64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4, res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false); true, false);
ASSERT_EQ(res.first[1].size(), 1); ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
(nodes_ids[0] == 37 && nodes_ids[1] == 59)); (nodes_ids[0] == 37 && nodes_ids[1] == 59));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
// odd id:96 48 122 112
char edge_file_name[] = "edges.txt";
std::vector<std::string> nodes = {
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"),
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"),
std::string("user\t59\ta 0.11\tb 11 14"),
std::string("user\t97\ta 0.11\tb 12 11"),
std::string("item\t45\ta 0.21"),
std::string("item\t145\ta 0.21"),
std::string("item\t112\ta 0.21"),
std::string("item\t48\ta 0.21"),
std::string("item\t247\ta 0.21"),
std::string("item\t111\ta 0.21"),
std::string("item\t46\ta 0.21"),
std::string("item\t146\ta 0.21"),
std::string("item\t122\ta 0.21"),
std::string("item\t49\ta 0.21"),
std::string("item\t248\ta 0.21"),
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt";
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1;
graph_table.initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_gpups_mode_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
}
TEST(testGraphSample, Run) { testGraphSample(); }
...@@ -370,7 +370,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -370,7 +370,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
if (grad_tensors[i].is_initialized()) { if (grad_tensors[i].is_initialized()) {
// Deep copy // Deep copy
paddle::experimental::Tensor tmp_tensor; paddle::experimental::Tensor tmp_tensor;
tmp_tensor.copy_(grad_tensors[i], true); tmp_tensor.copy_(grad_tensors[i], grad_tensors[i].inner_place(), true);
node_input_buffers_dict[grad_node]->add(input_info.first, node_input_buffers_dict[grad_node]->add(input_info.first,
input_info.second, tmp_tensor); input_info.second, tmp_tensor);
} else { } else {
......
...@@ -128,6 +128,6 @@ TEST(Generated, ElementwiseAdd) { ...@@ -128,6 +128,6 @@ TEST(Generated, ElementwiseAdd) {
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(matmul_v2);
...@@ -255,6 +255,6 @@ TEST(Hook_intermidiate, Matmul_v2) { ...@@ -255,6 +255,6 @@ TEST(Hook_intermidiate, Matmul_v2) {
} }
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(matmul_v2); USE_OP_ITSELF(matmul_v2);
...@@ -10,8 +10,9 @@ IF(WITH_GPU) ...@@ -10,8 +10,9 @@ IF(WITH_GPU)
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS}) nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm) nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps) nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps)
ENDIF() ENDIF()
IF(WITH_ROCM) IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
}
};
#endif
...@@ -14,114 +14,25 @@ ...@@ -14,114 +14,25 @@
#pragma once #pragma once
#include "heter_comm.h" #include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
class GpuPsGraphTable : public HeterComm<int64_t, int, int> { class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public: public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource) GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource)
: HeterComm<int64_t, int, int>(1, resource) { : HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25; load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
cpu_table_status = -1;
}
~GpuPsGraphTable() {
if (cpu_table_status != -1) {
end_graph_sampling();
}
} }
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list); void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size); NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
...@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> { ...@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int *h_right, int *h_right,
int64_t *src_sample_res, int64_t *src_sample_res,
int *actual_sample_size); int *actual_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() {
return cpu_graph_table->end_graph_sampling();
}
private: private:
std::vector<GpuPsCommGraph> gpu_graph_list; std::vector<GpuPsCommGraph> gpu_graph_list;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_;
std::condition_variable cv_;
int cpu_table_status;
}; };
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/* /*
...@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, ...@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
} }
} }
int GpuPsGraphTable::init_cpu_table(
const paddle::distributed::GraphParameter& graph) {
cpu_graph_table.reset(new paddle::distributed::GraphTable);
cpu_table_status = cpu_graph_table->initialize(graph);
if (cpu_table_status != 0) return cpu_table_status;
std::function<void(std::vector<GpuPsCommGraph>&)> callback =
[this](std::vector<GpuPsCommGraph>& res) {
pthread_rwlock_wrlock(this->rw_lock.get());
this->clear_graph_info();
this->build_graph_from_cpu(res);
pthread_rwlock_unlock(this->rw_lock.get());
cv_.notify_one();
};
cpu_graph_table->set_graph_sample_callback(callback);
return cpu_table_status;
}
int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
int status = cpu_graph_table->load(path, param);
if (status != 0) {
return status;
}
std::unique_lock<std::mutex> lock(mutex_);
cpu_graph_table->start_graph_sampling();
cv_.wait(lock);
return 0;
}
/* /*
comment 1 comment 1
...@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, ...@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
that's what fill_dvals does. that's what fill_dvals does.
*/ */
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right, int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right,
int64_t* src_sample_res, int* actual_sample_size) { int64_t* src_sample_res, int* actual_sample_size) {
...@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr()); int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr()); int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr = int* d_shard_actual_sample_size_ptr =
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include <queue> #include <queue>
namespace paddle { namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
char edge_file_name[] = "edges.txt";
TEST(TEST_FLEET, graph_sample) {
std::vector<std::string> edges;
int gpu_count = 3;
std::vector<int> dev_ids;
dev_ids.push_back(0);
dev_ids.push_back(1);
dev_ids.push_back(2);
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids);
resource->enable_p2p();
GpuPsGraphTable g(resource);
int node_count = 10;
std::vector<std::vector<int64_t>> neighbors(node_count);
int ind = 0;
int64_t node_id = 0;
// std::vector<GpuPsCommGraph> graph_list(gpu_count);
while (ind < node_count) {
int neighbor_size = ind + 1;
while (neighbor_size--) {
edges.push_back(std::to_string(ind) + "\t" + std::to_string(node_id) +
"\t1.0");
node_id++;
}
ind++;
}
/*
gpu 0:
0,3,6,9
gpu 1:
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(3);
table_proto.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto.set_gpups_graph_sample_args("5,5,1,1");
prepare_file(edge_file_name, edges);
g.init_cpu_table(table_proto);
g.load(std::string(edge_file_name), std::string("e>"));
/*
node x's neighbor list = [(1+x)*x/2,(1+x)*x/2 + 1,.....,(1+x)*x/2 + x]
so node 6's neighbors are [21,22...,27]
node 7's neighbors are [28,29,..35]
node 0's neighbors are [0]
query([7,0,6],sample_size=3) should return [28,29,30,0,x,x,21,22,23]
6 --index-->2
0 --index--->0
7 --index-->2
*/
int64_t cpu_key[3] = {7, 0, 6};
void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
int64_t *res = new int64_t[9];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost);
std::sort(res, res + 3);
std::sort(res + 6, res + 9);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23};
for (int i = 0; i < 9; i++) {
if (expected_sample_val[i] != -1) {
ASSERT_EQ(res[i], expected_sample_val[i]);
}
}
delete[] res;
delete neighbor_sample_res;
}
...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsDenseTensorVectorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY;
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name); auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR; return var_types[0] == proto::VarType::LOD_TENSOR;
...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dims(); return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims(); return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows.")); "Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dtype(); return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype(); return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows.")); "Currently, only can get dtype from DenseTensor or SelectedRows."));
...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor {
DataLayout layout() const override { DataLayout layout() const override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout(); if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now // Unsupported get layout for VarDesc now
...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows.")); "Currently, only can set dims from DenseTensor or SelectedRows."));
...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows.")); "Currently, only can set dtype from DenseTensor or SelectedRows."));
...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor {
void set_layout(DataLayout layout) override { void set_layout(DataLayout layout) override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); if (var->IsType<phi::DenseTensor>()) {
phi::DenseTensorUtils::GetMutableMeta( auto* tensor = var->GetMutable<phi::DenseTensor>();
static_cast<phi::DenseTensor*>(tensor)) phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
->layout = layout; } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now // Unsupported set layout for VarDesc now
...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override { void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor); share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout set_layout(meta_tensor.layout());
// set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor // special case: share lod of LoDTensor
share_lod(meta_tensor); share_lod(meta_tensor);
} }
...@@ -442,6 +489,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -442,6 +489,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size())); attr_name, infershape_input.size()));
} }
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr. // Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
......
...@@ -73,8 +73,10 @@ static void ShareVarInfoToCinnLaunch( ...@@ -73,8 +73,10 @@ static void ShareVarInfoToCinnLaunch(
varinfo_maps.at(cinn_launch_op->GetScopeIdx()); varinfo_maps.at(cinn_launch_op->GetScopeIdx());
// collect all MemOptVarInfos of external variables // collect all MemOptVarInfos of external variables
// that would be eager deleted after the cinn_launch subgraph executed, // that were eager deleted after the cinn_launch subgraph executed,
// and store them as attribute of the subgraph // and we will delete them in advance among eager_deletion_ops
// inside cinn_launch subgraph, so store them as attribute of the subgraph
// to pass to the inner eager_deletion_ops.
for (const auto& var_name : vars_to_delete) { for (const auto& var_name : vars_to_delete) {
auto it = src_varinfo_map.find(var_name); auto it = src_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(), PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
...@@ -82,6 +84,8 @@ static void ShareVarInfoToCinnLaunch( ...@@ -82,6 +84,8 @@ static void ShareVarInfoToCinnLaunch(
"MemOptVarInfo of var[%s] not found", var_name)); "MemOptVarInfo of var[%s] not found", var_name));
dst_varinfo_map.emplace(var_name, it->second); dst_varinfo_map.emplace(var_name, it->second);
} }
// skip running of the followed eager_deletion_op
followed_eager_deletion_op->SetSkipRunning(true);
} }
static void TakeVarInfoFromMainGraph( static void TakeVarInfoFromMainGraph(
......
...@@ -31,7 +31,7 @@ USE_OP(slice); ...@@ -31,7 +31,7 @@ USE_OP(slice);
USE_OP(concat); USE_OP(concat);
USE_OP(matmul); USE_OP(matmul);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh); USE_OP_ITSELF(tanh);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
USE_OP(softmax_with_cross_entropy); USE_OP(softmax_with_cross_entropy);
...@@ -47,7 +47,7 @@ USE_OP(square); ...@@ -47,7 +47,7 @@ USE_OP(square);
USE_OP(transpose2_grad); USE_OP(transpose2_grad);
USE_OP(concat_grad); USE_OP(concat_grad);
USE_OP_ITSELF(elementwise_mul_grad); USE_OP_ITSELF(elementwise_mul_grad);
USE_OP(sigmoid_grad); USE_OP_ITSELF(sigmoid_grad);
USE_OP_ITSELF(tanh_grad); USE_OP_ITSELF(tanh_grad);
USE_OP(sum); USE_OP(sum);
USE_OP(slice_grad); USE_OP(slice_grad);
......
...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto* var = ins_vector[offset]; auto* var = ins_vector[offset];
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor_in = &(var->Get<framework::LoDTensor>()); tensor_in = &(var->Get<framework::LoDTensor>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>()); tensor_in = &(var->Get<phi::SelectedRows>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var->Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} }
// Note: here cannot deal with vector<LoDTensorArray> input
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done inputs"; VLOG(4) << "Done inputs";
...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var) { if (var) {
if (var->template IsType<framework::LoDTensor>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<phi::SelectedRows>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} } else {
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
}
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done outputs"; VLOG(4) << "Done outputs";
......
...@@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.InputVar(name)->IsType<phi::SelectedRows>(); return ctx_.InputVar(name)->IsType<phi::SelectedRows>();
} }
bool IsDenseTensorVectorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::LoDTensorArray>();
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>(); return ctx_.OutputVar(name)->IsType<framework::LoDTensor>();
} }
......
...@@ -423,7 +423,7 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -423,7 +423,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
} }
if (data_type == framework::proto::VarType::BF16) { if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>( return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place); src_tensor, dst_tensor, place);
#else #else
......
...@@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext( ...@@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext(
auto& var = ins_vector[offset]->Var(); auto& var = ins_vector[offset]->Var();
if (var.template IsType<phi::DenseTensor>()) { if (var.template IsType<phi::DenseTensor>()) {
tensor_in = &(var.template Get<phi::DenseTensor>()); tensor_in = &(var.template Get<phi::DenseTensor>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<phi::SelectedRows>()) { } else if (var.template IsType<phi::SelectedRows>()) {
tensor_in = &(var.template Get<phi::SelectedRows>()); tensor_in = &(var.template Get<phi::SelectedRows>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var.template Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var.Type()))); framework::ToTypeName(var.Type())));
} }
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
...@@ -326,17 +335,28 @@ void BuildDygraphPhiKernelContext( ...@@ -326,17 +335,28 @@ void BuildDygraphPhiKernelContext(
if (var) { if (var) {
if (var->template IsType<phi::DenseTensor>()) { if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>(); tensor_out = var->template GetMutable<phi::DenseTensor>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<phi::SelectedRows>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} } else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
......
...@@ -50,8 +50,7 @@ ...@@ -50,8 +50,7 @@
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/utils/string/split.h" #include "paddle/utils/string/split.h"
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
...@@ -374,8 +373,7 @@ static void DisablePrepareDataOpt( ...@@ -374,8 +373,7 @@ static void DisablePrepareDataOpt(
} }
bool AnalysisPredictor::PrepareExecutor() { bool AnalysisPredictor::PrepareExecutor() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
if (config_.dist_config().use_dist_model()) { if (config_.dist_config().use_dist_model()) {
VLOG(3) << "use_dist_model is enabled, will init FleetExecutor."; VLOG(3) << "use_dist_model is enabled, will init FleetExecutor.";
return PrepareFleetExecutor(); return PrepareFleetExecutor();
...@@ -393,8 +391,7 @@ bool AnalysisPredictor::PrepareExecutor() { ...@@ -393,8 +391,7 @@ bool AnalysisPredictor::PrepareExecutor() {
return true; return true;
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
bool AnalysisPredictor::PrepareFleetExecutor() { bool AnalysisPredictor::PrepareFleetExecutor() {
VLOG(3) << "AnalysisPredictor::PrepareFleetExecutor()"; VLOG(3) << "AnalysisPredictor::PrepareFleetExecutor()";
if (config_.dist_config().nranks() > 1 && !CommInit()) { if (config_.dist_config().nranks() > 1 && !CommInit()) {
...@@ -1194,8 +1191,7 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() { ...@@ -1194,8 +1191,7 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor( std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
const std::string &name) { const std::string &name) {
framework::Scope *scope; framework::Scope *scope;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
if (config_.dist_config().use_dist_model()) { if (config_.dist_config().use_dist_model()) {
scope = scope_.get(); scope = scope_.get();
} else { } else {
...@@ -1244,8 +1240,7 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor( ...@@ -1244,8 +1240,7 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor( std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
const std::string &name) { const std::string &name) {
framework::Scope *scope; framework::Scope *scope;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
if (config_.dist_config().use_dist_model()) { if (config_.dist_config().use_dist_model()) {
scope = scope_.get(); scope = scope_.get();
} else { } else {
...@@ -1292,8 +1287,7 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor( ...@@ -1292,8 +1287,7 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
} }
bool AnalysisPredictor::ZeroCopyRun() { bool AnalysisPredictor::ZeroCopyRun() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
if (config_.dist_config().use_dist_model()) { if (config_.dist_config().use_dist_model()) {
VLOG(3) << "ZeroCopyRun will use the fleet executor."; VLOG(3) << "ZeroCopyRun will use the fleet executor.";
inference::Timer timer; inference::Timer timer;
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#endif #endif
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
...@@ -395,8 +394,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -395,8 +394,7 @@ class AnalysisPredictor : public PaddlePredictor {
void StatisticShapeRangeInfo(); void StatisticShapeRangeInfo();
void CollectShapeRangeInfo(); void CollectShapeRangeInfo();
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// fleet exe related // fleet exe related
/// ///
...@@ -488,8 +486,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -488,8 +486,7 @@ class AnalysisPredictor : public PaddlePredictor {
std::map<std::string, std::vector<std::vector<int32_t>>> shape_info_; std::map<std::string, std::vector<std::vector<int32_t>>> shape_info_;
static int clone_num_; static int clone_num_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// fleet executor related // fleet executor related
distributed::FleetExecutorDesc executor_desc_; distributed::FleetExecutorDesc executor_desc_;
std::shared_ptr<distributed::FleetExecutor> fleet_exe_; std::shared_ptr<distributed::FleetExecutor> fleet_exe_;
......
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
# #
cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope) cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce) if (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce onnxruntime)
else (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
endif (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc) cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api) cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api)
...@@ -22,12 +22,22 @@ ...@@ -22,12 +22,22 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif
namespace paddle_infer { namespace paddle_infer {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
void Tensor::Reshape(const std::vector<int> &shape) { void Tensor::Reshape(const std::vector<int> &shape) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
shape_.assign(shape.begin(), shape.end());
return;
}
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_.empty(), false, name_.empty(), false,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
...@@ -123,6 +133,11 @@ T *Tensor::data(PlaceType *place, int *size) const { ...@@ -123,6 +133,11 @@ T *Tensor::data(PlaceType *place, int *size) const {
} }
DataType Tensor::type() const { DataType Tensor::type() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return dtype_;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
auto type = paddle::framework::TransToProtoVarType(tensor->dtype()); auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
if (type == paddle::framework::proto::VarType::FP32) { if (type == paddle::framework::proto::VarType::FP32) {
...@@ -145,6 +160,13 @@ PlaceType Tensor::place() const { return place_; } ...@@ -145,6 +160,13 @@ PlaceType Tensor::place() const { return place_; }
template <typename T> template <typename T>
void Tensor::CopyFromCpu(const T *data) { void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyFromCpu<T>(data);
return;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GE(tensor->numel(), 0, PADDLE_ENFORCE_GE(tensor->numel(), 0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
...@@ -382,6 +404,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -382,6 +404,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
template <typename T> template <typename T>
void Tensor::CopyToCpu(T *data) const { void Tensor::CopyToCpu(T *data) const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyToCpu<T>(data);
return;
}
#endif
CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr); CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
} }
...@@ -489,12 +518,7 @@ template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place); ...@@ -489,12 +518,7 @@ template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
Tensor::Tensor(void *scope) : scope_{scope} { Tensor::Tensor(void *scope) : scope_{scope} {}
PADDLE_ENFORCE_NOT_NULL(scope_,
paddle::platform::errors::PreconditionNotMet(
"The `scope` can not be nullptr. It should be "
"set to the pointer of scope."));
}
template <typename T> template <typename T>
void *Tensor::FindTensor() const { void *Tensor::FindTensor() const {
...@@ -513,6 +537,26 @@ void *Tensor::FindTensor() const { ...@@ -513,6 +537,26 @@ void *Tensor::FindTensor() const {
} }
std::vector<int> Tensor::shape() const { std::vector<int> Tensor::shape() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
std::vector<int> shape;
// input handle
if (idx_ < 0) {
shape.assign(shape_.begin(), shape_.end());
} else { // output handle
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
auto ort_shape = info.GetShape();
shape.assign(ort_shape.begin(), ort_shape.end());
}
return shape;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor_, paddle::platform::errors::PreconditionNotMet( tensor_, paddle::platform::errors::PreconditionNotMet(
...@@ -573,4 +617,99 @@ void Tensor::SetPlace(PlaceType place, int device) { ...@@ -573,4 +617,99 @@ void Tensor::SetPlace(PlaceType place, int device) {
device_ = device; device_ = device;
} }
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }
void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
size * sizeof(float16), shape, shape_len,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}
template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"input tensor [%s] no binding ptr", name_));
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_,
OrtMemTypeDefault);
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
std::multiplies<size_t>());
auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size,
shape_.data(), shape_.size());
binding->BindInput(name_.c_str(), ort_value);
}
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
size_t size = info.GetElementCount() * sizeof(T);
if (place_ == PlaceType::kCPU) {
std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
} else {
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data),
paddle::platform::CUDAPlace(device_),
value.GetTensorData<void>(), size, nullptr);
}
}
template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);
template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif
} // namespace paddle_infer } // namespace paddle_infer
...@@ -25,11 +25,7 @@ ...@@ -25,11 +25,7 @@
#include <vector> #include <vector>
#include "paddle/fluid//platform/device/gpu/gpu_types.h" #include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -45,24 +41,23 @@ ...@@ -45,24 +41,23 @@
namespace paddle { namespace paddle {
framework::proto::VarType::Type ConvertONNXType( paddle_infer::DataType ConvertONNXType(ONNXTensorElementDataType type) {
ONNXTensorElementDataType type) {
switch (type) { switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return framework::proto::VarType::FP32; return paddle_infer::DataType::FLOAT32;
// case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
// return DataType::FP16; return paddle_infer::DataType::FLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return framework::proto::VarType::INT8; return paddle_infer::DataType::INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return framework::proto::VarType::INT32; return paddle_infer::DataType::INT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return framework::proto::VarType::INT64; return paddle_infer::DataType::INT64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return framework::proto::VarType::UINT8; return paddle_infer::DataType::UINT8;
default: default:
LOG(ERROR) << "unsupported ONNX Tensor Type: " << static_cast<int>(type); LOG(ERROR) << "unsupported ONNX Tensor Type: " << static_cast<int>(type);
return framework::proto::VarType::FP32; return paddle_infer::DataType::FLOAT32;
} }
} }
...@@ -87,13 +82,12 @@ bool ONNXRuntimePredictor::Init() { ...@@ -87,13 +82,12 @@ bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()"; VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only suuport CPU // Now ONNXRuntime only suuport CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) { if (config_.use_gpu()) {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} }
scope_.reset(new paddle::framework::Scope());
sub_scope_ = &scope_->NewScope();
std::string onnx_proto; std::string onnx_proto;
paddle2onnx::Export(config_.prog_file(), config_.params_file(), &onnx_proto, paddle2onnx::Export(config_.prog_file(), config_.params_file(), &onnx_proto,
...@@ -125,13 +119,12 @@ bool ONNXRuntimePredictor::Init() { ...@@ -125,13 +119,12 @@ bool ONNXRuntimePredictor::Init() {
"generated."; "generated.";
} }
session_ = {env_, onnx_proto.data(), onnx_proto.size(), session_options}; session_ = {env_, onnx_proto.data(), onnx_proto.size(), session_options};
binding_ = std::make_shared<Ort::IoBinding>(session_);
auto memory_info = Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info); Ort::Allocator allocator(session_, memory_info);
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
size_t n_inputs = session_.GetInputCount(); size_t n_inputs = session_.GetInputCount();
for (size_t i = 0; i < n_inputs; ++i) { for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator); auto input_name = session_.GetInputName(i, allocator);
...@@ -141,8 +134,6 @@ bool ONNXRuntimePredictor::Init() { ...@@ -141,8 +134,6 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type = ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType(); type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type}); input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name); allocator.Free(input_name);
} }
...@@ -155,11 +146,13 @@ bool ONNXRuntimePredictor::Init() { ...@@ -155,11 +146,13 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type = ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType(); type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type}); output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
auto *ptr = scope_->Var(output_name);
framework::InitializeVariable(ptr, proto_type); Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name); allocator.Free(output_name);
} }
return true; return true;
} }
...@@ -216,15 +209,26 @@ std::vector<std::string> ONNXRuntimePredictor::GetOutputNames() { ...@@ -216,15 +209,26 @@ std::vector<std::string> ONNXRuntimePredictor::GetOutputNames() {
return output_names; return output_names;
} }
bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,
bool is_input) {
if (is_input) {
for (auto i : input_desc_)
if (i.name == name) return true;
} else {
for (auto i : output_desc_)
if (i.name == name) return true;
}
return false;
}
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name), PADDLE_ENFORCE_EQ(FindONNXDesc(name, true), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the " "The in variable named %s is not found in the "
"scope of the ONNXPredictor.", "ONNXPredictor.",
name)); name));
std::unique_ptr<ZeroCopyTensor> res( std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
res->input_or_output_ = true; res->input_or_output_ = true;
res->SetName(name); res->SetName(name);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
...@@ -233,18 +237,19 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( ...@@ -233,18 +237,19 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
} }
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
return res; return res;
} }
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name), PADDLE_ENFORCE_EQ(FindONNXDesc(name, false), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The out variable named %s is not found in the " "The out variable named %s is not found in the "
"scope of the ONNXPredictor.", "ONNXPredictor.",
name)); name));
std::unique_ptr<ZeroCopyTensor> res( std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
res->input_or_output_ = false; res->input_or_output_ = false;
res->SetName(name); res->SetName(name);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
...@@ -253,44 +258,16 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( ...@@ -253,44 +258,16 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
} }
return res; res->SetOrtMark(true);
} res->SetOrtBinding(binding_);
int size = output_desc_.size();
Ort::Value ONNXRuntimePredictor::GetOrtValue(const ONNXDesc &desc, for (int i = 0; i < size; ++i)
const char *device_name) { if (output_desc_[i].name == name) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, res->idx_ = i;
place_.GetDeviceId(), OrtMemTypeDefault); res->dtype_ = ConvertONNXType(output_desc_[i].dtype);
auto *var = scope_->FindVar(desc.name); break;
auto *tensor = var->GetMutable<framework::LoDTensor>();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::TransToProtoVarType(tensor->dtype()));
std::vector<int64_t> shape = phi::vectorize<int64_t>(tensor->dims());
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(tensor->data()), size,
shape.data(), shape.size(), desc.dtype);
}
void ONNXRuntimePredictor::AsTensor(const Ort::Value &value,
const ONNXDesc &desc) {
auto info = value.GetTensorTypeAndShapeInfo();
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim(info.GetShape()));
auto dtype = ConvertONNXType(info.GetElementType());
auto *ptr = tensor->mutable_data(place_, dtype);
if (platform::is_cpu_place(place_)) {
std::memcpy(ptr, const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
} else {
auto src_place = place_;
auto dst_place = place_;
memory::Copy(dst_place, ptr, src_place,
const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
} }
return res;
} }
bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
...@@ -302,31 +279,7 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -302,31 +279,7 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() { bool ONNXRuntimePredictor::ZeroCopyRun() {
try { try {
Ort::IoBinding binding(session_); session_.Run({}, *(binding_.get()));
std::vector<Ort::Value> inputs;
std::vector<Ort::Value> outputs;
Ort::RunOptions options;
inputs.reserve(input_desc_.size());
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
for (auto desc : input_desc_) {
inputs.push_back(GetOrtValue(desc, device_name));
binding.BindInput(desc.name.c_str(), inputs.back());
}
// TODO(heliqi): Optimization —— move to Init()
for (auto desc : output_desc_) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding.BindOutput(desc.name.c_str(), memory_info);
}
session_.Run({}, binding);
outputs = binding.GetOutputValues();
for (size_t i = 0; i < output_desc_.size(); ++i) {
AsTensor(outputs[i], output_desc_[i]);
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << e.what(); LOG(ERROR) << e.what();
return false; return false;
...@@ -345,9 +298,9 @@ uint64_t ONNXRuntimePredictor::TryShrinkMemory() { ...@@ -345,9 +298,9 @@ uint64_t ONNXRuntimePredictor::TryShrinkMemory() {
} }
ONNXRuntimePredictor::~ONNXRuntimePredictor() { ONNXRuntimePredictor::~ONNXRuntimePredictor() {
if (sub_scope_) { binding_->ClearBoundInputs();
scope_->DeleteScope(sub_scope_); binding_->ClearBoundOutputs();
}
memory::Release(place_); memory::Release(place_);
} }
......
...@@ -94,9 +94,8 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -94,9 +94,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config /// \param[in] AnalysisConfig config
/// ///
explicit ONNXRuntimePredictor(const AnalysisConfig &config) explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: config_(config) { : config_(config), env_(ORT_LOGGING_LEVEL_WARNING, "onnx") {
predictor_id_ = inference::GetUniqueId(); predictor_id_ = inference::GetUniqueId();
env_ = Ort::Env(ORT_LOGGING_LEVEL_INFO, "onnx");
} }
/// ///
/// \brief Destroy the ONNXRuntime Predictor object /// \brief Destroy the ONNXRuntime Predictor object
...@@ -177,30 +176,17 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -177,30 +176,17 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
std::shared_ptr<framework::Scope> scope_;
private: private:
/// ///
/// \brief get the Ort Value(input Tensor). /// \brief Whether to find in/out by name.
///
/// \param[in] desc ONNXDesce(name、shape、dtype)
///
/// \param[in] device_name "cpu" or "gpu" of device
///
/// \return get a Ort::Value
///
Ort::Value GetOrtValue(const ONNXDesc &desc, const char *device_name);
///
/// \brief Ort::Value to Paddle::ZeroCopyTensor.
/// ///
/// \param[in] value Ort::Value(output Tensor) /// \param[in] name input or output name
/// ///
/// \param[in] desc a ONNXDesce(name、shape、dtype) /// \param[in] is_input input(true) or output(false)
/// ///
/// \return get a Ort::Value /// \return Whether to find by name
/// ///
void AsTensor(const Ort::Value &value, const ONNXDesc &desc); bool FindONNXDesc(const std::string &name, bool is_input);
private: private:
AnalysisConfig config_; AnalysisConfig config_;
...@@ -208,9 +194,9 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -208,9 +194,9 @@ class ONNXRuntimePredictor : public PaddlePredictor {
// ONNXRuntime // ONNXRuntime
Ort::Env env_; Ort::Env env_;
Ort::Session session_{nullptr}; Ort::Session session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;
platform::Place place_; platform::Place place_;
framework::Scope *sub_scope_{nullptr};
std::vector<ONNXDesc> input_desc_; std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_; std::vector<ONNXDesc> output_desc_;
int predictor_id_; int predictor_id_;
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include "paddle_infer_declare.h" // NOLINT #include "paddle_infer_declare.h" // NOLINT
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#endif
namespace paddle_infer { namespace paddle_infer {
/// \brief Experimental. /// \brief Experimental.
...@@ -175,6 +180,23 @@ class PD_INFER_DECL Tensor { ...@@ -175,6 +180,23 @@ class PD_INFER_DECL Tensor {
PlaceType place_; PlaceType place_;
int device_; int device_;
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
std::vector<int64_t> shape_;
std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1};
void SetOrtMark(bool is_ort_tensor);
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
template <typename T>
void ORTCopyFromCpu(const T* data);
template <typename T>
void ORTCopyToCpu(T* data) const;
#endif
friend class paddle_infer::contrib::TensorUtils; friend class paddle_infer::contrib::TensorUtils;
#if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST) #if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST)
friend class paddle_infer::InferApiTesterUtils; friend class paddle_infer::InferApiTesterUtils;
......
...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
......
...@@ -53,6 +53,6 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); } ...@@ -53,6 +53,6 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); }
} // namespace paddle } // namespace paddle
USE_OP_ITSELF(relu); USE_OP_ITSELF(relu);
USE_OP(sigmoid); USE_OP_ITSELF(sigmoid);
USE_OP_ITSELF(tanh); USE_OP_ITSELF(tanh);
USE_OP(relu6); USE_OP(relu6);
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/phi/kernels/layer_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm; phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps); variance_d, begin_norm_axis, eps);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
...@@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue( ...@@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue(
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm; phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps); variance_d, begin_norm_axis, eps);
} else { } else {
......
...@@ -1492,6 +1492,10 @@ REGISTER_ACTIVATION_OP(softshrink, SoftShrink, SoftShrinkFunctor, ...@@ -1492,6 +1492,10 @@ REGISTER_ACTIVATION_OP(softshrink, SoftShrink, SoftShrinkFunctor,
REGISTER_ACTIVATION_OP(tanh_shrink, TanhShrink, TanhShrinkFunctor, REGISTER_ACTIVATION_OP(tanh_shrink, TanhShrink, TanhShrinkFunctor,
TanhShrinkGradFunctor); TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor); REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
REGISTER_ACTIVATION_OP(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor);
REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
LogSigmoidGradFunctor);
/* ========================== sigmoid register ============================= /* ========================== sigmoid register =============================
*/ */
...@@ -1526,30 +1530,6 @@ REGISTER_OPERATOR(sigmoid_triple_grad, ...@@ -1526,30 +1530,6 @@ REGISTER_OPERATOR(sigmoid_triple_grad,
ops::SigmoidTripleGradFunctor<float>::FwdDeps()>, ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
ops::ActivationTripleGradOpInplaceInferer); ops::ActivationTripleGradOpInplaceInferer);
// Register Sigmoid/GradSigmoid Kernels
REGISTER_ACTIVATION_CPU_KERNEL(sigmoid, Sigmoid, SigmoidFunctor,
SigmoidGradFunctor);
// Register DoubleGrad Kernel
REGISTER_OP_CPU_KERNEL(
sigmoid_grad_grad,
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
ops::SigmoidGradGradFunctor<float>>,
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
ops::SigmoidGradGradFunctor<double>>,
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
ops::SigmoidGradGradFunctor<plat::float16>>);
// Register TripleGrad Kernel
REGISTER_OP_CPU_KERNEL(
sigmoid_triple_grad,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<float>>,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<double>>,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== tanh register ============================= */ /* ========================== tanh register ============================= */
......
...@@ -20,69 +20,6 @@ limitations under the License. */ ...@@ -20,69 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// sigmoid(x) = 1 / (1 + exp(-x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(one / (one + exp(-x)));
}
};
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout * out * (1 - out)
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out * (one - out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
// logsigmoid(x) = log(1 / (1 + exp(-x)))
// For numerical stability,
// logsigmoid(x) =
// - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType temp = x > zero ? zero : -x;
return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
}
};
template <typename T>
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
// dx = dout * exp(-x) / (1 + exp(-x))
// For numerical stability:
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
// 0)))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType temp1 = x > zero ? zero : -x;
MPType temp2 = exp(-x - temp1);
return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> { struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
...@@ -304,49 +241,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -304,49 +241,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
// hard_sigmoid(x) = 0, when x <= -3
// 1, when x >= 3
// x * slope + offset, otherwise
__device__ __forceinline__ T operator()(const T x) const {
T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < one ? temp_max : one;
return temp_min;
}
};
template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
// dx = (out > 0 && out < 1) ? dout * slope : 0
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T> template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> { struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
...@@ -580,6 +474,9 @@ USE_PHI_FUNCTOR(CudaSoftShrink) ...@@ -580,6 +474,9 @@ USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink) USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu) USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU) USE_PHI_FUNCTOR(CudaELU)
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
template <typename T> template <typename T>
using CudaELUGradNegativeAlphaFunctor = using CudaELUGradNegativeAlphaFunctor =
...@@ -658,35 +555,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -658,35 +555,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::CELUGradGradFunctor<plat::float16>>); ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* =========================== sigmoid register ============================
*/
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
CudaSigmoidGradFunctor);
REGISTER_OP_CUDA_KERNEL(
sigmoid_grad_grad,
ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidGradGradFunctor<float>>,
ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidGradGradFunctor<double>>,
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
ops::SigmoidGradGradFunctor<plat::float16>>,
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
ops::SigmoidGradGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_triple_grad,
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<float>>,
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<double>>,
ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<plat::float16>>,
ops::SigmoidTripleGradKernel<
plat::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */ /* =========================== sqrt register ============================= */
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -772,8 +640,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -772,8 +640,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */ /* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
CudaLogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \ CudaSoftShrinkGradFunctor); \
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
...@@ -788,8 +654,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -788,8 +654,6 @@ REGISTER_OP_CUDA_KERNEL(
CudaTanhShrinkGradFunctor); \ CudaTanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \ __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \
CudaHardShrinkGradFunctor); \ CudaHardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \
CudaHardSigmoidGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \ __macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
CudaHardSwishGradFunctor); CudaHardSwishGradFunctor);
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc; class OpDesc;
...@@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("X")) {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// The runtime output shape is determined in kernel.
return;
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
}
}
}
protected: protected:
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
...@@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference { ...@@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference {
} }
}; };
class AssignKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); ...@@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(assign, AssignInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(assign, ops::AssignOp, REGISTER_OPERATOR(assign, ops::AssignOp,
ops::AssignGradMaker<paddle::framework::OpDesc>, ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>, ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer, ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer,
ops::AssignInferVarType); ops::AssignInferVarType, AssignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel, plat::bfloat16,
ops::AssignKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel);
#endif
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(assign); USE_OP_ITSELF(assign);
USE_OP_DEVICE_KERNEL(assign, NPU); USE_OP_DEVICE_KERNEL(assign, NPU);
template <typename T> template <typename T>
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/diag_functor.h" #include "paddle/phi/kernels/funcs/diag_functor.h"
...@@ -30,7 +31,6 @@ ...@@ -30,7 +31,6 @@
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h" #include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h" #include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
......
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
// only can include the headers in paddle/phi/include dirs // only can include the headers in paddle/phi/include dirs
#include "paddle/phi/kernels/elementwise_grad_kernel.h" #include "paddle/phi/kernels/elementwise_grad_kernel.h"
#include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,9 @@ limitations under the License. */ ...@@ -12,7 +12,9 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h" #include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,27 +24,6 @@ using framework::Tensor; ...@@ -22,27 +24,6 @@ using framework::Tensor;
class ExpandAsV2Op : public framework::OperatorWithKernel { class ExpandAsV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2");
auto x_dims = ctx->GetInputDim("X");
auto target_shape = ctx->Attrs().Get<std::vector<int>>("target_shape");
PADDLE_ENFORCE_GE(
target_shape.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The rank of target_shape must be greater than or equal "
"to the rank of Input(X). But received Input(X): input "
"rank %u; received target_shape: rank %u.",
x_dims.size(), target_shape.size()));
PADDLE_ENFORCE_LE(target_shape.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of target_shape must be less than or equal "
"to %d. But received: rank %u.",
MAX_RANK_SUPPORTED, target_shape.size()));
ctx->SetOutputDim("Out", phi::make_ddim(target_shape));
}
}; };
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -116,9 +97,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X"); ...@@ -116,9 +97,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(expand_as_v2, ExpandAsInferShapeFunctor,
PD_INFER_META(phi::ExpandAsInferMeta));
REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker, REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker,
ops::ExpandAsV2GradOpMaker<paddle::framework::OpDesc>, ops::ExpandAsV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>); ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>,
ExpandAsInferShapeFunctor);
REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp, REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp,
ops::ExpandAsV2GradNoNeedBufVarsInferer); ops::ExpandAsV2GradNoNeedBufVarsInferer);
......
...@@ -25,14 +25,16 @@ limitations under the License. */ ...@@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace memory = paddle::memory; namespace memory = paddle::memory;
USE_OP_ITSELF(dropout); USE_OP_ITSELF(dropout);
USE_OP(layer_norm); USE_OP_ITSELF(layer_norm);
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
...@@ -136,18 +138,23 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale, ...@@ -136,18 +138,23 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
const platform::CUDADeviceContext &ctx) { const platform::CUDADeviceContext &ctx) {
framework::Scope scope; framework::Scope scope;
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
paddle::optional<const framework::LoDTensor &> scale_opt = paddle::none;
if (scale.size() > 0) { if (scale.size() > 0) {
auto var_scale = scope.Var("Scale"); auto var_scale = scope.Var("Scale");
auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>(); auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(scale, ctx, tensor_scale); framework::TensorFromVector(scale, ctx, tensor_scale);
tensor_scale->Resize({cols}); tensor_scale->Resize({cols});
scale_opt = *tensor_scale;
} }
paddle::optional<const framework::LoDTensor &> bias_opt = paddle::none;
if (bias.size() > 0) { if (bias.size() > 0) {
auto var_bias = scope.Var("Bias"); auto var_bias = scope.Var("Bias");
auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>(); auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(bias, ctx, tensor_bias); framework::TensorFromVector(bias, ctx, tensor_bias);
tensor_bias->Resize({cols}); tensor_bias->Resize({cols});
bias_opt = *tensor_bias;
} }
auto var_x = scope.Var("X"); auto var_x = scope.Var("X");
...@@ -157,20 +164,19 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale, ...@@ -157,20 +164,19 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
auto var_y = scope.Var("Y"); auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>(); auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
tensor_y->Resize({rows, cols});
auto var_mean = scope.Var("Mean"); auto var_mean = scope.Var("Mean");
auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>(); auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();
tensor_mean->Resize({rows});
auto var_variance = scope.Var("Variance"); auto var_variance = scope.Var("Variance");
auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>(); auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();
tensor_variance->Resize({rows});
framework::AttributeMap attrs; ctx.Wait();
attrs.insert({"epsilon", epsilon}); phi::LayerNormKernel<T>(static_cast<const phi::GPUContext &>(ctx), *tensor_x,
scale_opt, bias_opt, 1e-5, 1, false, tensor_y,
auto op = framework::OpRegistry::CreateOp( tensor_mean, tensor_variance);
"layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
op->Run(scope, place);
framework::TensorToVector(*tensor_y, ctx, y); framework::TensorToVector(*tensor_y, ctx, y);
framework::TensorToVector(*tensor_mean, ctx, means); framework::TensorToVector(*tensor_mean, ctx, means);
framework::TensorToVector(*tensor_variance, ctx, vars); framework::TensorToVector(*tensor_variance, ctx, vars);
......
...@@ -198,7 +198,6 @@ struct TestFusedLayernormResidualDropoutBias { ...@@ -198,7 +198,6 @@ struct TestFusedLayernormResidualDropoutBias {
residual_vec[i * cols + j] + out2[i * cols + j]; residual_vec[i * cols + j] + out2[i * cols + j];
} }
} }
LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means, LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means,
&correct_vars, &correct_layernorm_out, epsilon, rows, cols, &correct_vars, &correct_layernorm_out, epsilon, rows, cols,
*ctx); *ctx);
......
...@@ -17,7 +17,9 @@ limitations under the License. */ ...@@ -17,7 +17,9 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,27 +28,6 @@ class KronOp : public framework::OperatorWithKernel { ...@@ -26,27 +28,6 @@ class KronOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kron");
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
auto rank_x = dim_x.size();
auto rank_y = dim_y.size();
auto rank = (rank_x > rank_y) ? rank_x : rank_y;
std::vector<int64_t> dim_out;
dim_out.reserve(rank);
for (int i = 0; i < rank; i++) {
int64_t dim_xi = (i < rank - rank_x) ? 1 : dim_x.at(i - (rank - rank_x));
int64_t dim_yi = (i < rank - rank_y) ? 1 : dim_y.at(i - (rank - rank_y));
dim_out.push_back(dim_xi == -1 || dim_yi == -1 ? -1 : dim_xi * dim_yi);
}
ctx->SetOutputDim("Out", phi::make_ddim(dim_out));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -173,7 +154,10 @@ class KronGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -173,7 +154,10 @@ class KronGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(kron, KronInferShapeFunctor,
PD_INFER_META(phi::KronInferMeta));
REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker, REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker,
ops::KronGradOpMaker<paddle::framework::OpDesc>, ops::KronGradOpMaker<paddle::framework::OpDesc>,
ops::KronGradOpMaker<paddle::imperative::OpBase>); ops::KronGradOpMaker<paddle::imperative::OpBase>,
KronInferShapeFunctor);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp); REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
...@@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( ...@@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/ */
template <typename T, typename U, typename ScaleT = U, template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t> typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver( void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols, const int cols, float epsilon, const T *x_ptr,
float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr, const U *var_ptr, const T *dout_ptr, T *dx_ptr,
ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr, ScaleT *dscale_ptr, ScaleT *dbias_ptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) { const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
if (cols == 1024) { if (cols == 1024) {
// step-1: compute dx and reduced part results of dscale and dbias. // step-1: compute dx and reduced part results of dscale and dbias.
...@@ -1334,8 +1336,7 @@ static void LayerNormBackward( ...@@ -1334,8 +1336,7 @@ static void LayerNormBackward(
const U *mean, const U *var, T *d_x, const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size, int64_t batch_size, int64_t feature_size, const phi::GPUContext &dev_ctx) {
const platform::CUDADeviceContext &dev_ctx) {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
#ifdef __HIPCC__ #ifdef __HIPCC__
const int kMaxBlockDim = 256; const int kMaxBlockDim = 256;
......
...@@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, ...@@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>); ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp, REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer); ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
std::vector<int> input_shape,
const T *bias, const T *scale,
T *output, T *mean, T *variance,
int begin_norm_axis, float eps) {
const auto x_dims = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto x_dims = x->dims();
auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data());
framework::proto::VarType::Type x_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::TransToProtoVarType(bias->dtype()),
platform::errors::InvalidArgument(
"Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr
? framework::TransToProtoVarType(bias->dtype())
: x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::DataTypeTrait<U>::DataType(),
platform::errors::InvalidArgument(
"Unsupported data type of Scale and Bias: %s",
framework::DataTypeToString(scale_bias_dtype)));
}
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream();
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
switch (GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE( \
LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<< \
batch_size, kBlockDim, 0, stream>>>( \
x_data, static_cast<const ScaleBiasT *>(void_scale_data), \
static_cast<const ScaleBiasT *>(void_bias_data), y_data, \
mean_data, var_data, epsilon, feature_size)); \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (is_scale_bias_same_dtype_with_x) {
ln_fwd_1024_kernel<T, U, T, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data), mean_data, var_data,
y_data);
} else {
ln_fwd_1024_kernel<T, U, U, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data), mean_data, var_data,
y_data);
}
} else {
#endif
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
// d_x, d_scale, d_bias may be nullptr
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto *x = ctx.Input<Tensor>("X");
auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>();
auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
framework::proto::VarType::Type x_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type scale_bias_dtype;
if (scale != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
} else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
auto *bias = ctx.Input<Tensor>("Bias");
if (bias != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(bias->dtype());
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \
d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \
ctx.cuda_device_context()); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
#undef PADDLE_LAUNCH_LAYERNORM_BWD
}
};
template class LayerNormDirectCUDAFunctor<float>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#elif CUDNN_VERSION_MIN(8, 1, 0)
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#endif
此差异已折叠。
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -18,9 +18,9 @@ limitations under the License. */ ...@@ -18,9 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/set_value_op.h"
#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h" #include "paddle/phi/kernels/funcs/tril_triu_compute.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/triangular_solve_kernel.h" #include "paddle/phi/kernels/triangular_solve_kernel.h"
namespace paddle { namespace paddle {
......
...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -139,7 +140,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,7 +140,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
layer_norm_p->execute(astream, args); layer_norm_p->execute(astream, args);
astream.wait(); astream.wait();
y->set_layout(DataLayout::kMKLDNN); y->set_layout(phi::DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory)); y->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
......
...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/roi_pool_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,7 +58,7 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -57,7 +58,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"%d-dimensional LoDTensor", "%d-dimensional LoDTensor",
rois_dims.size())); rois_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_dims[1], kROISize, rois_dims[1], phi::kROISize,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)" "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...]. But the second dimension of " "given as [[x1, y1, x2, y2], ...]. But the second dimension of "
...@@ -216,16 +217,7 @@ REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, ...@@ -216,16 +217,7 @@ REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
ops::ROIPoolGradMaker<paddle::framework::OpDesc>, ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
ops::ROIPoolGradMaker<paddle::imperative::OpBase>); ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp); REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_pool) REGISTER_OP_VERSION(roi_pool)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
此差异已折叠。
此差异已折叠。
...@@ -84,13 +84,8 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -84,13 +84,8 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
Tensor cpu_tensor(tensor->dtype()); Tensor cpu_tensor(tensor->dtype());
cpu_tensor.Resize(tensor->dims()); cpu_tensor.Resize(tensor->dims());
T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace()); T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace());
auto normal_cdf = [](float x) { std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; 1.0);
};
float a_normal_cdf = normal_cdf((-2.0 - mean) / std);
float b_normal_cdf = normal_cdf((2.0 - mean) / std);
std::uniform_real_distribution<float> dist(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
TruncatedNormal<T> truncated_normal(mean, std); TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel(); int64_t size = tensor->numel();
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册