提交 e5fa5d0c 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into all_new_design_exec

......@@ -20,6 +20,13 @@ set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include(system)
# Note(zhouwei): Ninja Generator will set CMAKE_BUILD_TYPE to Debug
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
project(paddle CXX C)
# enable language CUDA
......@@ -213,12 +220,6 @@ if(NOT PY_VERSION)
endif()
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
# the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined. Default: OFF
if(SANITIZER_TYPE AND NOT "${SANITIZER_TYPE}" MATCHES "^(Address|Leak|Memory|Thread|Undefined)$")
......
......@@ -86,7 +86,7 @@ We provide [English](https://www.paddlepaddle.org.cn/documentation/docs/en/guide
## Communication
- [Github Issues](https://github.com/PaddlePaddle/Paddle/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 778260830 (PaddlePaddle).
- QQ discussion group: 793866180 (PaddlePaddle).
- [Forums](https://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Copyright and License
......
......@@ -83,7 +83,7 @@ PaddlePaddle用户可领取**免费Tesla V100在线算力资源**,训练模型
## 交流与反馈
- 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)来提交问题、报告与建议
- QQ群: 778260830 (PaddlePaddle)
- QQ群: 793866180 (PaddlePaddle)
- [论坛](https://ai.baidu.com/forum/topic/list/168): 欢迎大家在PaddlePaddle论坛分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 版权和许可证
......
......@@ -205,23 +205,16 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
if(WIN32)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
# match the cl's _ITERATOR_DEBUG_LEVEL
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"-g -G -D_DEBUG\"")
if(MSVC_STATIC_CRT)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /MTd")
else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /MDd")
endif()
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"-DNDEBUG\"")
if(MSVC_STATIC_CRT)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /MT")
else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /MD")
endif()
else()
message(FATAL "Windows only support Release or Debug build now. Please set visual studio build type to Release/Debug, x64 build.")
if(MSVC_STATIC_CRT)
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler /MTd")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler /MT")
foreach(flag_var
CMAKE_CUDA_FLAGS CMAKE_CUDA_FLAGS_DEBUG CMAKE_CUDA_FLAGS_RELEASE
CMAKE_CUDA_FLAGS_MINSIZEREL CMAKE_CUDA_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "-MD")
string(REGEX REPLACE "-MD" "-MT" ${flag_var} "${${flag_var}}")
endif()
endforeach(flag_var)
endif()
endif()
......
......@@ -20,7 +20,8 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY ${GIT_URL}/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG f58682cd8bd0615f41d879f8afc8f1511ab42d24)
SET(MKLDNN_TAG f3999b71d8e4415c1985a0dfb812a3ed77ee21fa)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
......@@ -59,8 +60,8 @@ ExternalProject_Add(
DEPENDS ${MKLDNN_DEPENDS}
PREFIX ${MKLDNN_PREFIX_DIR}
SOURCE_DIR ${MKLDNN_SOURCE_DIR}
BUILD_ALWAYS 1
# UPDATE_COMMAND ""
UPDATE_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
......
......@@ -100,9 +100,9 @@ else()
"${WARPCTC_DOWNLOAD_CMD}"
PREFIX ${WARPCTC_PREFIX_DIR}
SOURCE_DIR ${WARPCTC_SOURCE_DIR}
#UPDATE_COMMAND ""
UPDATE_COMMAND ""
PATCH_COMMAND ""
BUILD_ALWAYS 1
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${WARPCTC_C_FLAGS}
......
......@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_05_19.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
......@@ -28,7 +28,12 @@ function(CheckCompilerCXX14Flag)
endfunction()
CheckCompilerCXX14Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
else()
set(CMAKE_CXX_STANDARD 14)
endif()
# safe_set_flag
#
# Set a compile flag only if compiler is support
......
......@@ -92,7 +92,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
# including io directory for inference lib paddle_api.h
include_directories("${PADDLE_SOURCE_DIR}/paddle/fluid/framework/io")
if(NOT APPLE)
if(NOT APPLE AND NOT WIN32)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
if(WITH_PSLIB OR WITH_DISTRIBUTE)
......@@ -100,7 +100,7 @@ if(NOT APPLE)
else()
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")
endif()
endif(NOT APPLE)
endif()
set_property(GLOBAL PROPERTY FLUID_MODULES "")
# find all fluid modules is used for paddle fluid static library
......@@ -391,7 +391,7 @@ function(cc_binary TARGET_NAME)
endfunction(cc_binary)
function(cc_test_build TARGET_NAME)
if(WITH_TESTING)
if(WITH_TESTING AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......@@ -409,14 +409,12 @@ function(cc_test_build TARGET_NAME)
if(WITH_ROCM)
target_link_libraries(${TARGET_NAME} ${ROCM_HIPRTC_LIB})
endif()
check_coverage_opt(${TARGET_NAME} ${cc_test_SRCS})
endif()
check_coverage_opt(${TARGET_NAME} ${cc_test_SRCS})
endfunction()
function(cc_test_run TARGET_NAME)
if(WITH_TESTING)
if(WITH_TESTING AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
set(oneValueArgs "")
set(multiValueArgs COMMAND ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......
......@@ -17,16 +17,30 @@ if(NOT WIN32)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG")
if(WITH_GPU)
set(CMAKE_CUDA_FLAGS_DEBUG "-g")
set(CMAKE_CUDA_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-O1 -DNDEBUG")
endif()
else()
set(CMAKE_C_FLAGS_DEBUG "/Zi /DEBUG")
set(CMAKE_C_FLAGS_RELEASE "/O2 /DNDEBUG")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "/O2 /DNDEBUG")
set(CMAKE_C_FLAGS_MINSIZEREL "/Os /DNDEBUG")
set(CMAKE_C_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1")
set(CMAKE_C_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG")
set(CMAKE_C_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "/Zi /DEBUG")
set(CMAKE_CXX_FLAGS_RELEASE "/O2 /DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/O2 /DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "/Os /DNDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1")
set(CMAKE_CXX_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG")
if(WITH_GPU)
set(CMAKE_CUDA_FLAGS_DEBUG "-Xcompiler=\"-MDd -Zi -Ob0 -Od /RTC1\"")
set(CMAKE_CUDA_FLAGS_RELEASE "-Xcompiler=\"-MD -O2 -Ob2\" -DNDEBUG")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "-Xcompiler=\"-MD -Zi -O2 -Ob1\" -DNDEBUG")
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Xcompiler=\"-MD -O1 -Ob1\" -DNDEBUG")
endif()
# It can specify CUDA compile flag manualy,
# its use is to remvoe /Zi to reduce GPU static library size. But it's dangerous
......@@ -34,10 +48,3 @@ else()
# Now, it's only used in VS2015 + CUDA:[10.0, 10.2]
set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props)
endif()
if(WITH_GPU)
set(CMAKE_CUDA_FLAGS_DEBUG "-g")
set(CMAKE_CUDA_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-O1 -DNDEBUG")
endif()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sys/time.h>
#include <iostream>
#include <ostream>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <ThreadPool.h>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/string/split.h"
constexpr int FG = 256 * 1024 * 1024;
constexpr int Q_SIZE = 10000;
constexpr int BUCKET = 10;
constexpr char XEOF[] = "EOF";
using boost::lexical_cast;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
namespace paddle {
namespace distributed {
class ShardingMerge {
public:
ShardingMerge() {}
~ShardingMerge() {}
void Merge(const std::vector<std::string> &inputs,
const std::vector<int64_t> &feasigns, const std::string &output,
const int embedding_dim) {
pool_.reset(new ::ThreadPool(inputs.size()));
std::vector<std::future<int>> tasks(inputs.size());
std::vector<std::vector<int64_t>> rows;
rows.resize(inputs.size());
auto begin = GetCurrentUS();
for (int x = 0; x < inputs.size(); ++x) {
tasks[x] = pool_->enqueue([this, x, &rows, &inputs, &feasigns]() -> int {
DeserializeRowsFromFile(inputs[x], feasigns[x], &rows[x]);
return 0;
});
}
for (size_t x = 0; x < tasks.size(); ++x) {
tasks[x].wait();
}
int64_t total_rows = 0;
for (auto x = 0; x < rows.size(); x++) {
total_rows += rows[x].size();
}
auto end = GetCurrentUS();
VLOG(0) << "got " << total_rows
<< " feasigin ids from sparse embedding using " << end - begin;
std::vector<int64_t> total_dims = {total_rows,
static_cast<int64_t>(embedding_dim)};
std::vector<std::vector<int>> batch_buckets;
batch_buckets.resize(inputs.size());
for (int x = 0; x < rows.size(); ++x) {
batch_buckets[x] = bucket(rows[x].size(), BUCKET);
}
std::ofstream out(output, std::ios::binary);
begin = GetCurrentUS();
SerializeRowsToStream(out, rows, batch_buckets, total_rows);
end = GetCurrentUS();
VLOG(0) << "write rows to oostrream using " << end - begin;
begin = GetCurrentUS();
SerializePreTensorToStream(out, total_dims);
end = GetCurrentUS();
VLOG(0) << "write pretensor to oostrream using " << end - begin;
begin = GetCurrentUS();
SerializeValueToStream(out, inputs, batch_buckets, embedding_dim);
end = GetCurrentUS();
VLOG(0) << "write values to oostrream using " << end - begin;
}
private:
void SerializeRowsToStream(std::ostream &os,
const std::vector<std::vector<int64_t>> &rows,
const std::vector<std::vector<int>> &batch_buckets,
int64_t total_rows) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{
// the 2st field, rows information
os.write(reinterpret_cast<const char *>(&total_rows), sizeof(total_rows));
for (int b = 0; b < BUCKET; ++b) {
for (int x = 0; x < batch_buckets.size(); ++x) {
auto begin = batch_buckets[x][b];
auto end = batch_buckets[x][b + 1];
if (end - begin == 0) continue;
os.write(reinterpret_cast<const char *>(rows[x].data() + begin),
sizeof(int64_t) * (end - begin));
}
}
// the 3st field, the height of SelectedRows
int64_t height = total_rows;
os.write(reinterpret_cast<const char *>(&height), sizeof(height));
}
}
void SerializePreTensorToStream(std::ostream &os,
const std::vector<int64_t> &dims) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
framework::proto::VarType::TensorDesc desc;
desc.set_data_type(framework::proto::VarType::FP32);
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
}
void SerializeValueToVec(std::ifstream &in, const int batch,
const int embedding_dim, std::vector<float> *out) {
auto queue =
std::make_shared<framework::BlockingQueue<std::vector<std::string>>>();
auto read = [batch, &in, &queue]() {
std::string line;
std::vector<std::string> columns;
std::vector<std::string> values_str;
int count = 0;
while (std::getline(in, line)) {
++count;
columns = string::Split(line, '\t');
if (columns.size() != 5) {
VLOG(0) << "unexpected line: " << line << ", skip it";
continue;
}
values_str = string::Split(columns[4], ',');
queue->Push(values_str);
if (count >= batch) {
break;
}
}
queue->Push({});
};
auto write = [embedding_dim, &out, &queue]() {
std::vector<std::string> values_str;
std::string line;
while (true) {
queue->Pop(&values_str);
if (values_str.size() == 0) {
break;
}
for (int x = 0; x < embedding_dim; ++x) {
float v = 0.0;
try {
v = lexical_cast<float>(values_str[x]);
} catch (boost::bad_lexical_cast &e) {
VLOG(0) << " get unexpected line: " << line;
}
out->push_back(v);
}
}
};
std::thread p_read(read);
std::thread p_write(write);
p_read.join();
p_write.join();
}
void SerializeVecToStream(std::ostream &out,
const std::vector<float> &value) {
out.write(reinterpret_cast<const char *>(value.data()),
static_cast<std::streamsize>(sizeof(float) * value.size()));
}
void SerializeValueToStream(
std::ostream &out, const std::vector<std::string> &ins,
const std::vector<std::vector<int>> &batch_buckets,
const int embedding_dim) {
std::vector<std::shared_ptr<std::ifstream>> in_streams;
for (int x = 0; x < ins.size(); ++x) {
in_streams.emplace_back(std::make_shared<std::ifstream>(ins[x]));
}
std::vector<std::future<int>> tasks(ins.size());
for (int b = 0; b < BUCKET; ++b) {
std::vector<std::vector<float>> values;
values.resize(tasks.size());
auto begin = GetCurrentUS();
for (int x = 0; x < tasks.size(); ++x) {
auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b];
values[x].clear();
values[x].reserve(batch * embedding_dim);
}
for (int x = 0; x < tasks.size(); ++x) {
tasks[x] =
pool_->enqueue([this, b, x, &out, &in_streams, &batch_buckets,
&values, embedding_dim]() -> int {
auto batch = batch_buckets[x][b + 1] - batch_buckets[x][b];
if (batch == 0) return 0;
SerializeValueToVec(*(in_streams[x].get()), batch, embedding_dim,
&values[x]);
return 0;
});
}
for (size_t x = 0; x < tasks.size(); ++x) {
tasks[x].wait();
}
auto end = GetCurrentUS();
auto begin1 = GetCurrentUS();
for (size_t x = 0; x < tasks.size(); ++x) {
SerializeVecToStream(out, values[x]);
}
auto end1 = GetCurrentUS();
VLOG(0) << "serialize buckets " << b << " read using " << end - begin
<< ", to oostream using " << end1 - begin1;
}
}
void DeserializeRowsFromFile(const std::string &input_file,
const int64_t feasigns,
std::vector<int64_t> *rows) {
std::string line;
std::vector<std::string> columns;
std::ifstream file(input_file);
rows->reserve(feasigns);
while (std::getline(file, line)) {
columns = string::Split(line, '\t');
if (columns.size() != 5) {
VLOG(0) << "unexpected line: " << line << ", skip it";
continue;
}
rows->push_back(std::stoull(columns[0]));
}
VLOG(0) << "parse " << rows->size() << " embedding rows from "
<< input_file;
}
private:
std::unique_ptr<::ThreadPool> pool_;
};
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <sys/time.h>
#include <functional>
#include <memory>
#include <string>
......@@ -83,5 +85,11 @@ std::string to_string(const std::vector<T>& vec) {
}
return ss.str();
}
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
}
} // namespace distributed
} // namespace paddle
......@@ -80,11 +80,11 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
......@@ -144,6 +144,163 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
return fut;
}
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
if (add_weight)
is_weighted_bucket[index_mapping[server_index]].push_back(
query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
: false);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
weighted[j] = is_weighted_bucket[request_idx][j];
closure->request(request_idx)
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
......@@ -174,8 +331,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
......@@ -254,13 +411,14 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
char *buffer = new char[bytes_size];
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index));
index += GraphNode::id_size;
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
......@@ -292,7 +450,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
......@@ -301,6 +459,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
index += node.get_size(false);
res.push_back(node);
}
delete buffer;
}
closure->set_promise_value(ret);
});
......
......@@ -78,6 +78,13 @@ class GraphBrpcClient : public BrpcPsClient {
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
......
......@@ -24,6 +24,14 @@
namespace paddle {
namespace distributed {
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcServer::initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
......@@ -71,6 +79,58 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
return 0;
}
int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
((GraphTable *)table)->clear_nodes();
return 0;
}
int32_t GraphBrpcService::add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
((GraphTable *)table)->add_graph_node(node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 1 argument");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids);
return 0;
}
int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() {
......@@ -92,21 +152,17 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;
_service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
_service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
......
......@@ -86,6 +86,13 @@ class GraphBrpcService : public PsBaseService {
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
......
......@@ -44,6 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
}
}
void add_graph_node(std::vector<uint64_t> node_ids,
std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
......@@ -247,6 +250,34 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
}
}
void GraphPyClient::clear_nodes(std::string name) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->clear_nodes(table_id);
status.wait();
}
}
void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids,
std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
status.wait();
}
}
void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
status.wait();
}
}
void GraphPyClient::load_node_file(std::string name, std::string filepath) {
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
......
......@@ -141,6 +141,10 @@ class GraphPyClient : public GraphPyService {
void finalize_worker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids,
std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
......
......@@ -52,6 +52,9 @@ enum PsCmdID {
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
}
message PsRequestMessage {
......
......@@ -35,6 +35,77 @@ std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list) {
size_t node_size = id_list.size();
std::vector<std::vector<std::pair<uint64_t, bool>>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
continue;
}
batch[get_thread_pool_index(id_list[i])].push_back(
{id_list[i], i < is_weight_list.size() ? is_weight_list[i] : false});
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
this->shards[index].add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
size_t node_size = id_list.size();
std::vector<std::vector<uint64_t>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) continue;
batch[get_thread_pool_index(id_list[i])].push_back(id_list[i]);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
this->shards[index].delete_node(p);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
void GraphShard::clear() {
for (size_t i = 0; i < bucket.size(); i++) {
delete bucket[i];
}
bucket.clear();
node_location.clear();
}
GraphShard::~GraphShard() { clear(); }
void GraphShard::delete_node(uint64_t id) {
auto iter = node_location.find(id);
if (iter == node_location.end()) return;
int pos = iter->second;
delete bucket[pos];
if (pos != (int)bucket.size() - 1) {
bucket[pos] = bucket.back();
node_location[bucket.back()->get_id()] = pos;
}
node_location.erase(id);
bucket.pop_back();
}
GraphNode *GraphShard::add_graph_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
......@@ -79,11 +150,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
int start = 0, end, index = 0, total_size = 0;
res.clear();
std::vector<std::future<std::vector<uint64_t>>> tasks;
// std::string temp = "";
// for(int i = 0;i < shards.size();i++)
// temp+= std::to_string((int)shards[i].get_size()) + " ";
// VLOG(0)<<"range distribution "<<temp;
for (int i = 0; i < shards.size() && index < ranges.size(); i++) {
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i].get_size();
start = total_size;
while (start < end && index < ranges.size()) {
......@@ -97,7 +164,6 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
start = second;
first -= total_size;
second -= total_size;
// VLOG(0)<<" FIND RANGE "<<i<<" "<<first<<" "<<second;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, first, second, i]() -> std::vector<uint64_t> {
return shards[i].get_ids_by_range(first, second);
......@@ -106,7 +172,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
}
total_size += shards[i].get_size();
}
for (int i = 0; i < tasks.size(); i++) {
for (size_t i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
for (auto &id : vec) {
res.push_back(id);
......@@ -219,7 +285,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
......@@ -238,10 +304,29 @@ Node *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_table % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(
uint64_t shard_index) {
return shard_index % shard_num_per_table % task_pool_size_;
}
int32_t GraphTable::clear_nodes() {
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < shards.size(); i++) {
tasks.push_back(
_shards_task_pool[get_thread_pool_index_by_shard_index(i)]->enqueue(
[this, i]() -> int {
this->shards[i].clear();
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::random_sample_nodes(int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
bool need_feature = false;
int total_size = 0;
for (int i = 0; i < shards.size(); i++) {
total_size += shards[i].get_size();
......@@ -281,7 +366,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
}
std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size;
for (int i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size)
first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index});
......@@ -386,7 +471,6 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
if (this->feat_id_map.count(fields[0])) {
int32_t id = this->feat_id_map[fields[0]];
std::string dtype = this->feat_dtype[id];
int32_t shape = this->feat_shape[id];
std::vector<std::string> values(fields.begin() + 1, fields.end());
if (dtype == "feasign") {
return std::make_pair<int32_t, std::string>(
......
......@@ -36,11 +36,12 @@ class GraphShard {
size_t get_size();
GraphShard() {}
GraphShard(int shard_num) { this->shard_num = shard_num; }
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res;
for (int i = start; i < end && i < bucket.size(); i++) {
for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
......@@ -48,6 +49,8 @@ class GraphShard {
GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
void clear();
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> get_node_location() {
return node_location;
......@@ -85,6 +88,11 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list);
Node *find_node(uint64_t id);
virtual int32_t pull_sparse(float *values,
......@@ -97,6 +105,7 @@ class GraphTable : public SparseTable {
return 0;
}
virtual int32_t clear_nodes();
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
......@@ -105,6 +114,7 @@ class GraphTable : public SparseTable {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
......@@ -128,4 +138,5 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
};
} // namespace distributed
}; // namespace paddle
......@@ -134,10 +134,23 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
}
}
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
const int mode) {
int64_t save_num = 0;
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total) {
// save meta
std::stringstream stream;
stream << "param=" << common.table_name() << "\n";
stream << "shard_id=" << shard_idx << "\n";
stream << "row_names=" << paddle::string::join_strings(common.params(), ',')
<< "\n";
stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',')
<< "\n";
stream << "count=" << total << "\n";
os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
}
int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode) {
int64_t save_num = 0;
for (auto& table : block->values_) {
for (auto& value : table) {
if (mode == SaveMode::delta && !value.second->need_save_) {
......@@ -334,16 +347,24 @@ int32_t CommonSparseTable::set_global_lr(float* lr) {
int32_t CommonSparseTable::load(const std::string& path,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
VLOG(3) << "sparse table load with " << path << " with meta " << param;
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
&shard_values_);
rwlock_->UNLock();
auto end = GetCurrentUS();
auto varname = _config.common().table_name();
VLOG(0) << "load " << varname << " with value: " << path
<< " , meta: " << param
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
return 0;
}
int32_t CommonSparseTable::save(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
int mode = std::stoi(param);
VLOG(3) << "sparse table save: " << dirname << " mode: " << mode;
......@@ -356,36 +377,33 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
std::vector<std::string> params(_config.common().params().begin(),
_config.common().params().end());
std::string shard_var_pre =
string::Sprintf("%s.block%d", varname, _shard_idx);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> value_out(new std::ofstream(value_));
std::unique_ptr<std::ofstream> vs(new std::ofstream(value_));
int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values
total_ins += SaveToText(value_out.get(), shard_values_[shard_id], mode);
auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode);
total_ins += shard_save_num;
}
value_out->close();
vs->close();
// save meta
std::stringstream stream;
stream << "param=" << _config.common().table_name() << "\n";
stream << "shard_id=" << _shard_idx << "\n";
stream << "row_names="
<< paddle::string::join_strings(_config.common().params(), ',')
<< "\n";
stream << "row_dims="
<< paddle::string::join_strings(_config.common().dims(), ',') << "\n";
stream << "count=" << total_ins << "\n";
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> meta_out(new std::ofstream(meta_));
meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size());
meta_out->close();
VLOG(3) << "save " << varname << " in dir: " << var_store << " done";
std::unique_ptr<std::ofstream> ms(new std::ofstream(meta_));
SaveMetaToText(ms.get(), _config.common(), _shard_idx, total_ins);
ms->close();
auto end = GetCurrentUS();
rwlock_->UNLock();
VLOG(0) << "save " << varname << " with path: " << value_
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
return 0;
}
......@@ -403,8 +421,6 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
}
int32_t CommonSparseTable::pour() {
rwlock_->RDLock();
std::vector<float> values;
std::vector<uint64_t> keys;
......@@ -421,14 +437,11 @@ int32_t CommonSparseTable::pour() {
_push_sparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear();
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
rwlock_->RDLock();
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
......@@ -464,7 +477,6 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
......@@ -507,7 +519,6 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -531,7 +542,6 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
......@@ -569,7 +579,6 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float** values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -596,14 +605,11 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -635,14 +641,12 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) {
rwlock_->WRLock();
int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold;
......@@ -651,7 +655,6 @@ int32_t CommonSparseTable::shrink(const std::string& param) {
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink";
shard_values_[shard_id]->Shrink(threshold);
}
rwlock_->UNLock();
return 0;
}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_edge.h"
#include <cstring>
namespace paddle {
namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
}
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
weight_arr.push_back(weight);
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {
class GraphEdgeBlob {
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }
protected:
std::vector<uint64_t> id_arr;
};
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; }
protected:
std::vector<float> weight_arr;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_node.h"
#include <cstring>
namespace paddle {
namespace distributed {
GraphNode::~GraphNode() {
if (sampler != nullptr) {
delete sampler;
sampler = nullptr;
}
if (edges != nullptr) {
delete edges;
edges = nullptr;
}
}
int Node::weight_size = sizeof(float);
int Node::id_size = sizeof(uint64_t);
int Node::int_size = sizeof(int);
int Node::get_size(bool need_feature) { return id_size + int_size; }
void Node::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
memcpy(buffer, &feat_num, sizeof(int));
}
void Node::recover_from_buffer(char* buffer) { memcpy(&id, buffer, id_size); }
int FeatureNode::get_size(bool need_feature) {
int size = id_size + int_size; // id, feat_num
if (need_feature) {
size += feature.size() * int_size;
for (const std::string& fea : feature) {
size += fea.size();
}
}
return size;
}
void GraphNode::build_edges(bool is_weighted) {
if (edges == nullptr) {
if (is_weighted == true) {
edges = new WeightedGraphEdgeBlob();
} else {
edges = new GraphEdgeBlob();
}
}
}
void GraphNode::build_sampler(std::string sample_type) {
if (sample_type == "random") {
sampler = new RandomSampler();
} else if (sample_type == "weighted") {
sampler = new WeightedSampler();
}
sampler->build(edges);
}
void FeatureNode::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
int feat_len;
if (need_feature) {
feat_num += feature.size();
memcpy(buffer, &feat_num, sizeof(int));
buffer += sizeof(int);
for (int i = 0; i < feat_num; ++i) {
feat_len = feature[i].size();
memcpy(buffer, &feat_len, sizeof(int));
buffer += sizeof(int);
memcpy(buffer, feature[i].c_str(), feature[i].size());
buffer += feature[i].size();
}
} else {
memcpy(buffer, &feat_num, sizeof(int));
}
}
void FeatureNode::recover_from_buffer(char* buffer) {
int feat_num, feat_len;
memcpy(&id, buffer, id_size);
buffer += id_size;
memcpy(&feat_num, buffer, sizeof(int));
buffer += sizeof(int);
feature.clear();
for (int i = 0; i < feat_num; ++i) {
memcpy(&feat_len, buffer, sizeof(int));
buffer += sizeof(int);
char str[feat_len + 1];
memcpy(str, buffer, feat_len);
buffer += feat_len;
str[feat_len] = '\0';
feature.push_back(std::string(str));
}
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/table/graph_weighted_sampler.h"
namespace paddle {
namespace distributed {
class Node {
public:
Node() {}
Node(uint64_t id) : id(id) {}
virtual ~Node() {}
static int id_size, int_size, weight_size;
uint64_t get_id() { return id; }
void set_id(uint64_t id) { this->id = id; }
virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); }
virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; }
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) { return std::string(""); }
virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; }
protected:
uint64_t id;
};
class GraphNode : public Node {
public:
GraphNode() : Node(), sampler(nullptr), edges(nullptr) {}
GraphNode(uint64_t id) : Node(id), sampler(nullptr), edges(nullptr) {}
virtual ~GraphNode();
virtual void build_edges(bool is_weighted);
virtual void build_sampler(std::string sample_type);
virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight);
}
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
protected:
Sampler *sampler;
GraphEdgeBlob *edges;
};
class FeatureNode : public Node {
public:
FeatureNode() : Node() {}
FeatureNode(uint64_t id) : Node(id) {}
virtual ~FeatureNode() {}
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) {
if (idx < (int)this->feature.size()) {
return this->feature[idx];
} else {
return std::string("");
}
}
virtual void set_feature(int idx, std::string str) {
if (idx >= (int)this->feature.size()) {
this->feature.resize(idx + 1);
}
this->feature[idx] = str;
}
virtual void set_feature_size(int size) { this->feature.resize(size); }
virtual int get_feature_size() { return this->feature.size(); }
template <typename T>
static std::string parse_value_to_bytes(std::vector<std::string> feat_str) {
T v;
size_t Tsize = sizeof(T) * feat_str.size();
char buffer[Tsize];
for (size_t i = 0; i < feat_str.size(); i++) {
std::stringstream ss(feat_str[i]);
ss >> v;
std::memcpy(buffer + sizeof(T) * i, (char *)&v, sizeof(T));
}
return std::string(buffer, Tsize);
}
template <typename T>
static std::vector<T> parse_bytes_to_array(std::string feat_str) {
T v;
std::vector<T> out;
size_t start = 0;
const char *buffer = feat_str.data();
while (start < feat_str.size()) {
std::memcpy((char *)&v, buffer + start, sizeof(T));
start += sizeof(T);
out.push_back(v);
}
return out;
}
protected:
std::vector<std::string> feature;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph_weighted_sampler.h"
#include <iostream>
#include <unordered_map>
namespace paddle {
namespace distributed {
void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
std::vector<int> RandomSampler::sample_k(int k) {
int n = edges->size();
if (k > n) {
k = n;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result;
std::unordered_map<int, int> replace_map;
while (k--) {
int rand_int = rand() % n;
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
} else {
sample_result.push_back(iter->second);
}
iter = replace_map.find(n - 1);
if (iter == replace_map.end()) {
replace_map[rand_int] = n - 1;
} else {
replace_map[rand_int] = iter->second;
}
--n;
}
return sample_result;
}
WeightedSampler::WeightedSampler() {
left = nullptr;
right = nullptr;
edges = nullptr;
}
WeightedSampler::~WeightedSampler() {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
}
void WeightedSampler::build(GraphEdgeBlob *edges) {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
return build_one((WeightedGraphEdgeBlob *)edges, 0, edges->size());
}
void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
int end) {
count = 0;
this->edges = edges;
if (start + 1 == end) {
left = right = nullptr;
idx = start;
count = 1;
weight = edges->get_weight(idx);
} else {
left = new WeightedSampler();
right = new WeightedSampler();
left->build_one(edges, start, start + (end - start) / 2);
right->build_one(edges, start + (end - start) / 2, end);
weight = left->weight + right->weight;
count = left->count + right->count;
}
}
std::vector<int> WeightedSampler::sample_k(int k) {
if (k > count) {
k = count;
}
std::vector<int> sample_result;
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract));
}
return sample_result;
}
int WeightedSampler::sample(
float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract) {
if (left == nullptr) {
subtract_weight_map[this] = weight;
subtract = weight;
subtract_count_map[this] = 1;
return idx;
}
int left_count = left->count - subtract_count_map[left];
int right_count = right->count - subtract_count_map[right];
float left_subtract = subtract_weight_map[left];
int return_idx;
if (right_count == 0 ||
left_count > 0 && left->weight - left_subtract >= query_weight) {
return_idx = left->sample(query_weight, subtract_weight_map,
subtract_count_map, subtract);
} else {
return_idx =
right->sample(query_weight - (left->weight - left_subtract),
subtract_weight_map, subtract_count_map, subtract);
}
subtract_weight_map[this] += subtract;
subtract_count_map[this]++;
return return_idx;
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/table/graph_edge.h"
namespace paddle {
namespace distributed {
class Sampler {
public:
virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0;
};
class RandomSampler : public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k);
GraphEdgeBlob *edges;
};
class WeightedSampler : public Sampler {
public:
WeightedSampler();
virtual ~WeightedSampler();
WeightedSampler *left, *right;
float weight;
int count;
int idx;
GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k);
private:
int sample(float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract);
};
}
}
......@@ -36,7 +36,7 @@ class Table {
Table() {}
virtual ~Table() {}
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config) final;
const FsClientParameter &fs_config);
virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0;
......@@ -58,7 +58,9 @@ class Table {
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float **values,
size_t num){};
size_t num) {
return 0;
}
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values,
size_t num) {
return 0;
......@@ -108,7 +110,7 @@ class Table {
virtual int32_t save(const std::string &path,
const std::string &converter) = 0;
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) final {
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx;
_shard_num = shard_num;
return initialize_shard();
......@@ -123,7 +125,7 @@ class Table {
protected:
virtual int32_t initialize() = 0;
virtual int32_t initialize_accessor() final;
virtual int32_t initialize_accessor();
virtual int32_t initialize_shard() = 0;
virtual std::string table_dir(const std::string &model_dir) {
return paddle::string::format_string("%s/%03d/", model_dir.c_str(),
......
......@@ -124,7 +124,6 @@ void testSingleSampleNeighboor(
for (auto g : s) {
ASSERT_EQ(true, s1.find(g) != s1.end());
}
VLOG(0) << "test single done";
s.clear();
s1.clear();
vs.clear();
......@@ -141,6 +140,57 @@ void testSingleSampleNeighboor(
}
}
void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0);
int total_num = 270000;
uint64_t id;
std::unordered_set<uint64_t> id_set;
for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end())
;
id_set.insert(id);
}
std::vector<uint64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait();
std::vector<uint64_t> ids[2];
for (int i = 0; i < 2; i++) {
auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
}
std::vector<uint64_t> remove_ids;
for (auto p : id_set_check) {
if (remove_ids.size() == 0)
remove_ids.push_back(p);
else if (remove_ids.size() < total_num / 2 && rand() % 2 == 1) {
remove_ids.push_back(p);
}
}
for (auto p : remove_ids) id_set_check.erase(p);
status = worker_ptr_->remove_graph_node(0, remove_ids);
status.wait();
for (int i = 0; i < 2; i++) ids[i].clear();
for (int i = 0; i < 2; i++) {
auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
}
}
void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
......@@ -527,6 +577,7 @@ void RunBrpcPushSparse() {
std::remove(edge_file_name);
std::remove(node_file_name);
testAddNode(worker_ptr_);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
......
......@@ -27,6 +27,7 @@ add_subdirectory(fleet)
add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(op_def_proto SRCS op_def.proto)
proto_library(heter_service_proto SRCS heter_service.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
......@@ -287,6 +288,15 @@ if(WITH_DISTRIBUTE)
graph_to_program_pass variable_helper timer monitor)
endif()
elseif(WITH_PSLIB)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(device_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(hetercpu_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heterxpu_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
......
......@@ -143,7 +143,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place) {
platform::Place place, bool always_copy) {
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should "
......@@ -177,7 +177,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
if (in_format != out_format) {
if ((in_format != out_format) || always_copy) {
void* in_data = GetDataFromTensor(in, in_type);
std::string key =
platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);
......
......@@ -78,7 +78,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
platform::Place place,
bool always_copy = false);
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/eigen_ext.h"
......@@ -30,6 +31,8 @@ struct bfloat16;
struct complex128;
struct complex64;
struct float16;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -61,6 +64,10 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
......@@ -69,6 +76,10 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
......
......@@ -163,6 +163,11 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex128 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
float > : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
double > : omp_out += omp_in)
#endif
template <typename T>
......@@ -268,12 +273,69 @@ void CheckNanInf<paddle::platform::complex128>(
op_type));
}
}
template <>
void CheckNanInf<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
float real_sum = 0.0f;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
float imag_sum = 0.0f;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
template <>
void CheckNanInf<paddle::platform::complex<double>>>
(const paddle::platform::complex<double>* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
double real_sum = 0.0;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
double imag_sum = 0.0;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
#endif
template <>
template <typename T>
void TensorCheckerVisitor<platform::CPUDeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
// use env strategy control in future, -1=print_all.
int print_num = 3;
CheckNanInf(tensor_.data<T>(), tensor_.numel(), print_num, op_type_,
......
......@@ -123,7 +123,11 @@ __global__ void CheckNanInfKernel(const T* value, const size_t numel,
template <>
template <typename T>
void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
int print_num = 3;
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
......
......@@ -46,8 +46,12 @@ struct TensorCheckerVisitor {
}
template <typename T>
void apply(typename std::enable_if<std::is_floating_point<T>::value>::type* =
0) const;
void apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type* =
0) const;
std::string op_type_;
std::string var_name_;
......
......@@ -195,6 +195,9 @@ class DeviceWorker {
virtual void SetReaderPlace(const paddle::platform::Place& place) {
device_reader_->SetPlace(place);
}
virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) {
dev_ctx_ = dev_ctx;
}
virtual Scope* GetThreadScope() { return thread_scope_; }
DataFeed* device_reader_ = nullptr;
......@@ -221,6 +224,7 @@ class DeviceWorker {
int dump_mode_ = 0;
int dump_interval_ = 10000;
ChannelWriter<std::string> writer_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
class CPUWorkerBase : public DeviceWorker {
......@@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
#ifdef PADDLE_WITH_HETERPS
platform::DeviceContext* dev_ctx_ = nullptr;
#endif
};
class DownpourWorker : public HogwildWorker {
......@@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker {
gpuStream_t copy_stream_;
int batch_cnt_{0};
std::atomic<int> done_cnt_{0};
platform::DeviceContext* dev_ctx_ = nullptr;
double total_time_;
double read_time_;
......
......@@ -141,6 +141,7 @@ message PipelineConfig {
message TensorParallelConfig {
optional int32 tensor_parallel_degree = 1 [ default = 1 ];
optional int32 tensor_init_seed = 2 [ default = -1 ];
}
message DistributedStrategy {
......
......@@ -28,9 +28,19 @@ namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype;
if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value ||
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
// The current dlpack library version is v0.2, and does not define
// kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set
// dtype.code to 5U directly here. After the dlpack library version being
// upgraded to v0.4, it should be written as follow.
// dtype.code = kDLComplex;
dtype.code = 5U;
} else if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) {
dtype.code = kDLUInt;
......
......@@ -28,6 +28,13 @@ namespace framework {
namespace { // NOLINT
template <typename T>
constexpr uint8_t GetDLDataTypeCode() {
if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value ||
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
return static_cast<uint8_t>(5);
}
return std::is_same<platform::float16, T>::value ||
std::is_floating_point<T>::value
? static_cast<uint8_t>(kDLFloat)
......
......@@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
for (int i = 0; i < param_.stat_var_names_size(); ++i) {
stat_var_name_map_[param_.stat_var_names(i)] = 1;
}
#ifdef PADDLE_WITH_HETERPS
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
#endif
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
......
......@@ -50,8 +50,9 @@ if (WITH_TESTING)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
......@@ -139,6 +140,7 @@ cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass)
cc_test(test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_fc_lstm_fuse_pass_cc SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto)
cc_test(test_fc_gru_fuse_pass_cc SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto)
......
......@@ -17,7 +17,7 @@
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
......@@ -46,7 +46,7 @@ enum FuseOptions {
FUSE_MKLDNN // fusing will be done with MKL-DNN
};
class FusePassBase : public Pass {
class FusePassBase : public OpCompatSensiblePass {
public:
void Init(const std::string& repr, Graph* graph) const;
Scope* param_scope() const;
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle {
namespace framework {
namespace ir {
AttrCompat& AttrCompat::IsStringIn(const std::set<std::string>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr);
for (auto& str : candidates) {
if (str == value) {
return true;
}
}
return false;
});
return *this;
}
AttrCompat& AttrCompat::IsStringMatch(
const std::function<bool(const std::string&)>& func) {
conditions_.emplace_back([func](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr);
return func(value);
});
return *this;
}
AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
int value = BOOST_GET_CONST(int, attr);
return candidates.find(value) != candidates.end();
});
return *this;
}
//! Todo: append the definition.
AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) {
VLOG(3) << "Op (" << op_name << ") is not registered!";
conditions_.emplace_back([](const Attribute& attr) { return false; });
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
if (attrs.find(attr_name_) == attrs.end()) {
VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
} else {
Attribute default_attr = attrs.at(attr_name_);
conditions_.emplace_back([default_attr](const Attribute& attr) -> bool {
return attr == default_attr;
});
}
return *this;
}
bool AttrCompat::operator()(const OpDesc& op_desc) {
if (conditions_.empty()) {
return true;
}
if (!op_desc.HasAttr(attr_name_)) {
return optional_;
}
const Attribute attr = op_desc.GetAttr(attr_name_);
for (auto& func : conditions_) {
if (!func(attr)) {
return false;
}
}
return true;
}
AttrCompat& AttrCompat::IsOptional() {
optional_ = true;
return *this;
}
AttrCompat& AttrCompat::IsBoolEQ(bool v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
bool value = BOOST_GET_CONST(bool, attr);
return value == v;
});
return *this;
}
InputOrOutputCompat& InputOrOutputCompat::IsTensor() {
conditions_.emplace_back([](const std::vector<std::string>& input) -> bool {
return input.size() == 1u;
});
return *this;
}
InputOrOutputCompat& InputOrOutputCompat::IsOptional() {
optional_ = true;
return *this;
}
bool InputOrOutputCompat::operator()(
const std::vector<std::string>& input) const {
if (input.empty()) return false;
for (auto& func : conditions_) {
if (!func(input)) {
return false;
}
}
return true;
}
AttrCompat& OpCompat::AddAttr(const std::string& attr_name) {
PADDLE_ENFORCE_EQ(
attr_compats_.find(attr_name), attr_compats_.end(),
platform::errors::InvalidArgument(
"The attrubute compat with the same name has been added"));
attr_compats_.emplace(attr_name, AttrCompat(attr_name, this));
return attr_compats_.at(attr_name);
}
InputOrOutputCompat& OpCompat::AddInput(const std::string& name) {
PADDLE_ENFORCE_EQ(input_compats_.find(name), input_compats_.end(),
platform::errors::InvalidArgument(
"The input with the same name has been added"));
input_compats_.emplace(name, InputOrOutputCompat(name, this));
return input_compats_.at(name);
}
InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
PADDLE_ENFORCE_EQ(output_compats_.find(name), output_compats_.end(),
platform::errors::InvalidArgument(
"The output with the same name has been added"));
output_compats_.emplace(name, InputOrOutputCompat(name, this));
return output_compats_.at(name);
}
bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_map : op_desc.GetAttrMap()) {
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not equal to default value!";
return false;
}
}
}
for (auto& attr_compat : attr_compats_) {
if (!attr_compat.second(op_desc)) {
VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
return false;
}
}
const VariableNameMap& inputs_map = op_desc.Inputs();
for (auto& input_desc : inputs_map) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) {
VLOG(3) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
}
for (auto& input_val : input_compats_) {
if (inputs_map.find(input_val.first) == inputs_map.end()) {
if (!input_val.second.Optional()) {
VLOG(3) << "The No optional Input (" << input_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false;
}
} else {
if (!input_val.second(inputs_map.at(input_val.first))) {
VLOG(3) << "The Input (" << input_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
}
const VariableNameMap& outputs_map = op_desc.Outputs();
for (auto& output_desc : outputs_map) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) {
VLOG(3) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
}
for (auto& output_val : output_compats_) {
if (outputs_map.find(output_val.first) == outputs_map.end()) {
if (!output_val.second.Optional()) {
VLOG(3) << "The No optional Output (" << output_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false;
}
} else {
if (!output_val.second(outputs_map.at(output_val.first))) {
VLOG(3) << "The Output (" << output_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
}
return true;
}
OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) {
std::string name = op_compat.Name();
op_compat_judgers_[name].reset(new OpCompat(std::move(op_compat)));
return *(op_compat_judgers_[name]);
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class OpCompat;
class AttrCompat {
public:
AttrCompat(const std::string& attr_name, OpCompat* op_compat)
: optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringIn(const std::set<std::string>& candidates);
//! Assert the attribute is a string and match a custom judging function.
AttrCompat& IsStringMatch(
const std::function<bool(const std::string&)>& func);
// @}
//! Assert the attribute is an integer in the `candidates` domain.
AttrCompat& IsIntIn(const std::set<int>& candidates);
// @{ Number-releated methods
//! Assert the attribute is a number and > `v`.
template <typename T>
AttrCompat& IsNumGT(T v);
//! Assert the attribute is a number and >= `v`.
template <typename T>
AttrCompat& IsNumGE(T v);
//! Assert the attribute is a number and < `v`.
template <typename T>
AttrCompat& IsNumLT(T v);
//! Assert the attribute is a number and <= `v`.
template <typename T>
AttrCompat& IsNumLE(T v);
//! Assert the attribute is a number and == `v`.
template <typename T>
AttrCompat& IsNumEQ(T v);
//! Assert the attribute is a number and matches a customized judging
//! function.
template <typename T>
AttrCompat& IsNumMatch(bool (*func)(T));
// @}
//! Assert the attribute is a boolean value equals `v`.
AttrCompat& IsBoolEQ(bool v);
//! Tell whether this attribute is left as default value.
AttrCompat& IsLeftDefault();
AttrCompat& IsOptional();
//! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; }
bool operator()(const OpDesc& op_desc);
private:
bool optional_;
std::string attr_name_;
OpCompat* op_compat_;
std::vector<std::function<bool(const Attribute&)>> conditions_;
};
class InputOrOutputCompat {
public:
InputOrOutputCompat(const std::string& name, OpCompat* op_compat)
: optional_(false), name_(name), op_compat_(op_compat) {}
InputOrOutputCompat& IsTensor();
InputOrOutputCompat& IsOptional();
bool Optional() const { return optional_; }
bool operator()(const std::vector<std::string>& input) const;
//! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; }
private:
bool optional_;
std::string name_;
OpCompat* op_compat_;
std::vector<std::function<bool(const std::vector<std::string>&)>> conditions_;
};
/**
* OpCompat is a helper class to help define the compatible Op definition.
*
* Usage:
* OpCompat compat("FC");
* compat.AddAttr("in_num_col_dims").IsNumLE(1).End()
* .AddAttr("activation_type").IsStringIn({"tanh", "sigmoid"}).End()
* .AddInput("Input").IsTensor().End()
* .AddInput("W").IsTensor().End()
* .AddInput("Bias").IsTensor().IsOptional().End()
* .AddOutput("Out").IsTensor().End()
*
* All the inference-aware Op defition is as above, all the other attributes not
* contained in the definition should be set default value or it would be judged
* incompatible.
*/
class OpCompat {
public:
explicit OpCompat(const std::string& op_name) : op_name_(op_name) {}
explicit OpCompat(std::string&& op_name) : op_name_(std::move(op_name)) {}
explicit OpCompat(const OpCompat&) = default;
explicit OpCompat(OpCompat&&) = default;
AttrCompat& AddAttr(const std::string& attr_name);
InputOrOutputCompat& AddInput(const std::string& name);
InputOrOutputCompat& AddOutput(const std::string& name);
//! Judge whether an OpDesc match the defined Op compatibility.
bool Judge(const OpDesc& op_desc);
const std::string& Name() const { return op_name_; }
private:
std::string op_name_;
std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
};
/**
* OpCompatSensiblePass is a base class for all the passes thouse is sensitive
* to Op update.
* There are two methods to help tell the compability of an Op
* bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, Graph* g);
* bool IsCompat(const OpDesc& op_desc);
*
* One can register the related Op compabilities using
* void AddOpCompat(OpCompat&& judger);
*
* Most of the Passes are used for fusing ops, so we define a method for such
* scenerios.
* void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g);
* It will check the Op compatibility automatically.
* For other scenirios, one should call `IsCompat` by himself.
*
* A FC fuse pass example:
* class FcFusePass : public OpCompatSensiblePass {
* public:
* FcFusePass() {
* // define Mul op compatiblity.
* AddOpCompat(OpCompat("Mul"))
* .AddInput("Input").IsTensor().End()
* .AddAttr("in_num_col_dims").IsNumGE(1);
* AddOpCompat(OpCompat("Add")). ...;
* // There are multiple activation implemention.
* AddOpCompat(OpCompat("Tanh")). ...;
* AddOpCompat(OpCompat("Sigmoid")). ...;
* }
*
* // override the subgraph access method
* virtual bool AccessSubgraphImpl(
* const GraphPatternDetector::subgraph_t& subgraph,
* Graph* g) override { ... }
*
* // Call the AccessSubgraph method in main procedure of this Pass.
* };
*/
class OpCompatSensiblePass : public Pass {
protected:
/**
* Developer should push the compatibility `teller` for each kind of Op in the
* subgraph.
* NOTE One should add all the related op compatiblity in the construct so
* that all the following methods are valid.
*/
OpCompat& AddOpCompat(OpCompat&& op_compat);
//! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const {
CHECK(!op_compat_judgers_.empty())
<< "At least one OpCompat instance should be added in the "
"OpCompatSensiblePass.";
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
return false;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
//! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const {
if (!op_compat_judgers_.count(op_desc.Type())) return false;
return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc);
}
private:
std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_;
};
template <typename T>
AttrCompat& AttrCompat::IsNumGT(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value > v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumGE(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value >= v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumLT(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value < v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumLE(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value <= v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumEQ(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return value == v;
});
return *this;
}
template <typename T>
AttrCompat& AttrCompat::IsNumMatch(bool (*func)(T)) {
conditions_.emplace_back([func](const Attribute& attr) -> bool {
T value = BOOST_GET_CONST(T, attr);
return func(value);
});
return *this;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("fc");
compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2})
.IsNumLE(1)
.IsLeftDefault()
.End()
.AddAttr("activation_type")
.IsStringIn({"tanh", "sigmoid"})
.IsStringMatch(lambda)
.End()
.AddAttr("test_attr")
.IsBoolEQ(true)
.End()
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("Test")
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End();
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
attr_map["test_attr"] = true;
fc_op.SetAttrMap(attr_map);
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
fc_op.SetInput("W", std::vector<std::string>{"test_input_0"});
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "fc");
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpAttribute) {
OpCompat compat("fc");
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
fc_op.SetAttrMap(attr_map);
OpInfo info;
info.checker_ = new OpAttrChecker();
OpInfoMap::Instance().Insert("fc", info);
EXPECT_FALSE(compat.Judge(fc_op));
info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1);
EXPECT_TRUE(compat.Judge(fc_op));
delete info.checker_;
}
TEST(OpCompatSensiblePass, compatOpAttributeOptional) {
OpCompat compat("fc");
compat.AddAttr("activation_type")
.IsOptional()
.IsStringIn({"tanh", "sigmoid"});
OpDesc fc_op;
EXPECT_TRUE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpInput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOutput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetOutput("Output", std::vector<std::string>{"test_output"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddOutput("Output")
.IsTensor()
.End()
.AddOutput("Output_2")
.IsTensor()
.End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public:
OpCompatSensiblePassTest();
bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); }
};
OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("fc"))
.AddAttr("in_num_col_dims")
.IsNumLE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"tanh", "sigmoid"})
.End()
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor();
}
TEST(OpCompatSensiblePass, IsCompat) {
OpCompatSensiblePassTest test;
OpDesc fc_op;
fc_op.SetType("fc");
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
fc_op.SetAttrMap(attr_map);
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
fc_op.SetInput("W", std::vector<std::string>{"test_input_0"});
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
#ifdef PADDLE_WITH_HETERPS
workers_[i]->SetPlace(places_[i]);
workers_[i]->SetReaderPlace(places_[i]);
workers_[i]->SetDeviceContext(
platform::DeviceContextPool::Instance().Get(places_[i]));
#else
workers_[i]->SetPlace(place);
workers_[i]->SetReaderPlace(place);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "framework.proto";
package paddle.framework.proto;
message OpDef {
message VarDef {
required string name = 1;
// For the type of input / output variables.
reserved 2;
}
message AttrDef {
required string name = 1;
required AttrType type = 2;
}
message Desc {
repeated VarDef inputs = 1;
repeated VarDef outputs = 2;
repeated AttrDef attrs = 3;
}
required string type = 1;
required Desc def = 2;
optional Desc extra = 3;
}
......@@ -1229,6 +1229,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// will be executed and a warning will be given at the same time.
if (SupportGPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
} else if (SupportNPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
} else {
expected_kernel_key.place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
......@@ -1300,7 +1302,11 @@ void OperatorWithKernel::TransferInplaceVarsBack(
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto original_dims = original_tensor->dims();
original_tensor->ShareDataWith(*transformed_tensor);
original_tensor->Resize(original_dims);
// In order to solve the problem that the output latitude of NPU reshape
// operator is not changed when inplace.
if (type_ != "reshape2" && type_ != "reshape2_grad") {
original_tensor->Resize(original_dims);
}
}
}
......@@ -1550,10 +1556,10 @@ void OperatorWithKernel::ParseInputDataType(
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr.size(); j++) {
if (t_arr[j].IsInitialized()) {
t = &(t_arr[j]);
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
if (t_arr->at(j).IsInitialized()) {
t = &(t_arr->at(j));
}
}
}
......
......@@ -155,6 +155,7 @@ class OperatorBase {
std::string DebugString() const { return DebugStringEx(nullptr); }
virtual bool SupportGPU() const { return false; }
virtual bool SupportNPU() const { return false; }
const std::string& Type() const { return type_; }
......@@ -491,6 +492,13 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportNPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_npu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
......
......@@ -110,8 +110,22 @@ void SectionWorker::TrainFiles() {
BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
}
}
#elif defined(PADDLE_WITH_ASCEND_CL)
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for NPU.";
gc.reset(new NPUUnsafeFastGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
}
#endif
}
} // max_memory_size >= 0
if (schedule_mode_ == 0) {
// F-then-B scheduler which runs Forward phase for all microbatches,
......
......@@ -71,7 +71,7 @@ elseif (WIN32)
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor benchmark ${inference_deps}
ARGS --dirname=${WORD2VEC_MODEL_DIR})
endif()
if(WITH_TESTING)
if(WITH_TESTING AND TEST test_api_impl)
if(NOT APPLE)
set_tests_properties(test_api_impl PROPERTIES TIMEOUT 120)
endif()
......
......@@ -650,13 +650,6 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
gflags.push_back("--cudnn_deterministic=True");
}
if (config.thread_local_stream_enabled()) {
gflags.push_back("--allocator_strategy=thread_local");
process_level_allocator_enabled = false;
} else {
process_level_allocator_enabled = true;
}
// TODO(wilber): jetson tx2 may fail to run the model due to insufficient memory
// under the native_best_fit strategy. Modify the default allocation strategy to
// auto_growth. todo, find a more appropriate way to solve the problem.
......@@ -664,6 +657,15 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
gflags.push_back("--allocator_strategy=auto_growth");
#endif
// TODO(Shixiaowei02): Add a mandatory scheme to use the thread local
// allocator when multi-stream is enabled.
if (config.thread_local_stream_enabled()) {
gflags.push_back("--allocator_strategy=thread_local");
process_level_allocator_enabled = false;
} else {
process_level_allocator_enabled = true;
}
if (framework::InitGflags(gflags)) {
VLOG(3) << "The following gpu analysis configurations only take effect "
"for the first predictor: ";
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......@@ -161,8 +162,24 @@ void Tensor::CopyToCpu(T *data) {
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();
paddle::framework::Tensor out;
auto mem_allocation = std::make_shared<paddle::memory::Allocation>(
static_cast<void *>(data), ele_num * sizeof(T),
paddle::platform::CPUPlace());
out.ResetHolder(mem_allocation);
if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN(
tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(),
*tensor, &out, paddle::platform::CPUPlace(), true);
else
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
} else if (place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::platform::DeviceContextPool &pool =
......
......@@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter {
engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_);
if (op_pair == ops.end()) {
PADDLE_THROW(platform::errors::Fatal(
"Wrong activation op type, the trt do not support the %s act type.",
op_type_));
}
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
......
......@@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter {
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
PADDLE_ENFORCE_EQ(
data_layout, framework::DataLayout::kNCHW,
platform::errors::InvalidArgument(
"TensorRT affine channel converter can only convert NCHW format. "
"Other format should be run in fluid mode. Report a bug on github "
"issue if you see this line."));
// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
const int channel_axis = engine_->with_dynamic_shape();
......
......@@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x,
return false;
}
for (int i = 0; i < dims_x.nbDims; i++) {
// conservative judgment
if (dims_x.d[i] == -1 || dims_y.d[i] == -1) {
return false;
}
if (dims_x.d[i] != dims_y.d[i]) {
return false;
}
......
......@@ -143,6 +143,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false;
// strides > 1 is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) {
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
}
#endif
}
if (op_type == "pool2d") {
......@@ -225,6 +238,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< desc.Output("Output").size() << " output.";
return false;
}
// strides > 1 is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) {
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
}
#endif
}
if (op_type == "matmul") {
......
......@@ -176,7 +176,7 @@ if(NOT APPLE AND WITH_MKLML)
inference_analysis_api_test(test_analyzer_seq_pool1_fuse_compare_zero_copy ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_fuse_compare_zero_copy_tester.cc)
inference_analysis_api_test(test_analyzer_seq_pool1_fuse_statis ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_fuse_statis_tester.cc)
inference_analysis_api_test(test_analyzer_seq_pool1_profile ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_profile_tester.cc)
if(NOT WIN32)
if(NOT WIN32 AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
set_tests_properties(test_analyzer_seq_pool1_compare_determine PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_seq_pool1 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_seq_pool1_fuse_compare_zero_copy PROPERTIES TIMEOUT 120)
......@@ -242,10 +242,10 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
if(NOT WIN32 AND NOT APPLE)
if(NOT WIN32 AND NOT APPLE AND TEST test_analyzer_ernie_large)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 150 LABELS "RUN_TYPE=NIGHTLY")
endif()
if (WIN32)
if (WIN32 AND TEST test_analyzer_ernie_large)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 200)
endif()
......@@ -645,6 +645,10 @@ if(WITH_GPU)
ARGS --infer_model=${RESNET50_MODEL_DIR})
endif()
if("$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
return()
endif()
if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(trt_resnext_test PROPERTIES TIMEOUT 300)
set_tests_properties(trt_quant_int8_yolov3_r50_test PROPERTIES TIMEOUT 300)
......
......@@ -164,9 +164,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -174,9 +174,9 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
......@@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -52,8 +52,9 @@ class AbsKernel<platform::CUDADeviceContext, T>
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaAbsFunctor<T>();
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, math::Real<T>>(
dev_ctx, ins, &outs, functor);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T,
math::Real<T>>(dev_ctx, ins, &outs,
functor);
}
};
......@@ -69,8 +70,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
......@@ -78,8 +79,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -87,5 +88,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
......@@ -1315,8 +1316,8 @@ class ActivationCudaKernel
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(dev_ctx, ins,
&outs, functor);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
}
};
......@@ -1345,16 +1346,16 @@ class ActivationGradCudaKernel
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
// Only need forward output Out
ins.push_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) {
// Only need forward input X
ins.push_back(x);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
}
}
......@@ -1437,9 +1438,9 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
CudaReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -1448,6 +1449,36 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
#else
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::float16>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
/* ========================================================================== */
/* =========================== tanh register ============================ */
......
......@@ -27,6 +27,9 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
AddAttr<int>("in_dtype", "input data type");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Cast Operator.
......@@ -50,6 +53,7 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
grad->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
}
};
......@@ -77,6 +81,28 @@ class CastOp : public framework::OperatorWithKernel {
if (platform::is_cuda_pinned_place(tensor_place)) {
return framework::OpKernelType(tensor->type(), ctx.device_context());
}
#ifdef PADDLE_WITH_MKLDNN
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto MKLDNNSupportsCast = [&]() -> bool {
int dtype_fp32 = static_cast<int>(framework::proto::VarType::FP32);
int dtype_bf16 = static_cast<int>(framework::proto::VarType::BF16);
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) ||
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16))
return false;
return true;
};
if (this->CanMKLDNNBeUsed(ctx, tensor->type()) && MKLDNNSupportsCast()) {
return framework::OpKernelType(tensor->type(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}
};
......@@ -90,13 +116,11 @@ REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
cast, ops::CastOpKernel<CPU, float>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
......@@ -95,6 +95,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
......@@ -105,6 +106,23 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
#endif
......@@ -27,10 +27,11 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto x = ctx.Output<framework::LoDTensor>("Out");
void* ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel();
HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(out->dims(), ctx.GetPlace());
void* ptr = reinterpret_cast<void*>(const_cast<T*>(out->data<T>()));
int numel = out->numel();
HcclDataType dtype = platform::ToHCCLDataType(out->type());
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
......@@ -54,8 +55,10 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
int root = peer;
VLOG(3) << "begin hccl recv, parameter is: "
<< "root " << root << ", comm: " << comm->comm()
<< ", stream: " << stream;
<< "ring_id:" << ring_id << ", nranks:" << nranks
<< ", peer:" << peer << ", numel:" << numel << ", ptr:" << ptr
<< ", dtype:" << dtype << ", root:" << root
<< ", comm: " << comm->comm() << ", stream: " << stream;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(
ptr, numel, dtype, (uint32_t)root, comm->comm(), stream));
......
type: "while"
def {
inputs {
name: "X"
}
inputs {
name: "Condition"
}
outputs {
name: "Out"
}
outputs {
name: "StepScopes"
}
attrs {
name: "sub_block"
type: BLOCK
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "skip_eager_deletion_vars"
type: STRINGS
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
......@@ -233,7 +233,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
......@@ -242,4 +243,5 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
......@@ -23,7 +23,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
......@@ -31,4 +32,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
......@@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,
REGISTER_OP_CPU_KERNEL(
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, int>,
......
......@@ -13,15 +13,14 @@
// limitations under the License.
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
conj, ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
......
......@@ -15,15 +15,15 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
......
......@@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel {
"Output(Out) of DotOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto x_rank = (size_t)x_dims.size();
auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
platform::errors::PreconditionNotMet(
"ShapeError: The dimensions of input tensor X (%s) "
......@@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL(
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<float>>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct complex128;
struct complex64;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -135,9 +135,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -145,9 +145,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -159,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
// A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto.
......@@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_add)
.AddCheckpoint(
......
......@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -39,15 +38,24 @@ struct CudaAddFunctor {
};
template <typename T>
struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {z};
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx.template device_context<platform::CUDADeviceContext>(), ins, &outs,
CudaAddFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
}
};
......@@ -132,8 +140,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
......@@ -141,8 +149,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -151,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>,
plat::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
......@@ -161,5 +171,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
......@@ -20,11 +20,13 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef __NVCC__
#include <cuda.h>
#include <cuda_fp16.h>
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
......@@ -38,9 +40,10 @@ namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
......@@ -68,12 +71,13 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
same_dims_add(ctx, x, y, z);
if (x->dims() == y->dims()) {
SameDimsElemwiseAdd<platform::CPUDeviceContext, T>
LaunchElementwiseCpuKernel;
LaunchElementwiseCpuKernel(ctx, x, y, z);
} else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
LaunchBroadcastElementwiseCpuKernel<platform::CPUDeviceContext, T>(ctx, x,
y, z);
}
}
};
......@@ -459,8 +463,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
ddout);
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, &ddx_safe,
&ddy_safe, ddout);
}
}
};
......
......@@ -141,6 +141,7 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
}
}
const T* dz_data = dz->data<T>();
T* dx_data = nullptr;
T* dy_data = nullptr;
if (dx) {
......@@ -152,9 +153,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dx_data, dx_data,
dx_data, dz->data<T>(), dy_data,
dx_data, x_dims_vec, y_dims_vec);
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dz_data, dz_data,
dz_data, dz_data, dy_data, dx_data,
x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
......
......@@ -17,8 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -135,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -145,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad,
......@@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_div)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -76,18 +75,21 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
}
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
const paddle::platform::complex64* x, const paddle::platform::complex64* y,
const paddle::platform::complex64* out,
const paddle::platform::complex64* dout, int64_t size,
paddle::platform::complex64* dx, paddle::platform::complex64* dy) {
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* x,
const paddle::platform::complex<float>* y,
const paddle::platform::complex<float>* out,
const paddle::platform::complex<float>* dout, int64_t size,
paddle::platform::complex<float>* dx,
paddle::platform::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex64 o = dout[col];
paddle::platform::complex64 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
paddle::platform::complex<float> o = dout[col];
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
......@@ -95,19 +97,21 @@ __global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex64>(
}
template <>
__global__ void SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex128>(
const paddle::platform::complex128* x,
const paddle::platform::complex128* y,
const paddle::platform::complex128* out,
const paddle::platform::complex128* dout, int64_t size,
paddle::platform::complex128* dx, paddle::platform::complex128* dy) {
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
const paddle::platform::complex<double>* x,
const paddle::platform::complex<double>* y,
const paddle::platform::complex<double>* out,
const paddle::platform::complex<double>* dout, int64_t size,
paddle::platform::complex<double>* dx,
paddle::platform::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
paddle::platform::complex128 o = dout[col];
paddle::platform::complex128 y_conj(y[col].real, -y[col].imag);
paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
paddle::platform::complex<double> o = dout[col];
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
......@@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -173,6 +177,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -74,23 +74,13 @@ struct DivGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
};
template <>
struct DivGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
template <>
struct DivGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
template <typename T>
struct DivGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return dout / y_conj;
}
};
......@@ -102,23 +92,13 @@ struct DivGradDY {
}
};
template <>
struct DivGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag);
return -dout * out_div_y_conj;
}
};
template <>
struct DivGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 out_div_y_conj((out / y).real,
template <typename T>
struct DivGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> out_div_y_conj((out / y).real,
-(out / y).imag);
return -dout * out_div_y_conj;
}
......
......@@ -16,8 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -134,9 +133,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -144,9 +143,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -158,9 +157,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -76,31 +75,31 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex64>(
const plat::complex64* x, const plat::complex64* y,
const plat::complex64* out, const plat::complex64* dout, int64_t size,
plat::complex64* dx, plat::complex64* dy) {
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>(
const plat::complex<float>* x, const plat::complex<float>* y,
const plat::complex<float>* out, const plat::complex<float>* dout,
int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex64 o = dout[col];
dx[col] = plat::complex64(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex64(x[col].real, -x[col].imag) * o;
plat::complex<float> o = dout[col];
dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
template <>
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex128>(
const plat::complex128* x, const plat::complex128* y,
const plat::complex128* out, const plat::complex128* dout, int64_t size,
plat::complex128* dx, plat::complex128* dy) {
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
const plat::complex<double>* x, const plat::complex<double>* y,
const plat::complex<double>* out, const plat::complex<double>* dout,
int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
plat::complex128 o = dout[col];
dx[col] = plat::complex128(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex128(x[col].real, -x[col].imag) * o;
plat::complex<double> o = dout[col];
dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
col += blockDim.x * gridDim.x;
}
}
......@@ -133,8 +132,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
......@@ -142,8 +141,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
......@@ -152,6 +153,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex64>,
plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
plat::complex<double>>);
......@@ -132,23 +132,13 @@ struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
};
template <>
struct MulGradDX<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
template <>
struct MulGradDX<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 y_conj(y.real, -y.imag);
template <typename T>
struct MulGradDX<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return dout * y_conj;
}
};
......@@ -158,23 +148,13 @@ struct MulGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
};
template <>
struct MulGradDY<paddle::platform::complex64> {
HOSTDEVICE paddle::platform::complex64 operator()(
paddle::platform::complex64 x, paddle::platform::complex64 y,
paddle::platform::complex64 out, paddle::platform::complex64 dout) const {
paddle::platform::complex64 x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
template <>
struct MulGradDY<paddle::platform::complex128> {
HOSTDEVICE paddle::platform::complex128 operator()(
paddle::platform::complex128 x, paddle::platform::complex128 y,
paddle::platform::complex128 out,
paddle::platform::complex128 dout) const {
paddle::platform::complex128 x_conj(x.real, -x.imag);
template <typename T>
struct MulGradDY<paddle::platform::complex<T>> {
HOSTDEVICE paddle::platform::complex<T> operator()(
paddle::platform::complex<T> x, paddle::platform::complex<T> y,
paddle::platform::complex<T> out,
paddle::platform::complex<T> dout) const {
paddle::platform::complex<T> x_conj(x.real, -x.imag);
return dout * x_conj;
}
};
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
namespace paddle {
namespace operators {
struct DimensionsTransform {
using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
int, int);
int64_t dim_size;
DimVector out_dims;
std::vector<DimVector> in_dims;
private:
// To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
void InputDimensionsExtend(int N, int axis) {
for (auto &in_dim : in_dims) {
int64_t in_idx = 0;
if (in_dim.size() < dim_size) {
DimVector tmp_dim(dim_size, 1);
do {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx];
in_idx++;
axis++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth dimension of input tensor is expected to be equal "
"with"
"the %dth dimension of output tensor %d or 1, but recieved "
"%d.\n",
in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
}
} while (in_idx < in_dim.size());
in_dim.resize(dim_size);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
} else {
do {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth dimension of input tensor is expected to be equal "
"with"
"the %dth dimension of output tensor %d or 1, but recieved "
"%d.\n",
in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
}
} while (in_idx < dim_size);
}
std::reverse(in_dim.begin(), in_dim.end());
}
std::reverse(out_dims.begin(), out_dims.end());
}
template <typename MergeFunctor>
__inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
(*vec)[m_idx - 1] =
std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
std::multiplies<int64_t>());
vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
};
int64_t i = 0;
while (i < dim_size) {
int cnt = 0;
int low_idx = i;
bool equal = true;
do {
merge_func(equal, in_dims, out_dims, i, N);
if (equal) {
i++;
cnt++;
} else {
break;
}
} while (i < dim_size);
if (cnt > 1) {
for (auto &in_dim : in_dims) {
VectorReorganise(&in_dim, low_idx, i);
}
VectorReorganise(&out_dims, low_idx, i);
dim_size -= --cnt;
i -= cnt;
} else if (cnt < 1) {
i++;
}
}
}
public:
explicit DimensionsTransform(
const std::vector<const framework::Tensor *> &ins,
const framework::DDim &dims, int axis) {
const int N = ins.size();
dim_size = dims.size();
out_dims = framework::vectorize<int64_t>(dims);
in_dims.resize(N);
for (int j = 0; j < N; ++j) {
in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
}
InputDimensionsExtend(N, axis);
auto merge_sequential_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) {
for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
auto merge_sequential_one_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) {
equal = in_dims[0][i] == 1;
if (equal) {
for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i];
}
}
};
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears.
MergeFunctor merge_ptr = merge_sequential_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
int min_idx = 0;
int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
std::multiplies<int64_t>());
for (int j = 1; j < N; ++j) {
int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
std::multiplies<int64_t>());
min_val = min_val > temp ? temp : min_val;
min_idx = min_val == temp ? j : min_idx;
}
std::swap(in_dims[0], in_dims[min_idx]);
// To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears.
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[min_idx], in_dims[0]);
}
};
struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides;
std::vector<FastDivMod> divmoders;
private:
// To calculate the strides of each input_tensor.
__inline__ void CalculateStrides(
int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
for (int j = 0; j < N; ++j) {
for (int i = 0; i < dim_size; ++i) {
strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
strides[j][i] =
(i != 0 && strides[j][i] != 0)
? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
std::multiplies<int64_t>())
: strides[j][i];
}
}
}
public:
explicit StridesCalculation(const int64_t &dim_size,
const std::vector<std::vector<int64_t>> &in_dims,
const std::vector<int64_t> &out_dims) {
const auto N = in_dims.size();
divmoders.resize(dim_size);
strides.resize(N, std::vector<uint32_t>(dim_size, 1));
for (int i = 0; i < dim_size; ++i) {
divmoders[i] = FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
};
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWarpper {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;
HOSTDEVICE BroadcastArgsWarpper(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int scalar_cal_offset, Functor func,
const StridesCalculation &offset_calculator)
: scalar_cal_offset(scalar_cal_offset), func(func) {
for (int j = 0; j < ET; ++j) {
in_data[j] = ins[j]->data<InT>();
vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
memcpy(strides[j], offset_calculator.strides[j].data(),
kDims * sizeof(uint32_t));
}
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod));
}
__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
uint32_t offset = 0;
#pragma unroll(kDims)
for (int i = 0; i < kDims; ++i) {
auto fast_divmoder = divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * strides[in_idx][i];
}
return offset;
}
__device__ __forceinline__ void LoadVectorizedDataCommon(
InVecType *vector_args, int tid, int idx) {
*vector_args = vec_in_data[idx][tid];
}
__device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
int tid, int idx) {
int index = tid * VecSize;
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
uint32_t offset = GetOffsetByDivmod(index + i, idx);
scalar_args[i] = in_data[idx][offset];
}
}
__device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
int idx) {
args[idx] = in_data[idx][tid + scalar_cal_offset];
}
__device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
int tid, int idx) {
auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
args[idx] = in_data[idx][offset];
}
__device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
LoadVectorizedDataCommon(vector_args, tid, j);
} else {
LoadVectorizedDataByDivmod(args[j], tid, j);
}
}
}
__device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
if (no_broadcast[j]) {
LoadScalarizedDataCommon(args, tid, j);
} else {
LoadScalarizedDataByDivmod(args, tid, j);
}
}
}
__device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
int tid) {
vec_out_data[tid] = vec_args_out;
}
__device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
out_data[scalar_cal_offset + tid] = args_out;
}
};
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
InT args[ET];
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid);
#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
args_out = broadcast_warpper.func(args);
}
broadcast_warpper.StoreScalarizedData(args_out, tid);
}
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
using OutVecType = CudaAlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_warpper.LoadVectorizedData(args, tid);
#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
}
args_out.val[i] = broadcast_warpper.func(ins);
}
broadcast_warpper.StoreVectorizedData(args_out, tid);
}
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
broadcast_warpper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
broadcast_warpper, tid);
}
}
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
typename Functor>
void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
int numel = out->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
int vec_len = main_tid * VecSize;
auto stream = ctx.stream();
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
const auto offset_calculator = StridesCalculation(
merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);
switch (merge_dims.dim_size) {
case 1: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 2: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
break;
}
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
}
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
axis, func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}
template <ElementwiseType ET, typename InT, typename OutType, typename Functor>
void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
}
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
cuda_ctx, ins, outs, func);
} else {
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT,
OutType>(cuda_ctx, ins, outs, axis,
func);
}
}
} // namespace operators
} // namespace paddle
......@@ -15,8 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/fast_divmod.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
......@@ -29,11 +28,6 @@ namespace operators {
enum ElementwiseType { kUnary = 1, kBinary = 2 };
template <typename T, int Size>
struct alignas(sizeof(T) * Size) CudaAlignedVector {
T val[Size];
};
template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
......@@ -181,7 +175,7 @@ __global__ void ScalarKernel(const InT *__restrict__ in0,
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel(
void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
......@@ -197,6 +191,7 @@ void LaunchElementwiseCudaKernel(
OutT *out = (*outs)[0]->data<OutT>();
// cuda kernel
auto stream = ctx.stream();
switch (vec_size) {
case 4:
VectorizedKernel<ET, 4><<<grid_size, block_size, 0, stream>>>(
......
......@@ -20,8 +20,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct complex128;
struct complex64;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -134,9 +134,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -144,9 +144,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -158,9 +158,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_sub)
.AddCheckpoint(
......
......@@ -14,8 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -103,9 +102,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -115,9 +114,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -129,6 +128,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -173,7 +173,9 @@ void FusedBatchNormActOpMaker::Make() {
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' should be between 0.0 and 0.001."));
"Attr(epsilon) should be between 0.0 and 0.001, "
"but received value is %f.",
epsilon));
});
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册