提交 ec0e8391 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_embedding_to_phi

...@@ -25,7 +25,7 @@ repos: ...@@ -25,7 +25,7 @@ repos:
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: bash ./tools/codestyle/clang_format.hook -i entry: bash ./tools/codestyle/clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps)$
- repo: local - repo: local
hooks: hooks:
- id: cpplint-cpp-source - id: cpplint-cpp-source
...@@ -48,7 +48,7 @@ repos: ...@@ -48,7 +48,7 @@ repos:
name: copyright_checker name: copyright_checker
entry: python ./tools/codestyle/copyright.hook entry: python ./tools/codestyle/copyright.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: | exclude: |
(?x)^( (?x)^(
paddle/utils/.* paddle/utils/.*
......
...@@ -36,7 +36,7 @@ ENDIF() ...@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220215") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220219")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -125,6 +125,9 @@ function(op_library TARGET) ...@@ -125,6 +125,9 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.xpu)
list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu) list(APPEND xpu_kp_cc_srcs ${TARGET}.xpu)
endif() endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps)
list(APPEND xpu_kp_cc_srcs ${TARGET}.kps)
endif()
endif() endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}") string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
...@@ -162,6 +165,8 @@ function(op_library TARGET) ...@@ -162,6 +165,8 @@ function(op_library TARGET)
list(APPEND xpu_cc_srcs ${src}) list(APPEND xpu_cc_srcs ${src})
elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$") elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$")
list(APPEND xpu_kp_cc_srcs ${src}) list(APPEND xpu_kp_cc_srcs ${src})
elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.kps$")
list(APPEND xpu_kp_cc_srcs ${src})
elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
list(APPEND npu_cc_srcs ${src}) list(APPEND npu_cc_srcs ${src})
elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$") elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$")
...@@ -384,7 +389,15 @@ function(op_library TARGET) ...@@ -384,7 +389,15 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for XPU KP # pybind USE_OP_DEVICE_KERNEL for XPU KP
if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0) if (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, KP);\n") foreach(xpu_kp_src ${xpu_kp_cc_srcs})
set(op_name "")
find_register(${xpu_kp_src} "REGISTER_OP_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, KP);\n")
message(STATUS "Building KP Target: ${op_name}")
set(pybind_flag 1)
endif()
endforeach()
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for NPU # pybind USE_OP_DEVICE_KERNEL for NPU
......
...@@ -58,26 +58,32 @@ endfunction() ...@@ -58,26 +58,32 @@ endfunction()
function(kernel_declare TARGET_LIST) function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST}) foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl) file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL string(REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[ \t\r\n\/]*[a-z0-9_]*" first_registry "${kernel_impl}")
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "") if (NOT first_registry STREQUAL "")
# some gpu kernel only can run on cuda, not support rocm, so we add this branch
if (WITH_ROCM)
string(FIND "${first_registry}" "cuda_only" pos)
if(pos GREATER 1)
continue()
endif()
endif()
# parse the first kernel name # parse the first kernel name
string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
string(REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name}")
# append kernel declare into declarations.h # append kernel declare into declarations.h
# TODO(chenweihang): default declare ALL_LAYOUT for each kernel # TODO(chenweihang): default declare ALL_LAYOUT for each kernel
if (${kernel_path} MATCHES "./cpu\/") if (${kernel_path} MATCHES "./cpu\/")
file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./gpu\/") elseif (${kernel_path} MATCHES "./gpu\/")
file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./xpu\/") elseif (${kernel_path} MATCHES "./xpu\/")
file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
else () else ()
# deal with device independent kernel, now we use CPU temporaary # deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
endif() endif()
endif() endif()
endforeach() endforeach()
...@@ -285,9 +291,9 @@ endfunction() ...@@ -285,9 +291,9 @@ endfunction()
function(append_op_util_declare TARGET) function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") string(REGEX MATCH "(PD_REGISTER_BASE_KERNEL_NAME|PD_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}") string(REPLACE "PD_REGISTER_ARG_MAPPING_FN" "PD_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}") string(REPLACE "PD_REGISTER_BASE_KERNEL_NAME" "PD_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");\n") string(APPEND util_declare ");\n")
file(APPEND ${op_utils_header} "${util_declare}") file(APPEND ${op_utils_header} "${util_declare}")
endfunction() endfunction()
......
...@@ -17,7 +17,7 @@ if(NOT WITH_XPU_KP) ...@@ -17,7 +17,7 @@ if(NOT WITH_XPU_KP)
endif() endif()
if(NOT XPU_TOOLCHAIN) if(NOT XPU_TOOLCHAIN)
set(XPU_TOOLCHAIN /workspace/paddle/xpu-demo/XTDK) set(XPU_TOOLCHAIN /workspace/output/XTDK-ubuntu_x86_64)
get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH) get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH)
endif() endif()
if(NOT IS_DIRECTORY ${XPU_TOOLCHAIN}) if(NOT IS_DIRECTORY ${XPU_TOOLCHAIN})
...@@ -102,7 +102,7 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -102,7 +102,7 @@ macro(compile_kernel COMPILE_ARGS)
set(XTDK_DIR ${XPU_TOOLCHAIN}) set(XTDK_DIR ${XPU_TOOLCHAIN})
set(CXX_DIR ${HOST_SYSROOT}) set(CXX_DIR ${HOST_SYSROOT})
set(XPU_CXX_FLAGS -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG ) set(XPU_CXX_FLAGS -fforce-enable-int128 -Wno-error=pessimizing-move -Wno-error=constant-conversion -Wno-error=c++11-narrowing -Wno-error=shift-count-overflow -Wno-error=unused-local-typedef -Wno-error=deprecated-declarations -Wno-deprecated-declarations -std=c++14 -m64 -fPIC -fno-omit-frame-pointer -Wall -Wno-inconsistent-missing-override -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=unused-local-typedefs -Wno-error=ignored-attributes -Wno-error=int-in-bool-context -Wno-error=parentheses -Wno-error=address -Wno-ignored-qualifiers -Wno-ignored-attributes -Wno-parentheses -DNDEBUG )
#include path #include path
get_property(dirs DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) get_property(dirs DIRECTORY ${CMAKE_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
...@@ -127,9 +127,11 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -127,9 +127,11 @@ macro(compile_kernel COMPILE_ARGS)
kernel_build/${kernel_name}.bin.o kernel_build/${kernel_name}.bin.o
COMMAND COMMAND
${CMAKE_COMMAND} -E make_directory kernel_build ${CMAKE_COMMAND} -E make_directory kernel_build
COMMAND
cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.bin.o.sec ${kernel_path}/${kernel_name}.xpu -I. -o kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.xpu
--xpu-device-only -c -v --xpu-device-only -c -v
COMMAND COMMAND
${XTDK_DIR}/bin/xpu2-elfconv kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.bin.o ${XPU_CLANG} --sysroot=${CXX_DIR} ${XTDK_DIR}/bin/xpu2-elfconv kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.bin.o ${XPU_CLANG} --sysroot=${CXX_DIR}
...@@ -148,9 +150,11 @@ macro(compile_kernel COMPILE_ARGS) ...@@ -148,9 +150,11 @@ macro(compile_kernel COMPILE_ARGS)
kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.host.o
COMMAND COMMAND
${CMAKE_COMMAND} -E make_directory kernel_build ${CMAKE_COMMAND} -E make_directory kernel_build
COMMAND
cp ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES} ${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.host.o ${kernel_path}/${kernel_name}.xpu -I. -o kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.xpu
--xpu-host-only -c -v --xpu-host-only -c -v
WORKING_DIRECTORY WORKING_DIRECTORY
${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}
...@@ -185,7 +189,7 @@ macro(xpu_add_library TARGET_NAME) ...@@ -185,7 +189,7 @@ macro(xpu_add_library TARGET_NAME)
# Distinguish .xpu file from other files # Distinguish .xpu file from other files
foreach(cur_xpu_src IN LISTS xpu_srcs_lists) foreach(cur_xpu_src IN LISTS xpu_srcs_lists)
get_filename_component(language_type_name ${cur_xpu_src} EXT) get_filename_component(language_type_name ${cur_xpu_src} EXT)
if(${language_type_name} STREQUAL ".xpu") if(${language_type_name} STREQUAL ".kps")
list(APPEND xpu_kernel_lists ${cur_xpu_src}) list(APPEND xpu_kernel_lists ${cur_xpu_src})
else() else()
list(APPEND cc_kernel_lists ${cur_xpu_src}) list(APPEND cc_kernel_lists ${cur_xpu_src})
......
add_subdirectory(collective)
add_subdirectory(store)
if(NOT WITH_PSCORE) if(NOT WITH_PSCORE)
add_subdirectory(fleet_executor) add_subdirectory(fleet_executor)
return() return()
......
cc_library(processgroup SRCS ProcessGroup.cc DEPS pten pten_api eager_api)
if(WITH_NCCL)
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context pten pten_api eager_api)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime.h>
#include <error.h>
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class EventManager {
public:
EventManager() {}
explicit EventManager(unsigned int flags) : flags_{flags} {}
~EventManager() {
if (is_created_) {
platform::CUDADeviceGuard guard(device_index_);
cudaEventDestroy(event_);
}
}
EventManager(const EventManager&) = delete;
EventManager& operator=(const EventManager&) = delete;
EventManager(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
EventManager& operator=(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
gpuEvent_t GetRawCudaEvent() const { return event_; }
void Record(const paddle::platform::CUDADeviceContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index, device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
platform::CUDADeviceGuard guard(device_index_);
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
}
bool Query() const {
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
} else if (err == cudaErrorNotReady) {
return false;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(err);
return false;
}
}
void Synchronize() const {
if (is_created_) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
}
}
void Block(const paddle::platform::CUDADeviceContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
platform::CUDADeviceGuard guard(device_index_);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
}
}
private:
unsigned int flags_ = cudaEventDefault;
bool is_created_{false};
gpuEvent_t event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::CUDADeviceGuard guard(device_index);
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
is_created_ = true;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class NCCLCommManager {
public:
explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}
NCCLCommManager() : NCCLCommManager(nullptr) {}
~NCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (nccl_comm_) {
platform::dynload::ncclCommDestroy(nccl_comm_);
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks, int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(&(nccl_manager->nccl_comm_),
num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
return nccl_manager;
}
ncclUniqueId GetNcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_id_;
}
ncclComm_t GetNcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_comm_;
}
NCCLCommManager(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(NCCLCommManager&& other) = delete;
NCCLCommManager(NCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(nccl_comm_, other.nccl_comm_);
}
protected:
ncclComm_t nccl_comm_;
ncclUniqueId nccl_id_;
int rank_;
mutable std::mutex mutex_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
namespace paddle {
namespace distributed {
ProcessGroup::Task::Task(int rank, const std::vector<Tensor>& inputTensors,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::~Task() = default;
bool ProcessGroup::Task::IsCompleted() {
std::lock_guard<std::mutex> lock(mutex_);
return is_completed_;
}
bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
return false;
}
void ProcessGroup::Task::Synchronize() {}
ProcessGroup::ProcessGroup(int rank, int size) : rank_(rank), size_(size) {}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"
constexpr auto kWaitTimeout = std::chrono::milliseconds(0);
namespace paddle {
namespace distributed {
using Tensor = paddle::experimental::Tensor;
enum class CommType : std::uint8_t {
BROADCAST = 0,
ALLREDUCE = 1,
ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce
REDUCE = 3,
ALLGATHER = 4,
GATHER = 5,
SCATTER = 6,
REDUCE_SCATTER = 7,
ALLTOALL = 8,
SEND = 9,
RECV = 10,
BARRIER = 11,
UNKNOWN = 100,
};
struct ProcessGroupStrategy {
int nranks_{1};
int local_rank_{0};
std::vector<std::string> trainer_endpoints_{};
std::string current_endpoint_{""};
int nrings_{1};
};
class ProcessGroup {
public:
class Task {
public:
Task(int rank, const std::vector<Tensor>& inputTensors,
CommType opType = CommType::UNKNOWN);
virtual ~Task();
virtual bool IsCompleted();
virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
virtual void Synchronize();
protected:
const int rank_;
CommType comm_type_;
std::mutex mutex_;
bool is_completed_ = false;
};
explicit ProcessGroup(int rank, int size);
virtual ~ProcessGroup() {}
int GetRank() const { return rank_; }
int GetSize() const { return size_; }
virtual const std::string GetBackendName() const = 0;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& /* tensors */,
const AllreduceOptions& = AllreduceOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& /* tensors */,
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
}
protected:
const int rank_;
const int size_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle {
namespace distributed {
static ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, ncclRedOp_t> red_type = {
{ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum},
{ReduceOp::PRODUCT, ncclProd},
};
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(), true,
platform::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum"));
return it->second;
}
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
std::ostringstream oss;
for (auto i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
// Get the list of devices from list of tensors
std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors) {
std::vector<Place> places;
places.reserve(tensors.size());
for (auto& tensor : tensors) {
places.push_back(tensor.inner_place());
}
return places;
}
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places) {
std::string placeList;
for (auto& place : places) {
std::stringstream tmp;
tmp << place;
if (placeList.empty()) {
placeList += tmp.str();
} else {
placeList += "," + tmp.str();
}
}
return placeList;
}
bool CheckTensorsInCudaPlace(const std::vector<Tensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
return t.place() == PlaceType::kGPU;
});
}
void SyncDefaultStream(
const std::vector<Place>& places,
std::vector<EventManager>& ncclEvents, // NOLINT
std::vector<std::unique_ptr<CUDADeviceContext>>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
ncclEvents[i].Record(*dev_ctx[i]);
ncclEvents[i].Block(*default_ctx);
}
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
const std::vector<Tensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(places, rank, comm_type,
inputs);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(const std::vector<Place>& places, int rank,
CommType CommType,
const std::vector<Tensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
ncclComms_.resize(places.size());
}
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
void ProcessGroupNCCL::NCCLTask::SetOutputs(
std::vector<Tensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<Tensor>>(outputs);
}
void ProcessGroupNCCL::NCCLTask::SynchronizeStreams() {
for (size_t i = 0; i < places_.size(); ++i) {
auto* default_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(places_[i]));
default_ctx->WaitEvent(control_events_[i].GetRawCudaEvent());
}
}
bool ProcessGroupNCCL::NCCLTask::IsCompleted() {
for (size_t i = 0; i < places_.size(); ++i) {
if (!control_events_[i].Query()) {
return false;
}
}
return true;
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
SynchronizeStreams();
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
return true;
}
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const ProcessGroupStrategy& strategy,
int rank, int size)
: ProcessGroup(rank, size), strategy_(strategy) {}
void ProcessGroupNCCL::BcastNCCLId(
std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto& ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
}
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else {
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
}
}
void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
int server_fd = -1;
if (rank_ != 0) {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0, server_fd);
}
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false,
platform::errors::PreconditionNotMet(
"Not able to create/get the NCCL Communicator since "
"the GPU place are not known"));
std::vector<std::shared_ptr<NCCLCommManager>> nccl_comms;
nccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
auto& nccl_id = nccl_ids.front();
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(nccl_ids);
VLOG(3) << "init nccl rank: " << strategy_.local_rank_
<< ", nranks: " << strategy_.nranks_ << ", place: " << places_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
std::vector<std::unique_ptr<CUDADeviceContext>> dev_ctx;
dev_ctx.resize(places.size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (size_t i = 0; i < places.size(); ++i) {
platform::CUDADeviceGuard guard(places[i]);
nccl_comms[i] = NCCLCommManager::Create(GetSize(), GetRank(), nccl_id);
dev_ctx[i].reset(new CUDADeviceContext(places[i]));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
std::vector<EventManager> events;
events.resize(places.size());
// These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events));
places_to_ncclcomm_.emplace(places_key, std::move(nccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl());
memory::RecordStream(dense_tensor->Holder(),
places_to_ctx_[key][i]->stream());
}
}
{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream);
}
}
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclAllReduce(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<Tensor>& tensors, const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
const auto root = opts.source_rank * tensors.size() + opts.source_root;
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclBcast(
input_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), root, comm, stream);
},
CommType::BROADCAST);
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
constexpr const char* NCCL_BACKEND_NAME = "NCCL";
namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using CUDAStream = platform::stream::CUDAStream;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
class ProcessGroupNCCL : public ProcessGroup {
public:
class NCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<NCCLTask> {
public:
NCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Tensor>& inputs);
bool IsCompleted();
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<Tensor>& outputs); // NOLINT
virtual ~NCCLTask();
std::vector<EventManager> control_events_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<Tensor>> outputs_;
private:
};
ProcessGroupNCCL(const ProcessGroupStrategy& strategy, int rank, int size);
const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& tensors,
const BroadcastOptions& = BroadcastOptions()) override;
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
const std::vector<Tensor>& inputs);
protected:
ProcessGroupStrategy strategy_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
places_to_ncclcomm_;
std::unordered_map<std::string, std::vector<EventManager>> places_to_events_;
std::unordered_map<std::string,
std::vector<std::unique_ptr<CUDADeviceContext>>>
places_to_ctx_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
Fn fn, CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {
// TODO(shenliang03): To support AVG for reduce
enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT };
struct AllreduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
};
struct BroadcastOptions {
int source_rank = 0;
int source_root = 0;
};
} // namespace distributed
} // namespace paddle
...@@ -52,6 +52,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -52,6 +52,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
input_tensor_ptr = input_tensor->mutable_data<float>(dims, place); input_tensor_ptr = input_tensor->mutable_data<float>(dims, place);
} else if (input_data.dtype == DistModelDataType::INT32) { } else if (input_data.dtype == DistModelDataType::INT32) {
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place); input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place);
} else if (input_data.dtype == DistModelDataType::FLOAT16) {
input_tensor_ptr = input_tensor->mutable_data<float16>(dims, place);
} else { } else {
LOG(ERROR) << "unsupported feed type " << input_data.dtype; LOG(ERROR) << "unsupported feed type " << input_data.dtype;
return false; return false;
...@@ -412,6 +414,8 @@ bool DistModel::PrepareFeedAndFetch() { ...@@ -412,6 +414,8 @@ bool DistModel::PrepareFeedAndFetch() {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT32}); feeds_to_dtype_.insert({var_name, DistModelDataType::INT32});
} else if (real_var->GetDataType() == framework::proto::VarType::INT64) { } else if (real_var->GetDataType() == framework::proto::VarType::INT64) {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT64}); feeds_to_dtype_.insert({var_name, DistModelDataType::INT64});
} else if (real_var->GetDataType() == framework::proto::VarType::FP16) {
feeds_to_dtype_.insert({var_name, DistModelDataType::FLOAT16});
} else { } else {
LOG(ERROR) << "Don't support feed var dtype for: " LOG(ERROR) << "Don't support feed var dtype for: "
<< real_var->GetDataType(); << real_var->GetDataType();
...@@ -503,9 +507,13 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data, ...@@ -503,9 +507,13 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
} else if (type == framework::proto::VarType::INT32) { } else if (type == framework::proto::VarType::INT32) {
rst = FetchResult<int32_t>(fetch, output); rst = FetchResult<int32_t>(fetch, output);
output->dtype = DistModelDataType::INT32; output->dtype = DistModelDataType::INT32;
} else if (type == framework::proto::VarType::FP16) {
rst = FetchResult<float16>(fetch, output);
output->dtype = DistModelDataType::FLOAT16;
} else { } else {
LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only " LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only "
"supports float32, int64 and int32 fetch type for now."; "supports float32, float16, int64 and int32 fetch type "
"for now.";
} }
if (!rst) { if (!rst) {
LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx]; LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx];
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
...@@ -40,6 +41,11 @@ constexpr DistModelDataType DistModelGetDtype<float>() { ...@@ -40,6 +41,11 @@ constexpr DistModelDataType DistModelGetDtype<float>() {
return DistModelDataType::FLOAT32; return DistModelDataType::FLOAT32;
} }
template <>
constexpr DistModelDataType DistModelGetDtype<platform::float16>() {
return DistModelDataType::FLOAT16;
}
class DistModelDataBuf { class DistModelDataBuf {
public: public:
explicit DistModelDataBuf(size_t length) explicit DistModelDataBuf(size_t length)
......
...@@ -31,7 +31,8 @@ struct CommContext { ...@@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id, const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true, bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1, bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false) bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name), : var_name(name),
splited_varnames(names), splited_varnames(names),
epmap(emap), epmap(emap),
...@@ -42,7 +43,9 @@ struct CommContext { ...@@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_), is_sparse(is_sparse_),
is_distributed(is_distributed_), is_distributed(is_distributed_),
table_id(table_id_), table_id(table_id_),
is_tensor_table(is_tensor_table_) {} program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) { CommContext(const CommContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
...@@ -55,7 +58,9 @@ struct CommContext { ...@@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames; origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed; is_distributed = ctx.is_distributed;
table_id = ctx.table_id; table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table; is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
} }
std::string print() const { std::string print() const {
...@@ -78,7 +83,9 @@ struct CommContext { ...@@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse; ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n"; ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n"; ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n"; ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str(); return ss.str();
} }
...@@ -93,7 +100,9 @@ struct CommContext { ...@@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse; bool is_sparse;
bool is_distributed; bool is_distributed;
int table_id; int table_id;
int64_t program_id;
bool is_tensor_table; bool is_tensor_table;
bool is_datanorm_table;
}; };
} // namespace distributed } // namespace distributed
......
cc_library(tcp_store SRCS tcp_store.cc tcp_utils.cc DEPS enforce glog)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/store/tcp_utils.h"
namespace paddle {
namespace distributed {
class Store {
public:
Store() = delete;
explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {}
virtual ~Store() = default;
virtual int64_t add(const std::string& key, int64_t value) = 0;
virtual std::vector<uint8_t> get(const std::string& key) = 0;
virtual void wait(const std::string& key) = 0;
virtual const std::chrono::seconds& timeout() const { return _timeout; }
private:
std::chrono::seconds _timeout;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono>
#include <iostream>
#include <thread>
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
namespace detail {
constexpr int INFTIME = -1;
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket) {
return std::make_unique<MasterDaemon>(socket);
}
MasterDaemon::MasterDaemon(SocketType socket) : _listen_socket(socket) {
_background_thread = std::thread{&MasterDaemon::run, this};
}
MasterDaemon::~MasterDaemon() {
_background_thread.join();
tcputils::close_socket(_listen_socket);
for (SocketType socket : _sockets) {
tcputils::close_socket(socket);
}
}
void MasterDaemon::_do_add(SocketType socket) {
int64_t new_value{};
std::string key = tcputils::receive_string(socket);
new_value = tcputils::receive_value<int64_t>(socket);
std::vector<uint8_t> old_value;
auto it = _store.find(key);
if (it != _store.end()) {
old_value = it->second;
char* buffer = reinterpret_cast<char*>(it->second.data());
size_t len = old_value.size();
new_value += std::stoll(std::string(buffer, len));
}
std::string new_value_str = std::to_string(new_value);
_store[key] =
std::vector<uint8_t>(new_value_str.begin(), new_value_str.end());
VLOG(3) << "TCPStore: new value (" << new_value << ") for key (" << key
<< ").";
tcputils::send_value<int64_t>(socket, new_value);
}
void MasterDaemon::_do_get(SocketType socket) {
std::string key = tcputils::receive_string(socket);
auto iter = _store.find(key);
PADDLE_ENFORCE_NE(
iter, _store.end(),
platform::errors::InvalidArgument("Key %s not found in TCPStore.", key));
std::vector<uint8_t> value = iter->second;
VLOG(3) << "TCPStore: value ("
<< std::stoll(std::string(reinterpret_cast<char*>(value.data()),
value.size()))
<< ") for key (" << key << ").";
tcputils::send_vector<uint8_t>(socket, value);
}
void MasterDaemon::_do_stop(SocketType socket) {
ReplyType value = ReplyType::STOP_WAIT;
_stop = true;
tcputils::send_value<ReplyType>(socket, value);
}
void MasterDaemon::_do_wait(SocketType socket) {
std::string key = tcputils::receive_string(socket);
auto iter = _store.find(key);
auto reply = ReplyType::STOP_WAIT;
if (iter == _store.end()) {
reply = ReplyType::WAITING;
}
VLOG(3) << "TCPStore: wait reply (" << static_cast<int>(reply)
<< ") for key (" << key << ").";
tcputils::send_value<ReplyType>(socket, reply);
}
void MasterDaemon::run() {
std::vector<struct pollfd> fds;
#ifdef _WIN32
fds.push_back({_listen_socket, POLLIN});
#else
fds.push_back({.fd = _listen_socket, .events = POLLIN, .revents = 0});
#endif
while (!_stop) {
for (size_t i = 0; i < fds.size(); i++) {
fds[i].revents = 0;
}
#ifdef _WIN32
::WSAPoll(fds.data(), fds.size(), INFTIME);
#else
::poll(fds.data(), fds.size(), INFTIME);
#endif
if (fds[0].revents != 0) {
auto socket = tcputils::tcp_accept(_listen_socket);
_sockets.emplace_back(socket);
#ifdef _WIN32
fds.push_back({socket, POLLIN});
#else
fds.push_back({.fd = socket, .events = POLLIN, .revents = 0});
#endif
}
for (size_t i = 1; i < fds.size(); i++) {
if (fds[i].revents == 0) {
continue;
}
Command command = tcputils::receive_value<Command>(fds[i].fd);
VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command) << ".";
switch (command) {
case Command::ADD:
_do_add(fds[i].fd);
break;
case Command::GET:
_do_get(fds[i].fd);
break;
case Command::WAIT:
_do_wait(fds[i].fd);
break;
case Command::STOP:
_do_stop(fds[i].fd);
break;
}
}
}
}
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port) {
int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET);
auto server = std::make_unique<TCPServer>();
server->_master_daemon = MasterDaemon::start(socket);
return server;
}
std::unique_ptr<TCPClient> TCPClient::connect(const std::string host,
uint16_t port) {
int socket = tcputils::tcp_connect(host, std::to_string(port), AF_INET);
return std::make_unique<TCPClient>(socket);
}
void TCPClient::send_command_for_key(Command type, const std::string& key) {
tcputils::send_value<Command>(_socket, type);
if (key.empty()) {
return;
}
tcputils::send_string(_socket, key);
}
template <typename T>
void TCPClient::send_value(const T& value) {
tcputils::send_bytes<T>(_socket, &value, 1);
}
template <typename T>
T TCPClient::receive_value() {
T res;
tcputils::receive_bytes<T>(_socket, &res, 1);
return res;
}
template <typename T>
void TCPClient::send_vector(const std::vector<T>& value) {
tcputils::send_vector<T>(_socket, value);
}
template <typename T>
std::vector<T> TCPClient::receive_vector() {
return tcputils::receive_vector<T>(_socket);
}
} // namespace detail
TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
size_t num_workers, std::chrono::seconds timeout)
: Store(timeout), _is_master(is_master), _num_workers(num_workers) {
if (_is_master) {
_server = detail::TCPServer::create(port);
}
_client = detail::TCPClient::connect(host, port);
waitWorkers();
}
void TCPStore::waitWorkers() {
if (_num_workers == 0) {
return;
}
add(_init_key, 1);
if (_server) {
auto begin = std::chrono::steady_clock::now();
do {
auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers;
if (completed >= _num_workers) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != tcputils::kNoTimeout && elapsed > _timeout) {
PADDLE_ENFORCE_EQ(
completed, _num_workers,
platform::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
}
VLOG(3) << "TCPStore initialized.";
}
int64_t TCPStore::add(const std::string& key, int64_t value) {
_client->send_command_for_key(Command::ADD, _key_prefix + key);
_client->send_value<std::int64_t>(value);
return _client->receive_value<std::int64_t>();
}
std::vector<uint8_t> TCPStore::get(const std::string& key) {
wait(key);
_client->send_command_for_key(Command::GET, _key_prefix + key);
VLOG(3) << "TCPStore get.";
return _client->receive_vector<uint8_t>();
}
void TCPStore::wait(const std::string& key) {
ReplyType reply;
do {
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>();
std::this_thread::sleep_for(std::chrono::milliseconds(500));
} while (reply != ReplyType::STOP_WAIT);
}
TCPStore::~TCPStore() {
_client->send_command_for_key(Command::STOP, "");
ReplyType ret = _client->receive_value<ReplyType>();
PADDLE_ENFORCE_EQ(ret, ReplyType::STOP_WAIT,
platform::errors::InvalidArgument(
"The reply for TCPStore destructure must be 0."));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
#include <unordered_map>
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
namespace paddle {
namespace distributed {
enum class ReplyType { WAITING, STOP_WAIT };
enum class Command { ADD, GET, WAIT, STOP };
namespace detail {
class MasterDaemon {
public:
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket);
MasterDaemon() = delete;
explicit MasterDaemon(SocketType listen_socket);
~MasterDaemon();
private:
void run();
void _do_add(SocketType socket);
void _do_wait(SocketType socket);
void _do_get(SocketType socket);
void _do_stop(SocketType socket);
SocketType _listen_socket;
std::vector<SocketType> _sockets;
std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{};
bool _stop = false;
};
class TCPServer {
public:
TCPServer() = default;
static std::unique_ptr<TCPServer> create(std::uint16_t port);
private:
std::unique_ptr<MasterDaemon> _master_daemon;
};
class TCPClient {
public:
explicit TCPClient(SocketType socket) : _socket{socket} {}
static std::unique_ptr<TCPClient> connect(const std::string host,
uint16_t port);
~TCPClient() { tcputils::close_socket(_socket); }
void send_command_for_key(Command type, const std::string& key);
template <typename T>
void send_value(const T& value);
template <typename T>
void send_vector(const std::vector<T>& value);
template <typename T>
std::vector<T> receive_vector();
template <typename T>
T receive_value();
private:
SocketType _socket;
};
} // namespace detail
class TCPStore : public Store {
public:
static constexpr std::uint16_t kDefaultPort = 6170;
explicit TCPStore(std::string host, uint16_t port = kDefaultPort,
bool is_master = false, size_t num_workers = 1,
std::chrono::seconds timeout = tcputils::kDefaultTimeout);
~TCPStore();
int64_t add(const std::string& key, int64_t value) override;
std::vector<uint8_t> get(const std::string& key) override;
void wait(const std::string& key) override;
private:
void waitWorkers();
std::unique_ptr<detail::TCPServer> _server;
std::unique_ptr<detail::TCPClient> _client;
const std::string _init_key = "init/";
const std::string _key_prefix = "/";
std::chrono::seconds _timeout;
bool _is_master;
int _num_workers;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include <cerrno>
#include <cstring>
#include <thread>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
namespace tcputils {
std::error_code socket_error() {
#ifdef _WIN32
return std::error_code{::WSAGetLastError(), std::generic_category()};
#else
return std::error_code{errno, std::generic_category()};
#endif
}
void close_socket(SocketType socket) {
#ifdef _WIN32
::closesocket(socket);
#else
::close(socket);
#endif
}
::addrinfo* get_addr_info(const std::string host, const std::string port,
int ai_flags, int family) {
::addrinfo hints{}, *res;
hints.ai_flags = ai_flags;
hints.ai_family = family;
hints.ai_socktype = SOCK_STREAM;
const char* node = host.empty() ? nullptr : host.c_str();
int n;
n = ::getaddrinfo(node, port.c_str(), &hints, &res);
const char* gai_err = ::gai_strerror(n);
const char* proto =
(family == AF_INET ? "IPv4" : family == AF_INET6 ? "IPv6" : "");
PADDLE_ENFORCE_EQ(
n, 0, platform::errors::InvalidArgument(
"%s network %s:%s cannot be obtained. Details: %s.", proto,
host, port, gai_err));
return res;
}
void free_addr_info(::addrinfo* hint) {
PADDLE_ENFORCE_NOT_NULL(
hint, platform::errors::InvalidArgument(
"The parameter for free_addr_info cannot be null."));
::freeaddrinfo(hint);
}
SocketType tcp_connect(const std::string host, const std::string port,
int family, std::chrono::seconds timeout) {
int ai_flags = AI_NUMERICSERV | AI_V4MAPPED | AI_ALL;
::addrinfo* res = get_addr_info(host, port, ai_flags, family);
SocketType sockfd = -1;
bool retry = true;
auto deadline = std::chrono::steady_clock::now() + timeout;
do {
for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) {
sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
PADDLE_ENFORCE_GT(sockfd, 0, platform::errors::InvalidArgument(
"Create socket to connect %s:%s failed. "
"Details: %s. ",
host, port, socket_error().message()));
if (::connect(sockfd, cur->ai_addr, cur->ai_addrlen) == 0) {
retry = false;
break;
}
VLOG(0) << "Retry to connect to " << host << ":" << port
<< " while the server is not yet listening.";
close_socket(sockfd);
sockfd = -1;
std::this_thread::sleep_for(kDelay);
if (timeout != kNoTimeout &&
std::chrono::steady_clock::now() >= deadline) {
retry = false;
break;
}
}
if (timeout != kNoTimeout && std::chrono::steady_clock::now() >= deadline) {
retry = false;
}
} while (retry);
free_addr_info(res);
PADDLE_ENFORCE_GT(sockfd, 0,
platform::errors::InvalidArgument(
"Network %s:%s cannot be connected.", host, port));
VLOG(0) << "Successfully connected to " << host << ":" << port;
return sockfd;
}
SocketType tcp_listen(const std::string host, const std::string port,
int family) {
int ai_flags = AI_PASSIVE | AI_NUMERICSERV;
::addrinfo* res = get_addr_info(host, port, ai_flags, family);
::addrinfo* cur = res;
SocketType sockfd{};
std::string node = host.empty() ? "IP_ANY" : host;
while (cur) {
sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
if (sockfd < 0) {
VLOG(0) << "Cannot create socket on " << node << ":" << port
<< ". Details: " << socket_error().message();
cur = cur->ai_next;
continue;
}
int on = 1;
#ifdef _WIN32
int ret = ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR,
reinterpret_cast<char*>(&on), sizeof(on));
#else
int ret = ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
#endif
if (ret < 0) {
VLOG(0) << "Set the address reuse option failed on the server.";
}
if (::bind(sockfd, res->ai_addr, res->ai_addrlen) == 0) {
break;
}
close_socket(sockfd);
sockfd = -1;
cur = cur->ai_next;
}
PADDLE_ENFORCE_GT(sockfd, 0,
platform::errors::InvalidArgument(
"Bind network on %s:%s failedd.", node, port));
::listen(sockfd, LISTENQ);
VLOG(0) << "The server starts to listen on " << node << ":" << port;
return sockfd;
}
SocketType tcp_accept(SocketType socket) {
::sockaddr_storage addr_s{};
::socklen_t addr_len = sizeof(addr_s);
SocketType new_socket =
::accept(socket, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len);
PADDLE_ENFORCE_GT(
new_socket, 0,
platform::errors::InvalidArgument(
"The server failed to accept a new connection. Details: %s.",
socket_error().message()));
#ifndef _WIN32
::fcntl(new_socket, F_SETFD, FD_CLOEXEC);
#endif
auto value = 1;
#ifdef _WIN32
::setsockopt(new_socket, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<const char*>(&value), sizeof(value));
#else
::setsockopt(new_socket, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value));
#endif
return new_socket;
}
void send_string(SocketType socket, const std::string& s) {
std::string::size_type size = s.size();
send_bytes<std::string::size_type>(socket, &size, 1);
send_bytes<const char>(socket, s.data(), size);
}
std::string receive_string(SocketType socket) {
std::string::size_type size;
receive_bytes<std::string::size_type>(socket, &size, 1);
std::vector<char> v(size);
receive_bytes<char>(socket, v.data(), size);
return std::string(v.data(), v.size());
}
} // namespace tcputils
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#else
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <chrono>
#include <iostream>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
// Utility functions for TCP socket.
namespace paddle {
namespace distributed {
#ifdef _WIN32
using SocketType = SOCKET;
#else
using SocketType = int;
#endif
namespace tcputils {
constexpr int LISTENQ = 2048;
constexpr std::chrono::seconds kDelay = std::chrono::seconds(3);
constexpr std::chrono::seconds kNoTimeout = std::chrono::seconds::zero();
constexpr std::chrono::seconds kDefaultTimeout = std::chrono::seconds(360);
std::error_code socket_error();
void close_socket(SocketType socket);
::addrinfo* get_addr_info(const std::string host, const std::string port,
int ai_flags, int family);
void free_addr_info(::addrinfo*);
SocketType tcp_connect(const std::string host, const std::string port,
int family, std::chrono::seconds timeout = kNoTimeout);
SocketType tcp_listen(const std::string host, const std::string port,
int family);
SocketType tcp_accept(SocketType socket);
void send_string(SocketType socket, const std::string& s);
std::string receive_string(SocketType socket);
template <typename T>
void send_bytes(SocketType socket, const T* buffer, size_t len) {
size_t to_send = len * sizeof(T);
if (to_send == 0) {
return;
}
auto ptr = reinterpret_cast<const char*>(buffer);
while (to_send > 0) {
auto byte_sent = ::send(socket, ptr, to_send, 0);
PADDLE_ENFORCE_GT(byte_sent, 0, platform::errors::InvalidArgument(
"TCP send error. Details: %s.",
socket_error().message()));
to_send -= byte_sent;
ptr += byte_sent;
}
}
template <typename T>
void receive_bytes(SocketType socket, T* buffer, size_t len) {
size_t to_recv = len * sizeof(T);
if (to_recv == 0) {
return;
}
auto ptr = reinterpret_cast<char*>(buffer);
while (to_recv > 0) {
auto byte_received = ::recv(socket, ptr, to_recv, 0);
PADDLE_ENFORCE_GT(byte_received, 0, platform::errors::InvalidArgument(
"TCP receive error. Details: %s.",
socket_error().message()));
to_recv -= byte_received;
ptr += byte_received;
}
}
template <typename T>
void send_vector(SocketType socket, const std::vector<T>& v) {
size_t size = v.size();
send_bytes<size_t>(socket, &size, 1);
send_bytes<T>(socket, v.data(), size);
}
template <typename T>
std::vector<T> receive_vector(SocketType socket) {
size_t size;
receive_bytes<size_t>(socket, &size, 1);
std::vector<T> res(size);
receive_bytes<T>(socket, res.data(), size);
return res;
}
template <typename T>
void send_value(SocketType socket, const T& v) {
send_bytes<T>(socket, &v, 1);
}
template <typename T>
T receive_value(SocketType socket) {
T v;
receive_bytes<T>(socket, &v, 1);
return v;
}
} // namespace tcputils
} // namespace distributed
} // namespace paddle
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include "glog/logging.h" #include "glog/logging.h"
namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
const paddle::experimental::Tensor& t) { const paddle::experimental::Tensor& t) {
if (!tensor->defined() || !tensor->initialized()) { if (!tensor->defined() || !tensor->initialized()) {
...@@ -36,17 +38,10 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -36,17 +38,10 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
} }
namespace egr {
void GradNodeAccumulation::RetainGrad(
const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook) {
retain_grad_hook_ = hook;
}
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()( operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1, PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"GradNodeAccumulation should take exactly 1 grad tensor" "GradNodeAccumulation should take exactly 1 grad tensor"
...@@ -58,17 +53,18 @@ operator()( ...@@ -58,17 +53,18 @@ operator()(
"However received: %d in slot %d .", "However received: %d in slot %d .",
grads[0].size(), 0)); grads[0].size(), 0));
// Apply Gradient Hooks // Apply Gradient Hooks
paddle::experimental::Tensor grad_out;
if (GradientHooksRegistered()) { if (GradientHooksRegistered()) {
std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads = std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
ApplyGradientHooks(grads); ApplyGradientHooks(grads);
// TODO(jiabin): It's little weird grad_out = hooked_grads[0][0];
CopyOrAddTensor(&accumulated_grad, hooked_grads[0][0]);
} else { } else {
CopyOrAddTensor(&accumulated_grad, grads[0][0]); grad_out = grads[0][0];
} }
if (retain_grad_hook_ != nullptr) { if (!weak_grad_.expired()) {
retain_grad_hook_(accumulated_grad); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out);
} }
// Apply Reduce Hooks // Apply Reduce Hooks
...@@ -76,7 +72,7 @@ operator()( ...@@ -76,7 +72,7 @@ operator()(
ApplyReduceHooks(); ApplyReduceHooks();
} }
return {{accumulated_grad}}; return {{grad_out}};
} }
void GradNodeAccumulation::RegisterReduceHook( void GradNodeAccumulation::RegisterReduceHook(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
namespace egr { namespace egr {
...@@ -21,7 +22,10 @@ namespace egr { ...@@ -21,7 +22,10 @@ namespace egr {
class GradNodeAccumulation : public GradNodeBase { class GradNodeAccumulation : public GradNodeBase {
public: public:
// Constructor: configure fwd input tensors to grad node // Constructor: configure fwd input tensors to grad node
GradNodeAccumulation() : GradNodeBase(1, 1) { SetDefaultGradInOutMeta(); } explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
weak_grad_ = meta->WeakGrad();
SetDefaultGradInOutMeta();
}
~GradNodeAccumulation() override = default; ~GradNodeAccumulation() override = default;
...@@ -30,10 +34,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -30,10 +34,7 @@ class GradNodeAccumulation : public GradNodeBase {
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
override; override;
void RetainGrad(const std::function<paddle::experimental::Tensor( std::string name() { return "GradNodeAccumulation"; }
const paddle::experimental::Tensor&)>& hook);
paddle::experimental::Tensor* Grad() { return &accumulated_grad; }
/** /**
* Register ReduceHook * Register ReduceHook
...@@ -47,7 +48,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -47,7 +48,7 @@ class GradNodeAccumulation : public GradNodeBase {
void ApplyReduceHooks(); void ApplyReduceHooks();
private: private:
paddle::experimental::Tensor accumulated_grad; std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
std::function<paddle::experimental::Tensor( std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)> const paddle::experimental::Tensor&)>
......
...@@ -52,9 +52,15 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, ...@@ -52,9 +52,15 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
} }
} }
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { static void RetainGradForRegularNode(
// TODO(jiabin): Support More Tensor type here const paddle::experimental::Tensor& tensor) {
AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor); AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor);
if (meta->RetainGrads()) {
return;
} else {
meta->SetRetainGrads(true);
}
std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor = std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor =
meta->WeakGrad(); meta->WeakGrad();
...@@ -70,12 +76,8 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { ...@@ -70,12 +76,8 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
grad_tensor->set_impl(t.impl()); grad_tensor->set_impl(t.impl());
return *grad_tensor.get(); return *grad_tensor.get();
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
"Detected uninitialized variable, causing segmentation " return paddle::experimental::Tensor();
"fault "
"inside the hook."
"Tensor has to be initialized while we need to set it."
"please check tensor initialization status."));
} }
} else { } else {
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
...@@ -83,21 +85,17 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { ...@@ -83,21 +85,17 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
} }
}; };
if (IsLeafTensor(tensor)) { // Append to GradientHooks
// Add RetainGrad as PostHook to AccumulationNode RegisterGradientHookForTensor(tensor, hook);
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor); }
PADDLE_ENFORCE(
grad_node.get() != nullptr,
paddle::platform::errors::Fatal("Detected NULL grad_node"
"Leaf tensor should have had grad_node "
"with type: GradNodeAccumulation"));
auto accumulation_grad_node =
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node);
accumulation_grad_node->RetainGrad(hook);
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
if (IsLeafTensor(tensor)) {
// Leaf tensor's grad will always be retained
// Refer to implementation of AccumulationNode for more details
return;
} else { } else {
// Append to GradientHooks RetainGradForRegularNode(tensor);
RegisterGradientHookForTensor(tensor, hook);
} }
} }
......
...@@ -47,7 +47,7 @@ paddle::experimental::Tensor CreateTensorWithValue( ...@@ -47,7 +47,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
auto meta = EagerUtils::autograd_meta(&out); auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) { if (is_leaf) {
auto accumulation_node = std::make_shared<GradNodeAccumulation>(); auto accumulation_node = std::make_shared<GradNodeAccumulation>(meta);
meta->SetGradNode(accumulation_node); meta->SetGradNode(accumulation_node);
meta->SetStopGradient(false); meta->SetStopGradient(false);
} }
......
...@@ -554,6 +554,21 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -554,6 +554,21 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
return true; return true;
} }
static bool BeSameAsInput(const std::string& output_name,
const std::set<std::string>& input_names) {
if (output_name.size() < 4) {
return false;
}
if (output_name.substr(output_name.size() - 3, 3) == "Out") {
if (input_names.count(output_name.substr(0, output_name.size() - 3))) {
return true;
}
}
return false;
}
/* --------------------------------------- */ /* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */ /* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */ /* --------------------------------------- */
...@@ -1016,33 +1031,20 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1016,33 +1031,20 @@ static std::string GenerateGradNodeCreationContent(
const std::string& output_name = output.name(); const std::string& output_name = output.name();
const std::string& output_autograd_name = "p_autograd_" + output_name; const std::string& output_autograd_name = "p_autograd_" + output_name;
// Skip Intermediate Tensor
if (output.duplicable()) { if (output.duplicable()) {
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE = const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = " " std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::autograd_meta(&%s);\n"; "egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf( get_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name); GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
if (op_passing_outs_map[op_type].count(output_name)) {
const std::string output_var_args_name = output_name + "Var";
const char* FWD_OUT_SYNC_BACK_TEMPLATE =
" egr::EagerUtils::OverwriteOutputs(%s, %s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
FWD_OUT_SYNC_BACK_TEMPLATE, output_name, output_var_args_name);
}
} else { } else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = " " egr::AutogradMeta* %s = "
"egr::EagerUtils::autograd_meta(&%s);\n"; "egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf( get_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name); GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
if (op_passing_outs_map[op_type].count(output_name)) {
const std::string output_var_args_name = output_name + "Var";
const char* FWD_OUT_SYNC_BACK_TEMPLATE =
" egr::EagerUtils::OverwriteOutputs(%s, %s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
FWD_OUT_SYNC_BACK_TEMPLATE, output_name, output_var_args_name);
}
} }
} }
VLOG(6) << "Generated outputs autograd_meta"; VLOG(6) << "Generated outputs autograd_meta";
...@@ -1145,6 +1147,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1145,6 +1147,8 @@ static std::string GenerateGradNodeCreationContent(
const std::string& output_autograd_name = "p_autograd_" + output_name; const std::string& output_autograd_name = "p_autograd_" + output_name;
size_t output_position = fwd_outputs_name_pos_map.at(output_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name);
// Intermediate Tensor does not require SetHistory, nor RetainGrad
if (output.duplicable()) { if (output.duplicable()) {
pass_stop_gradient_args += ", &" + output_autograd_name; pass_stop_gradient_args += ", &" + output_autograd_name;
const char* SET_OUT_RANK_TEMPLATE = const char* SET_OUT_RANK_TEMPLATE =
...@@ -1180,11 +1184,13 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1180,11 +1184,13 @@ static std::string GenerateGradNodeCreationContent(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position); SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
} }
VLOG(6) << "Generated Call RetainGradForTensor"; if (!output.intermediate()) {
const char* RETAIN_GRAD_TEMPLATE = VLOG(6) << "Generated Call RetainGradForTensor";
" egr::EagerUtils::CheckAndRetainGrad(%s);\n"; const char* RETAIN_GRAD_TEMPLATE =
grad_node_creation_str += " egr::EagerUtils::CheckAndRetainGrad(%s);\n";
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, output_name); grad_node_creation_str +=
paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, output_name);
}
} }
VLOG(6) << "Generated SetGradIn/OutMeta"; VLOG(6) << "Generated SetGradIn/OutMeta";
...@@ -1324,19 +1330,21 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1324,19 +1330,21 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += "\n"; generated_function_body += "\n";
// Handle Dispensable Inputs // Handle Dispensable Inputs
std::set<std::string> input_names;
for (const proto::OpProto::Var& input : in_vars) { for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
input_names.insert(input_name);
if (input.dispensable()) { if (input.dispensable()) {
if (input.duplicable()) { if (input.duplicable()) {
const char* FWD_INS_CONTENT_TEMPLATE = const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.size() > 0) " " if(%s.size() > 0) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s)\n;"; "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
generated_function_body += paddle::string::Sprintf( generated_function_body += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
} else { } else {
const char* FWD_INS_CONTENT_TEMPLATE = const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.initialized()) " " if(%s.initialized()) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s)\n;"; "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
generated_function_body += paddle::string::Sprintf( generated_function_body += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
} }
...@@ -1372,11 +1380,21 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1372,11 +1380,21 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
core_ops_args_type_info[op_type].push_back("tensor"); core_ops_args_type_info[op_type].push_back("tensor");
} }
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
if (BeSameAsInput(output_name, input_names)) {
if (!output.dispensable()) {
std::string input_name =
output_name.substr(0, output_name.size() - 3);
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, input_name);
}
} else {
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
}
core_ops_args_info[op_type].push_back(output_var_name); core_ops_args_info[op_type].push_back(output_var_name);
} else { } else {
...@@ -1415,6 +1433,23 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1415,6 +1433,23 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += outs_map_str; generated_function_body += outs_map_str;
generated_function_body += "\n"; generated_function_body += "\n";
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
if (op_passing_outs_map[op_type].count(output_name)) {
if (BeSameAsInput(output_name, input_names)) {
if (output.dispensable()) {
std::string input_name =
output_name.substr(0, output_name.size() - 3);
const char* FWD_OUTS_CONTENT_TEMPLATE =
" if (ins.count(\"%s\")) outs[\"%s\"] = ins[\"%s\"];\n";
generated_function_body += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, input_name, output_name, input_name);
}
}
}
}
generated_function_body += "\n";
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs // [Generation] Get Attrs
...@@ -1448,33 +1483,61 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1448,33 +1483,61 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::string output_varname = LegalizeVariableName(output_name); std::string output_varname = LegalizeVariableName(output_name);
if (output.duplicable()) { if (output.duplicable()) {
const char* FWD_OUT_TENSORS_TEMPLATE = if (op_passing_outs_map[op_type].count(output_name)) {
" std::vector<paddle::experimental::Tensor> %s = " if (output.dispensable()) {
"egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; const char* FWD_OUT_TENSORS_TEMPLATE =
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, " std::vector<paddle::experimental::Tensor> %s;\n"
output_varname, output_name); " if (outs.count(\"%s\")) "
"egr::EagerUtils::GetOutputs(outs[\"%s\"], %s);\n"
" egr::EagerUtils::Output2Result(%s, &%s);\n";
out_tensor_str = paddle::string::Sprintf(
FWD_OUT_TENSORS_TEMPLATE, output_varname, output_name,
output_name, output_var_args_name, output_var_args_name,
output_varname);
} else {
const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<paddle::experimental::Tensor> %s;\n"
" egr::EagerUtils::GetOutputs(outs[\"%s\"], %s);\n"
" egr::EagerUtils::Output2Result(%s, &%s);\n";
out_tensor_str = paddle::string::Sprintf(
FWD_OUT_TENSORS_TEMPLATE, output_varname, output_name,
output_var_args_name, output_var_args_name, output_varname);
}
} else {
const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<paddle::experimental::Tensor> %s;\n"
" egr::EagerUtils::GetOutputs(outs[\"%s\"], &%s);\n";
out_tensor_str =
paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, output_varname,
output_name, output_varname);
}
return_types[return_position] = return_types[return_position] =
"std::vector<paddle::experimental::Tensor>"; "std::vector<paddle::experimental::Tensor>";
if (op_passing_outs_map[op_type].count(output_name) &&
bwd_info.GenerateForwardOnly()) {
const char* FWD_OUT_SYNC_BACK_TEMPLATE =
" egr::EagerUtils::OverwriteOutputs(outs[\"%s\"], %s);\n";
out_tensor_str += paddle::string::Sprintf(
FWD_OUT_SYNC_BACK_TEMPLATE, output_name, output_var_args_name);
}
} else { } else {
const char* FWD_OUT_TENSOR_TEMPLATE = if (op_passing_outs_map[op_type].count(output_name)) {
" paddle::experimental::Tensor %s = " if (output.dispensable()) {
"egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n"; const char* FWD_OUT_TENSOR_TEMPLATE =
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, " if (outs.count(\"%s\")) "
output_varname, output_name); "egr::EagerUtils::GetOutput(outs[\"%s\"][0], %s);\n"
" paddle::experimental::Tensor& %s = *%s;\n";
if (op_passing_outs_map[op_type].count(output_name) && out_tensor_str = paddle::string::Sprintf(
bwd_info.GenerateForwardOnly()) { FWD_OUT_TENSOR_TEMPLATE, output_name, output_name,
const char* FWD_OUT_SYNC_BACK_TEMPLATE = output_var_args_name, output_varname, output_var_args_name);
" egr::EagerUtils::OverwriteOutputs(outs[\"%s\"][0], %s);\n"; } else {
out_tensor_str += paddle::string::Sprintf( const char* FWD_OUT_TENSOR_TEMPLATE =
FWD_OUT_SYNC_BACK_TEMPLATE, output_name, output_var_args_name); " egr::EagerUtils::GetOutput(outs[\"%s\"][0], %s);\n"
" paddle::experimental::Tensor& %s = *%s;\n";
out_tensor_str = paddle::string::Sprintf(
FWD_OUT_TENSOR_TEMPLATE, output_name, output_var_args_name,
output_varname, output_var_args_name);
}
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n"
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n";
out_tensor_str =
paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_varname,
output_name, output_varname);
} }
return_types[return_position] = "paddle::experimental::Tensor"; return_types[return_position] = "paddle::experimental::Tensor";
} }
...@@ -1494,6 +1557,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1494,6 +1557,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
GenerateGradNodeCreationContent(fwd_info, bwd_info); GenerateGradNodeCreationContent(fwd_info, bwd_info);
generated_function_body += grad_node_creation_body_str; generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n"; generated_function_body += "\n";
// [Generation] Call RetainGradForTensor // [Generation] Call RetainGradForTensor
VLOG(6) << "Generated GradNode Creation codes"; VLOG(6) << "Generated GradNode Creation codes";
} }
...@@ -1588,12 +1652,25 @@ static std::string GenerateSingleOpBase( ...@@ -1588,12 +1652,25 @@ static std::string GenerateSingleOpBase(
const std::string& attrs_name = "attrs_map" + std::to_string(*outs_size); const std::string& attrs_name = "attrs_map" + std::to_string(*outs_size);
// [Generation] Get Ins Map // [Generation] Get Ins Map
std::unordered_set<std::string> dispensable_input_name_set;
for (const auto& in : in_vars) {
if (in.dispensable()) dispensable_input_name_set.insert(in.name());
}
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
std::string ins_contents_str = ""; std::string ins_contents_str = "";
for (auto iter : grad_ins) { for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first; const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) { if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor // Fwd Tensor
const std::string& fwd_name =
grad_ins_fwd_slotname_map.at(grad_input_name);
if (dispensable_input_name_set.count(fwd_name)) {
continue;
}
std::string struct_fwd_input_name = std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE = const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
...@@ -1634,14 +1711,41 @@ static std::string GenerateSingleOpBase( ...@@ -1634,14 +1711,41 @@ static std::string GenerateSingleOpBase(
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str; generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map"; for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
// [Generation] Get Outs Map if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
std::unordered_set<std::string> duplicable_input_name_set; // Fwd Tensor
for (const auto& in : in_vars) { const std::string& fwd_name =
if (in.duplicable()) duplicable_input_name_set.insert(in.name()); grad_ins_fwd_slotname_map.at(grad_input_name);
if (dispensable_input_name_set.count(fwd_name)) {
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
if (duplicable_input_name_set.count(fwd_name)) {
const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE =
" if(this->%s.size() > 0) %s[\"%s\"] = "
"egr::EagerUtils::TrySyncToVars(egr::EagerUtils::"
"RecoverTensorWrapper(&this->%s, nullptr));\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, struct_fwd_input_name,
ins_name, grad_input_name, struct_fwd_input_name);
} else {
const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE =
" auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s, "
"nullptr);\n if(%s.initialized()) %s[\"%s\"] = "
"egr::EagerUtils::TrySyncToVars(%s);\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name,
struct_fwd_input_name, grad_input_name, ins_name, grad_input_name,
grad_input_name);
}
}
}
} }
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::string outs_contents_str = ""; std::string outs_contents_str = "";
for (auto iter : grad_outs) { for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first; const std::string& grad_output_name = iter.first;
...@@ -1987,6 +2091,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -1987,6 +2091,7 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n" "%s\n"
" // SetAttrMap\n" " // SetAttrMap\n"
"%s\n" "%s\n"
" std::string name() { return \"GradNode%s\"; }\n"
"\n" "\n"
" private:\n" " private:\n"
" // TensorWrappers\n" " // TensorWrappers\n"
...@@ -2085,8 +2190,8 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2085,8 +2190,8 @@ static std::string GenerateGradNodeHeaderContents(
std::string grad_node_str = paddle::string::Sprintf( std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type,
set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str, set_tensor_wrappers_str, set_attr_map_str, op_type,
attr_members_str); tensor_wrapper_members_str, attr_members_str);
return grad_node_str; return grad_node_str;
} }
......
...@@ -127,6 +127,40 @@ def ReadBwdFile(filepath): ...@@ -127,6 +127,40 @@ def ReadBwdFile(filepath):
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...]
"""
Check whether intermediate_outputs are positioned
at the very end of forward_returns_list
"""
intermediate_positions = range(
len(forward_returns_list) - len(intermediate_outputs),
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
def ParseDispensable(string):
# string: "X, Y"
return [v.strip() for v in string.split(",")]
def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]
def ParseNoNeedBuffer(string):
# string: "x, y"
no_need_buffer_set = set()
for name in string.split(","):
no_need_buffer_set.add(name.strip())
return no_need_buffer_set
def ParseYamlArgs(string): def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
...@@ -397,7 +431,7 @@ def SlotNameMatching(backward_inputs_list, backward_returns_list, ...@@ -397,7 +431,7 @@ def SlotNameMatching(backward_inputs_list, backward_returns_list,
def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
backward_attrs_list): backward_attrs_list, no_need_buffer_set):
# Inputs: # Inputs:
# fwd_api_name = "" # fwd_api_name = ""
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
...@@ -410,15 +444,20 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -410,15 +444,20 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
set_tensor_wrapper_methods_str = "" set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = "" tensor_wrapper_members_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set:
no_need_buffer = "true"
else:
no_need_buffer = "false"
tensor_wrapper_name = GetSavedName(tname) tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved); {} = egr::TensorWrapper({}, full_reserved, {});
}} }}
""" """
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname) tname, tname, tensor_wrapper_name, tname, no_need_buffer)
PLAIN_TENSOR_MEMBER_TEMPLATE = """ PLAIN_TENSOR_MEMBER_TEMPLATE = """
egr::TensorWrapper {}; egr::TensorWrapper {};
...@@ -430,12 +469,12 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -430,12 +469,12 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
for(const auto& eager_tensor : {}) {{ for(const auto& eager_tensor : {}) {{
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved) ); {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) );
}}; }};
}} }}
""" """
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name) tname, tname, tname, tensor_wrapper_name, no_need_buffer)
VECTOR_TENSOR_MEMBER_TEMPLATE = """ VECTOR_TENSOR_MEMBER_TEMPLATE = """
std::vector<egr::TensorWrapper> {}; std::vector<egr::TensorWrapper> {};
...@@ -562,11 +601,11 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std: ...@@ -562,11 +601,11 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std:
return node_definition_str return node_definition_str
def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, def GenerateNodeCreationCodes(
forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list): backward_grad_output_map, backward_attrs_list, optional_inputs):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -640,10 +679,17 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -640,10 +679,17 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_tensor_wrappers_list = []
for name, (_, is_fwd_input, _) in backward_fwd_input_map.items(): for name, (_, is_fwd_input, _) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input: if is_fwd_input:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
...@@ -732,7 +778,8 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -732,7 +778,8 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list): backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -741,6 +788,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -741,6 +788,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...} # backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...} # backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = ["name0", ...]
# Get Function Args # Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys( num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
...@@ -750,17 +798,18 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -750,17 +798,18 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_call_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}" inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
inputs_args_definition_list[ if is_optional:
pos] = f"const paddle::experimental::Tensor& {name}" arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[ else:
pos] = f"const paddle::experimental::Tensor& {name}" arg_str = f"const paddle::experimental::Tensor& {name}"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
inputs_args_definition_list[ arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[ inputs_args_definition_list[pos] = arg_str
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}" inputs_args_declaration_list[pos] = arg_str
for name, atype, default_val, pos in forward_attrs_list: for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name inputs_call_list[pos] = name
...@@ -776,13 +825,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -776,13 +825,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_call_args_str = ", ".join(inputs_call_list) inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic # Forward Full Logic
forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});" if len(intermediate_outputs) == 0:
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs # Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)] returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items(): for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1: if num_outputs == 1:
returns_list[0] = f"api_result" returns_list[0] = f"api_result"
else: else:
...@@ -808,7 +864,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -808,7 +864,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list, optional_inputs)
FORWARD_FUNCTION_TEMPLATE = """ FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{ {} {}({}) {{
...@@ -997,6 +1053,10 @@ if __name__ == "__main__": ...@@ -997,6 +1053,10 @@ if __name__ == "__main__":
assert 'output' in fwd_api.keys() assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys() assert 'backward' in fwd_api.keys()
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer'])
fwd_api_name = fwd_api['api'] fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output'] fwd_returns_str = fwd_api['output']
...@@ -1008,6 +1068,12 @@ if __name__ == "__main__": ...@@ -1008,6 +1068,12 @@ if __name__ == "__main__":
assert 'args' in bwd_api.keys() assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys() assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys() assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward'] bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args'] bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output'] bwd_returns_str = bwd_api['output']
...@@ -1019,6 +1085,12 @@ if __name__ == "__main__": ...@@ -1019,6 +1085,12 @@ if __name__ == "__main__":
print("Prased Forward Attrs List: ", forward_attrs_list) print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list) print("Parsed Forward Returns List: ", forward_returns_list)
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])
IntermediateValidationCheck(intermediate_outputs, forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
...@@ -1062,7 +1134,8 @@ if __name__ == "__main__": ...@@ -1062,7 +1134,8 @@ if __name__ == "__main__":
# Node Declaration Generation # Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration( node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list) fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str) print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition( node_definition_str += GenerateNodeDefinition(
...@@ -1076,7 +1149,8 @@ if __name__ == "__main__": ...@@ -1076,7 +1149,8 @@ if __name__ == "__main__":
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0] forward_definition_str += definition_declaration_pair[0]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import argparse import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = { atype_to_parsing_function = {
"bool": "CastPyArg2Boolean", "bool": "CastPyArg2Boolean",
...@@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype): ...@@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype):
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map, def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map): forward_attrs_list, forward_outputs_position_map,
optional_inputs):
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = [name0, ...]
# Get EagerTensor from args # Get EagerTensor from args
# Get dygraph function call args # Get dygraph function call args
...@@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map, ...@@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
dygraph_function_call_list = ["" for i in range(num_args)] dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = "" get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n" is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
get_eager_tensor_str += f" auto {name} = GetTensorListFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
if is_optional:
get_eager_tensor_str += f" auto {name} = GetOptionalTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
get_eager_tensor_str += f" auto {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = "" parse_attributes_str = ""
...@@ -267,6 +276,11 @@ if __name__ == "__main__": ...@@ -267,6 +276,11 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output'] fwd_returns_str = fwd_api['output']
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward( forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
...@@ -283,7 +297,7 @@ if __name__ == "__main__": ...@@ -283,7 +297,7 @@ if __name__ == "__main__":
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction( python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list, fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map) forward_outputs_position_map, optional_inputs)
python_c_function_list.append(python_c_function_str) python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str) python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str) print("Generated Python-C Function: ", python_c_function_str)
......
...@@ -97,6 +97,7 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -97,6 +97,7 @@ class AutogradMeta : public AbstractAutogradMeta {
"Should Not set NULL as GradNode pointer, since " "Should Not set NULL as GradNode pointer, since "
"our default Edge and autogradMeta has nullptr for " "our default Edge and autogradMeta has nullptr for "
"grad node. Set Nullptr will lead error.")); "grad node. Set Nullptr will lead error."));
grad_node_ = grad_node; grad_node_ = grad_node;
} }
...@@ -127,6 +128,12 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -127,6 +128,12 @@ class AutogradMeta : public AbstractAutogradMeta {
stop_gradient_ = static_cast<int>(stop_gradient); stop_gradient_ = static_cast<int>(stop_gradient);
} }
void WeakSetStopGradient(bool stop_gradient) {
if (stop_gradient_ == -1) {
stop_gradient_ = static_cast<int>(stop_gradient);
}
}
bool Persistable() const { return persistable_; } bool Persistable() const { return persistable_; }
void SetPersistable(bool persistable) { persistable_ = persistable; } void SetPersistable(bool persistable) { persistable_ = persistable; }
......
...@@ -53,7 +53,7 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) { ...@@ -53,7 +53,7 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} else { } else {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>()); meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} }
...@@ -69,13 +69,16 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { ...@@ -69,13 +69,16 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
"adj_edges is designed to has the same size of grad " "adj_edges is designed to has the same size of grad "
"inputs's slot num.")); "inputs's slot num."));
if (meta && !meta->StopGradient()) { if (meta && !meta->StopGradient()) {
VLOG(6) << "Add Edges for slot: " << slot_id;
auto node = meta->GetMutableGradNode(); auto node = meta->GetMutableGradNode();
if (node) { if (node) {
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} else { } else {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>()); meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} }
......
...@@ -147,6 +147,8 @@ class GradNodeBase { ...@@ -147,6 +147,8 @@ class GradNodeBase {
std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks( std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors); const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
virtual std::string name() { return "GradNodeBase"; }
private: private:
// TODO(jiabin): Use SmallVector instead after merge PR from develop // TODO(jiabin): Use SmallVector instead after merge PR from develop
......
...@@ -34,7 +34,8 @@ class TensorWrapper { ...@@ -34,7 +34,8 @@ class TensorWrapper {
public: public:
TensorWrapper() = default; TensorWrapper() = default;
explicit TensorWrapper(const paddle::experimental::Tensor& tensor, explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false) { bool full_reserved = false,
bool no_need_buffer = false) {
/** /**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor * Normally, we should fully reserved all non-output or non-leaf fwd tensor
* here. And for fwd output tensor, we should not reserve its autogradmeta, * here. And for fwd output tensor, we should not reserve its autogradmeta,
...@@ -48,16 +49,30 @@ class TensorWrapper { ...@@ -48,16 +49,30 @@ class TensorWrapper {
} }
// shallow copy tensor_impl here // shallow copy tensor_impl here
intermidiate_tensor_.set_impl(tensor.impl()); if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto tw_dense_tensor = std::make_shared<phi::DenseTensor>();
tw_dense_tensor->set_meta(dense_tensor->meta());
intermidiate_tensor_.set_impl(tw_dense_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
}
} else {
intermidiate_tensor_.set_impl(tensor.impl());
}
intermidiate_tensor_.set_name(tensor.name() + "@Saved"); intermidiate_tensor_.set_name(tensor.name() + "@Saved");
PADDLE_ENFORCE_NOT_NULL(
EagerUtils::unsafe_autograd_meta(tensor), // If an output is marked "intermedaite", we won't create
paddle::platform::errors::Fatal( // autograd_meta for it.
"Full reserved Tensor should not have null autograd meta, since " // In that case, simply skip OutRankInfo Copy
"tensor_wrapper is used to build backward info. There is no way " if (EagerUtils::nullable_autograd_meta(tensor)) {
"for us to build it with null autograd_meta.")); out_rank_info_ = EagerUtils::OutRankInfo(tensor);
// copy output_rank }
out_rank_info_ = EagerUtils::OutRankInfo(tensor);
} }
paddle::experimental::Tensor recover( paddle::experimental::Tensor recover(
......
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/api/utils/hook_utils.h"
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h" #include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
// TODO(jiabin): remove nolint here!!! // TODO(jiabin): remove nolint here!!!
...@@ -37,7 +39,7 @@ TEST(AccumulationNode, Tensor) { ...@@ -37,7 +39,7 @@ TEST(AccumulationNode, Tensor) {
.get(), .get(),
meta); meta);
dt0->mutable_data<paddle::platform::float16>( dt0->mutable_data<paddle::platform::float16>(
paddle::platform::CPUPlace())[0] = 10.0; paddle::platform::CPUPlace())[0] = paddle::platform::float16(10.0f);
paddle::experimental::Tensor et0 = paddle::experimental::Tensor(dt0); paddle::experimental::Tensor et0 = paddle::experimental::Tensor(dt0);
std::shared_ptr<phi::DenseTensor> dt1 = std::make_shared<phi::DenseTensor>( std::shared_ptr<phi::DenseTensor> dt1 = std::make_shared<phi::DenseTensor>(
...@@ -47,84 +49,100 @@ TEST(AccumulationNode, Tensor) { ...@@ -47,84 +49,100 @@ TEST(AccumulationNode, Tensor) {
meta); meta);
dt1->mutable_data<paddle::platform::float16>( dt1->mutable_data<paddle::platform::float16>(
paddle::platform::CPUPlace())[0] = 20.0; paddle::platform::CPUPlace())[0] = paddle::platform::float16(20.0f);
paddle::experimental::Tensor et1 = paddle::experimental::Tensor(dt1); paddle::experimental::Tensor et1 = paddle::experimental::Tensor(dt1);
std::shared_ptr<phi::DenseTensor> input_dt =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
paddle::experimental::Tensor input_et =
paddle::experimental::Tensor(input_dt);
auto grad_meta = EagerUtils::autograd_meta(&input_et);
// Initialize Grad Tensor
std::shared_ptr<phi::DenseTensor> grad_dt = std::shared_ptr<phi::DenseTensor> grad_dt =
std::make_shared<phi::DenseTensor>( std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>( std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()) paddle::platform::CPUPlace())
.get(), .get(),
meta); meta);
paddle::experimental::Tensor grad_et = paddle::experimental::Tensor(grad_dt); grad_dt->mutable_data<paddle::platform::float16>(
paddle::platform::CPUPlace())[0] = paddle::platform::float16(0.0f);
grad_meta->MutableGrad()->set_impl(grad_dt);
// AccumulationNode // AccumulationNode
GradNodeAccumulation node = GradNodeAccumulation(); auto node = std::make_shared<GradNodeAccumulation>(grad_meta);
grad_meta->SetGradNode(node);
// Hook, RetainGrad grad_meta->SetStopGradient(false);
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
hook = [&grad_et](const paddle::experimental::Tensor& t) {
grad_et.set_impl(t.impl());
return grad_et;
};
node.RetainGrad(hook);
// operator() // operator()
paddle::experimental::Tensor ret_et0 = node({{et0}})[0][0]; paddle::experimental::Tensor ret_et0 = node->operator()({{et0}})[0][0];
auto* ret_et0_ptr = auto* ret_et0_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl())
->data<paddle::platform::float16>(); ->data<paddle::platform::float16>();
CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f)); CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f));
paddle::experimental::Tensor ret_et1 = node({{et1}})[0][0]; paddle::experimental::Tensor ret_et1 = node->operator()({{et1}})[0][0];
auto* ret_et1_ptr = auto* ret_et1_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl())
->data<paddle::platform::float16>(); ->data<paddle::platform::float16>();
CHECK_EQ(ret_et1_ptr[0], paddle::platform::float16(30.0f)); CHECK_EQ(ret_et1_ptr[0], paddle::platform::float16(20.0f));
// Retain Grad // Check Retain Grad
auto* ret_grad_et_ptr = CHECK_EQ(std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl())
std::dynamic_pointer_cast<phi::DenseTensor>(grad_et.impl()) ->data<paddle::platform::float16>()[0],
->data<paddle::platform::float16>(); paddle::platform::float16(10.0f));
CHECK_EQ(ret_grad_et_ptr[0], paddle::platform::float16(30.0f)); paddle::experimental::Tensor* grad = EagerUtils::mutable_grad(input_et);
auto* grad_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl())
->data<paddle::platform::float16>();
CHECK_EQ(grad_ptr[0], paddle::platform::float16(30.0f));
// Reduce Hook case 1: Call RegisterReduceHook and run operator() // Reduce Hook case 1: Call RegisterReduceHook and run operator()
VLOG(6) << "Test Reduce Hook"; VLOG(6) << "Test Reduce Hook";
CHECK_EQ(std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl())
->data<paddle::platform::float16>()[0],
paddle::platform::float16(10.0f));
auto reduce_hook_1 = [&](void) -> void { auto reduce_hook_1 = [&](void) -> void {
auto* grad_et_ptr = auto* input_et_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(grad_et.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(input_et.impl())
->data<paddle::platform::float16>(); ->mutable_data<paddle::platform::float16>(
grad_et_ptr[0] = 36.0; paddle::platform::CPUPlace());
input_et_ptr[0] = 36.0;
VLOG(6) << "Running Reduce Hook"; VLOG(6) << "Running Reduce Hook";
}; };
node.RegisterReduceHook(reduce_hook_1); node->RegisterReduceHook(reduce_hook_1);
// operator() // operator()
paddle::experimental::Tensor _ret = node({{et0}})[0][0]; paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0];
// Check operator() result, should be 36.0 // Check operator() result, should be 36.0
auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl()) auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl())
->data<paddle::platform::float16>(); ->data<paddle::platform::float16>();
CHECK_EQ(_ret_ptr[0], paddle::platform::float16(36.0f)); CHECK_EQ(_ret_ptr[0], paddle::platform::float16(10.0f));
// Check Retain Grad, should be 36.0 // Check Retain Grad, should be 36.0
auto* _ret_grad_et_ptr = auto* _ret_input_et_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(grad_et.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(input_et.impl())
->data<paddle::platform::float16>(); ->data<paddle::platform::float16>();
CHECK_EQ(_ret_grad_et_ptr[0], paddle::platform::float16(36.0f)); CHECK_EQ(_ret_input_et_ptr[0], paddle::platform::float16(36.0f));
// Reduce Hook case 2: Call RegisterReduceHook and ApplyReduceHooks directly // Reduce Hook case 2: Call RegisterReduceHook and ApplyReduceHooks directly
VLOG(6) << "Test Reduce Hook"; VLOG(6) << "Test Reduce Hook";
auto reduce_hook_2 = [&](void) -> void { auto reduce_hook_2 = [&](void) -> void {
auto* ret_et0_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl()) auto* ret_et0_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl())
->data<paddle::platform::float16>(); ->mutable_data<paddle::platform::float16>(
paddle::platform::CPUPlace());
ret_et0_ptr[0] = 100.0; // set to 100.0 ret_et0_ptr[0] = 100.0; // set to 100.0
VLOG(6) << "Running Reduce Hook"; VLOG(6) << "Running Reduce Hook";
}; };
node.RegisterReduceHook(reduce_hook_2); node->RegisterReduceHook(reduce_hook_2);
node.ApplyReduceHooks(); node->ApplyReduceHooks();
// Check ApplyReduceHooks result // Check ApplyReduceHooks result
CHECK_EQ(std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl()) CHECK_EQ(std::dynamic_pointer_cast<phi::DenseTensor>(et0.impl())
......
...@@ -59,22 +59,18 @@ TEST(Backward, SingleNodeEmptyGrad) { ...@@ -59,22 +59,18 @@ TEST(Backward, SingleNodeEmptyGrad) {
auto_grad_meta->SetSingleOutRankWithSlot(0, 0); auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta->SetStopGradient(false); auto_grad_meta->SetStopGradient(false);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta // Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr = std::make_shared<egr::GradNodeAccumulation>(); auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
auto_grad_meta1->SetGradNode( auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
egr_utils_api::RetainGradForTensor(leaf_tensor); std::vector<egr::AutogradMeta*> res = {auto_grad_meta1};
// Connect Node0 -> AccumulationNode via Edge
auto meta = egr::AutogradMeta();
meta.SetStopGradient(false);
meta.SetSingleOutRankWithSlot(0, 0);
meta.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res = {&meta};
node0_ptr->AddEdges(&res, 0); node0_ptr->AddEdges(&res, 0);
} }
std::vector<paddle::experimental::Tensor> outs = {target_tensor}; std::vector<paddle::experimental::Tensor> outs = {target_tensor};
...@@ -123,22 +119,17 @@ TEST(Backward, SingleNodeCustomGrad) { ...@@ -123,22 +119,17 @@ TEST(Backward, SingleNodeCustomGrad) {
std::dynamic_pointer_cast<GradNodeBase>(node0_ptr)); std::dynamic_pointer_cast<GradNodeBase>(node0_ptr));
auto_grad_meta->SetSingleOutRankWithSlot(0, 0); auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta->SetStopGradient(false); auto_grad_meta->SetStopGradient(false);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr = std::make_shared<egr::GradNodeAccumulation>();
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor); AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
auto_grad_meta1->SetGradNode( auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
auto_grad_meta1->SetStopGradient(false);
egr_utils_api::RetainGradForTensor(leaf_tensor); std::vector<egr::AutogradMeta*> res = {auto_grad_meta1};
// Connect Node0 -> AccumulationNode via Edge
auto meta = egr::AutogradMeta();
meta.SetStopGradient(false);
meta.SetSingleOutRankWithSlot(0, 0);
meta.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res = {&meta};
node0_ptr->AddEdges(&res, 0); node0_ptr->AddEdges(&res, 0);
} }
...@@ -201,22 +192,17 @@ TEST(Backward, LinearNodes) { ...@@ -201,22 +192,17 @@ TEST(Backward, LinearNodes) {
std::vector<egr::AutogradMeta*> res0 = {&meta0}; std::vector<egr::AutogradMeta*> res0 = {&meta0};
node0_ptr->AddEdges(&res0, 0); node0_ptr->AddEdges(&res0, 0);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta // Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr = std::make_shared<egr::GradNodeAccumulation>(); auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta1);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
auto_grad_meta1->SetGradNode( auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
egr_utils_api::RetainGradForTensor(leaf_tensor); auto_grad_meta1->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res1 = {auto_grad_meta1};
// Connect Node1 -> AccumulationNode via Edge
auto meta1 = egr::AutogradMeta();
meta1.SetStopGradient(false);
meta1.SetSingleOutRankWithSlot(0, 0);
meta1.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res1 = {&meta1};
node1_ptr->AddEdges(&res1, 0); node1_ptr->AddEdges(&res1, 0);
} }
...@@ -311,22 +297,17 @@ TEST(Backward, WithAccumulation) { ...@@ -311,22 +297,17 @@ TEST(Backward, WithAccumulation) {
std::vector<egr::AutogradMeta*> res1 = {&meta1}; std::vector<egr::AutogradMeta*> res1 = {&meta1};
node1_ptr->AddEdges(&res1, 0); node1_ptr->AddEdges(&res1, 0);
AutogradMeta* auto_grad_meta2 = EagerUtils::autograd_meta(&leaf_tensor);
// Connect Tensor and AccumulationNode via AutoGradMeta // Connect Tensor and AccumulationNode via AutoGradMeta
auto acc_node_ptr = std::make_shared<egr::GradNodeAccumulation>(); auto acc_node_ptr =
std::make_shared<egr::GradNodeAccumulation>(auto_grad_meta2);
AutogradMeta* auto_grad_meta2 = EagerUtils::autograd_meta(&leaf_tensor);
auto_grad_meta2->SetGradNode( auto_grad_meta2->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta2->SetSingleOutRankWithSlot(0, 0); auto_grad_meta2->SetSingleOutRankWithSlot(0, 0);
egr_utils_api::RetainGradForTensor(leaf_tensor); auto_grad_meta2->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res2 = {auto_grad_meta2};
// Connect Node2 -> AccumulationNode via Edge
auto meta2 = egr::AutogradMeta();
meta2.SetStopGradient(false);
meta2.SetSingleOutRankWithSlot(0, 0);
meta2.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res2 = {&meta2};
node2_ptr->AddEdges(&res2, 0); node2_ptr->AddEdges(&res2, 0);
} }
......
...@@ -46,34 +46,26 @@ TEST(CrossBatchAccumulation, SingleScaleNode) { ...@@ -46,34 +46,26 @@ TEST(CrossBatchAccumulation, SingleScaleNode) {
paddle::experimental::Tensor& target_tensor = target_tensors[0]; paddle::experimental::Tensor& target_tensor = target_tensors[0];
paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor(); paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor();
{
auto scale_node_ptr = std::make_shared<GradNodeScale>(1, 1); auto scale_node_ptr = std::make_shared<GradNodeScale>(1, 1);
scale_node_ptr->SetAttributes_scale(5.0 /*scale*/); scale_node_ptr->SetAttributes_scale(5.0 /*scale*/);
scale_node_ptr->SetDefaultGradInOutMeta(); scale_node_ptr->SetDefaultGradInOutMeta();
auto acc_node_ptr = std::make_shared<GradNodeAccumulation>(); AutogradMeta* auto_grad_meta = EagerUtils::autograd_meta(&target_tensor);
auto_grad_meta->SetGradNode(
AutogradMeta* auto_grad_meta = EagerUtils::autograd_meta(&target_tensor); std::dynamic_pointer_cast<GradNodeBase>(scale_node_ptr));
auto_grad_meta->SetGradNode( auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
std::dynamic_pointer_cast<GradNodeBase>(scale_node_ptr)); auto_grad_meta->SetStopGradient(false);
auto_grad_meta->SetSingleOutRankWithSlot(0, 0); egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0
auto_grad_meta->SetStopGradient(false);
egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0 AutogradMeta* meta = EagerUtils::autograd_meta(&leaf_tensor);
auto acc_node_ptr = std::make_shared<GradNodeAccumulation>(meta);
auto meta = AutogradMeta(); meta->SetStopGradient(false);
meta.SetSingleOutRankWithSlot(0, 0); meta->SetSingleOutRankWithSlot(0, 0);
meta.SetStopGradient(false); meta->SetGradNode(acc_node_ptr);
meta.SetGradNode(acc_node_ptr); std::vector<egr::AutogradMeta*> res = {meta};
std::vector<egr::AutogradMeta*> res = {&meta}; scale_node_ptr->AddEdges(&res, 0);
scale_node_ptr->AddEdges(&res, 0);
AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor);
auto_grad_meta1->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr));
auto_grad_meta1->SetSingleOutRankWithSlot(0, 0);
egr_utils_api::RetainGradForTensor(leaf_tensor);
}
RunBackward(target_tensors, {}); RunBackward(target_tensors, {});
......
...@@ -159,7 +159,7 @@ TEST(EagerUtils, PassStopGradient) { ...@@ -159,7 +159,7 @@ TEST(EagerUtils, PassStopGradient) {
CHECK(auto_grad0->StopGradient() == false); CHECK(auto_grad0->StopGradient() == false);
egr::EagerUtils::PassStopGradient(true, auto_grad0.get(), auto_grad1.get(), egr::EagerUtils::PassStopGradient(true, auto_grad0.get(), auto_grad1.get(),
auto_grad2.get(), auto_grad3.get()); auto_grad2.get(), auto_grad3.get());
CHECK(auto_grad0->StopGradient() == true); CHECK(auto_grad0->StopGradient() == false);
CHECK(auto_grad1->StopGradient() == true); CHECK(auto_grad1->StopGradient() == true);
CHECK(auto_grad2->StopGradient() == true); CHECK(auto_grad2->StopGradient() == true);
CHECK(auto_grad3->StopGradient() == true); CHECK(auto_grad3->StopGradient() == true);
......
...@@ -79,9 +79,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -79,9 +79,6 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
// Set grad in/out meta for node0 // Set grad in/out meta for node0
scale_node_ptr->SetDefaultGradInOutMeta(); scale_node_ptr->SetDefaultGradInOutMeta();
// Create AccumulationNode
auto acc_node_ptr = std::make_shared<GradNodeAccumulation>();
// Connect Input Tensor and ScaleNode via AutoGradMeta // Connect Input Tensor and ScaleNode via AutoGradMeta
// Apply RetainGrad // Apply RetainGrad
{ {
...@@ -102,16 +99,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -102,16 +99,8 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook);
egr_utils_api::RetainGradForTensor( egr_utils_api::RetainGradForTensor(
target_tensor); // result: 1.0 + 3.0 = 4.0 target_tensor); // result: 1.0 + 3.0 = 4.0
} egr_utils_api::RetainGradForTensor(
target_tensor); // result: 1.0 + 3.0 = 4.0
// Connect ScaleNode -> AccumulationNode via Edge
{
auto meta = AutogradMeta();
meta.SetStopGradient(false);
meta.SetSingleOutRankWithSlot(0, 0);
meta.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res = {&meta};
scale_node_ptr->AddEdges(&res, 0);
} }
// Retain Grad for leaf tensor1 // Retain Grad for leaf tensor1
...@@ -123,9 +112,16 @@ TEST(RetainGrad, HookBeforeRetainGrad) { ...@@ -123,9 +112,16 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
hook = &hook_function; hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
auto_grad_meta->SetGradNode(
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); auto acc_node_ptr =
std::make_shared<GradNodeAccumulation>(auto_grad_meta.get());
auto_grad_meta->SetStopGradient(false);
auto_grad_meta->SetGradNode(acc_node_ptr);
auto_grad_meta->SetSingleOutRankWithSlot(0, 0); auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
std::vector<egr::AutogradMeta*> res = {auto_grad_meta.get()};
scale_node_ptr->AddEdges(&res, 0);
leaf_tensor.set_autograd_meta( leaf_tensor.set_autograd_meta(
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta)); auto_grad_meta));
...@@ -160,8 +156,6 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -160,8 +156,6 @@ TEST(RetainGrad, HookAfterRetainGrad) {
scale_node_ptr->SetAttributes_scale(5.0 /*scale*/); scale_node_ptr->SetAttributes_scale(5.0 /*scale*/);
// Set grad in/out meta for node0 // Set grad in/out meta for node0
scale_node_ptr->SetDefaultGradInOutMeta(); scale_node_ptr->SetDefaultGradInOutMeta();
// Create AccumulationNode
auto acc_node_ptr = std::make_shared<GradNodeAccumulation>();
// Connect Input Tensor and ScaleNode via AutoGradMeta // Connect Input Tensor and ScaleNode via AutoGradMeta
// Apply RetainGrad // Apply RetainGrad
...@@ -184,16 +178,6 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -184,16 +178,6 @@ TEST(RetainGrad, HookAfterRetainGrad) {
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook);
} }
// Connect ScaleNode -> AccumulationNode via Edge
{
auto meta = AutogradMeta();
meta.SetStopGradient(false);
meta.SetSingleOutRankWithSlot(0, 0);
meta.SetGradNode(acc_node_ptr);
std::vector<egr::AutogradMeta*> res = {&meta};
scale_node_ptr->AddEdges(&res, 0);
}
// Retain Grad for leaf tensor1 // Retain Grad for leaf tensor1
paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor(); paddle::experimental::Tensor leaf_tensor = paddle::experimental::Tensor();
{ {
...@@ -203,17 +187,18 @@ TEST(RetainGrad, HookAfterRetainGrad) { ...@@ -203,17 +187,18 @@ TEST(RetainGrad, HookAfterRetainGrad) {
hook = &hook_function; hook = &hook_function;
auto auto_grad_meta = std::make_shared<AutogradMeta>(); auto auto_grad_meta = std::make_shared<AutogradMeta>();
auto_grad_meta->SetGradNode( auto acc_node_ptr =
std::dynamic_pointer_cast<GradNodeBase>(acc_node_ptr)); std::make_shared<GradNodeAccumulation>(auto_grad_meta.get());
auto_grad_meta->SetGradNode(acc_node_ptr);
auto_grad_meta->SetStopGradient(false);
std::vector<egr::AutogradMeta*> res = {auto_grad_meta.get()};
scale_node_ptr->AddEdges(&res, 0);
auto_grad_meta->SetSingleOutRankWithSlot(0, 0); auto_grad_meta->SetSingleOutRankWithSlot(0, 0);
leaf_tensor.set_autograd_meta( leaf_tensor.set_autograd_meta(
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta)); auto_grad_meta));
egr_utils_api::RetainGradForTensor(
leaf_tensor); // RetainGrad for leaf tensor gets
// postponed, result: 4.0*5.0 + 3.0 =
// 23.0
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook); egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/api/utils/hook_utils.h"
#include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/eager/tensor_wrapper.h"
...@@ -21,7 +22,6 @@ ...@@ -21,7 +22,6 @@
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -109,6 +109,16 @@ std::shared_ptr<GradNodeBase> EagerUtils::grad_node( ...@@ -109,6 +109,16 @@ std::shared_ptr<GradNodeBase> EagerUtils::grad_node(
} }
} }
paddle::experimental::Tensor* EagerUtils::mutable_grad(
const paddle::experimental::Tensor& target) {
auto* meta = nullable_autograd_meta(target);
if (meta) {
return meta->MutableGrad();
} else {
return nullptr;
}
}
void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas, void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
const std::shared_ptr<GradNodeBase>& grad_node) { const std::shared_ptr<GradNodeBase>& grad_node) {
for (const auto& autograd_meta : *autograd_metas) { for (const auto& autograd_meta : *autograd_metas) {
...@@ -220,53 +230,62 @@ paddle::experimental::Tensor EagerUtils::GetOutput( ...@@ -220,53 +230,62 @@ paddle::experimental::Tensor EagerUtils::GetOutput(
return paddle::experimental::Tensor(out->GetTensorBase(), out->name()); return paddle::experimental::Tensor(out->GetTensorBase(), out->name());
} }
void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerVariable>& out, void EagerUtils::GetOutput(const std::shared_ptr<EagerVariable>& out,
paddle::experimental::Tensor* tensor) { paddle::experimental::Tensor* out_var) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor, paddle::platform::errors::Fatal( out_var, paddle::platform::errors::Fatal(
"Tensor is null and cannot be copied. " "Tensor is null and cannot be copied. "
"We are tring to OverwriteOutput from its " "We are tring to OverwriteOutput from its "
"shared_ptr, this error may indicate some outputs " "shared_ptr, this error may indicate some outputs "
"are nullptr")); "are nullptr"));
tensor->set_impl(out->GetTensorBase()); out_var->set_impl(out->GetTensorBase());
} }
void EagerUtils::OverwriteOutputs( void EagerUtils::GetOutputs(
const std::vector<std::shared_ptr<EagerVariable>>& outs, const std::vector<std::shared_ptr<EagerVariable>>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors) { std::vector<paddle::experimental::Tensor>* result) {
PADDLE_ENFORCE_EQ(
outs.size(), tensors.size(),
paddle::platform::errors::Fatal(
"We are tring to OverwriteOutputs which passed in and it expected "
"elements num of outs and origin outputs are equal, but we got outs "
"size of: %d, and tensors passed in size is: %d",
outs.size(), tensors.size()));
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
OverwriteOutputs(outs[i], tensors[i]); result->emplace_back(outs[i]->GetTensorBase());
} }
} }
void EagerUtils::OverwriteOutputs(const paddle::experimental::Tensor& out, void EagerUtils::GetOutputs(
paddle::experimental::Tensor* tensor) { const std::vector<std::shared_ptr<EagerVariable>>& outs,
PADDLE_ENFORCE_NOT_NULL( const std::vector<paddle::experimental::Tensor*>& out_var) {
tensor, paddle::platform::errors::Fatal(
"Tensor is null and cannot be copied. "
"We are tring to OverwriteOutput from its "
"shared_ptr, this error may indicate some outputs "
"are nullptr"));
*tensor = out;
}
void EagerUtils::OverwriteOutputs(
const std::vector<paddle::experimental::Tensor>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensors[i], paddle::platform::errors::Fatal( out_var[i], paddle::platform::errors::Fatal(
"Tensor is null and cannot be copied. " "Tensor is null and cannot be copied. "
"We are tring to OverwriteOutput from its " "We are tring to OverwriteOutput from its "
"shared_ptr, this error may indicate some outputs " "shared_ptr, this error may indicate some outputs "
"are nullptr")); "are nullptr"));
*tensors[i] = outs[i]; out_var[i]->set_impl(outs[i]->GetTensorBase());
}
}
void EagerUtils::GetOutputs(const std::shared_ptr<EagerVariable>& out,
std::vector<paddle::experimental::Tensor>* result) {
result->emplace_back(out->GetTensorBase());
}
void EagerUtils::GetOutputs(
const std::shared_ptr<EagerVariable>& out,
const std::vector<paddle::experimental::Tensor*>& out_var) {
PADDLE_ENFORCE_NOT_NULL(
out_var[0], paddle::platform::errors::Fatal(
"Tensor is null and cannot be copied. "
"We are tring to OverwriteOutput from its "
"shared_ptr, this error may indicate some outputs "
"are nullptr"));
out_var[0]->set_impl(out->GetTensorBase());
}
void EagerUtils::Output2Result(
const std::vector<paddle::experimental::Tensor*>& out_var,
std::vector<paddle::experimental::Tensor>* result) {
result->reserve(out_var.size());
for (size_t i = 0; i < out_var.size(); i++) {
result->emplace_back(*out_var[i]);
} }
} }
...@@ -333,7 +352,8 @@ std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode( ...@@ -333,7 +352,8 @@ std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
} else { } else {
if (!autograd_ptr->StopGradient()) { if (!autograd_ptr->StopGradient()) {
VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name(); VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name();
autograd_ptr->SetGradNode(std::make_shared<egr::GradNodeAccumulation>()); autograd_ptr->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(autograd_ptr));
return autograd_ptr->GetMutableGradNode(); return autograd_ptr->GetMutableGradNode();
} else { } else {
return nullptr; return nullptr;
......
...@@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper<AutogradMeta*> { ...@@ -77,7 +77,7 @@ class PassStopGradientIter : public IterHelper<AutogradMeta*> {
VLOG(2) << "Tensor is NULL"; VLOG(2) << "Tensor is NULL";
return; return;
} }
element->SetStopGradient(stop_gradient_); element->WeakSetStopGradient(stop_gradient_);
} }
bool stop_gradient_ = true; bool stop_gradient_ = true;
...@@ -102,6 +102,8 @@ class EagerUtils { ...@@ -102,6 +102,8 @@ class EagerUtils {
static std::shared_ptr<GradNodeBase> grad_node( static std::shared_ptr<GradNodeBase> grad_node(
const paddle::experimental::Tensor& target); const paddle::experimental::Tensor& target);
static paddle::experimental::Tensor* mutable_grad(
const paddle::experimental::Tensor& target);
// Set history is used to set backward info during forward process, it will // Set history is used to set backward info during forward process, it will
// set forward var's autograd meta's grad node as current backward node. // set forward var's autograd meta's grad node as current backward node.
...@@ -173,17 +175,24 @@ class EagerUtils { ...@@ -173,17 +175,24 @@ class EagerUtils {
const std::vector<std::shared_ptr<EagerVariable>>& outs); const std::vector<std::shared_ptr<EagerVariable>>& outs);
static paddle::experimental::Tensor GetOutput( static paddle::experimental::Tensor GetOutput(
const std::shared_ptr<EagerVariable>& out); const std::shared_ptr<EagerVariable>& out);
// Sync Back to origin output Tensor static void GetOutput(const std::shared_ptr<EagerVariable>& out,
static void OverwriteOutputs(const std::shared_ptr<EagerVariable>& out, paddle::experimental::Tensor* out_var);
paddle::experimental::Tensor* tensor); static void GetOutputs(
static void OverwriteOutputs(const paddle::experimental::Tensor& out,
paddle::experimental::Tensor* tensor);
static void OverwriteOutputs(
const std::vector<std::shared_ptr<EagerVariable>>& outs, const std::vector<std::shared_ptr<EagerVariable>>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors); std::vector<paddle::experimental::Tensor>* result);
static void OverwriteOutputs( static void GetOutputs(
const std::vector<paddle::experimental::Tensor>& outs, const std::vector<std::shared_ptr<EagerVariable>>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors); const std::vector<paddle::experimental::Tensor*>& out_var);
static void GetOutputs(const std::shared_ptr<EagerVariable>& out,
std::vector<paddle::experimental::Tensor>* result);
static void GetOutputs(
const std::shared_ptr<EagerVariable>& out,
const std::vector<paddle::experimental::Tensor*>& out_var);
static void Output2Result(
const std::vector<paddle::experimental::Tensor*>& out_var,
std::vector<paddle::experimental::Tensor>* result);
// end Intermidate needed // end Intermidate needed
static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor);
......
...@@ -437,8 +437,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}") ...@@ -437,8 +437,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h) configure_file(commit.h.in commit.h)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api)
cc_library(custom_kernel SRCS custom_kernel.cc DEPS cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry pten_custom_kernel pten_tensor_raw)
tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_kernel_info pten_api)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
...@@ -459,4 +458,3 @@ else() ...@@ -459,4 +458,3 @@ else()
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
endif() endif()
cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils) cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils)
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
...@@ -18,355 +18,24 @@ limitations under the License. */ ...@@ -18,355 +18,24 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/framework/custom_kernel.h" #include "paddle/fluid/framework/custom_kernel.h"
#include <dirent.h> #include "paddle/phi/core/custom_kernel.h"
#include <algorithm>
#include <regex>
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/ext/op_kernel_info.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// set phi::Kernel args_def_ from op_kernel_info
// because we can not set directly to phi::Kernel without exposing
// phi::KernelArgsDef when parsing custom user function
static void ParseArgs(const OpKernelInfo& op_kernel_info,
phi::KernelArgsDef* args_def) {
auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info);
auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info);
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
for (auto& input : input_defs) {
auto type_index =
input.is_vector
? std::type_index(typeid(const std::vector<phi::DenseTensor>&))
: std::type_index(typeid(const phi::DenseTensor&));
args_def->AppendInput(input.backend, input.layout, input.dtype, type_index);
}
for (auto& output : output_defs) {
auto type_index =
output.is_vector
? std::type_index(typeid(const std::vector<phi::DenseTensor>&))
: std::type_index(typeid(const phi::DenseTensor&));
args_def->AppendOutput(output.backend, output.layout, output.dtype,
type_index);
}
for (auto& attr : attribute_defs) {
args_def->AppendAttribute(attr.type_index);
}
}
// custom pten kernel call function define
static void RunKernelFunc(phi::KernelContext* ctx,
const OpKernelInfo& op_kernel_info) {
VLOG(3) << "[CUSTOM KERNEL] RunKernelFunc begin...";
// input and output size is not params' num
// but actual Tensors' size
size_t input_size = ctx->InputsSize();
size_t output_size = ctx->OutputsSize();
size_t attr_size = ctx->AttrsSize();
// parameters' num of unified user kernel function
auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info);
auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info);
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
PADDLE_ENFORCE_GE(input_size, input_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx inputs size (%d) must be larger than "
"the size of kernel input_defs (%d).",
input_size, input_defs.size()));
PADDLE_ENFORCE_GE(output_size, output_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx outputs size (%d) must be larger than "
"the size of kernel output_defs (%d).",
output_size, output_defs.size()));
PADDLE_ENFORCE_EQ(attr_size, attribute_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx attribute size (%d) must be equal to "
"to the size of kernel attribute_defs (%d).",
attr_size, attribute_defs.size()));
VLOG(3) << "[CUSTOM KERNEL] Input num: " << input_defs.size()
<< "[tensor size:" << input_size << "]"
<< " Attribute num: " << attribute_defs.size()
<< " Output num: " << output_defs.size()
<< "[tensor size:" << output_size << "].";
// Inputs mapping
std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
for (size_t in_idx = 0; in_idx < input_defs.size(); ++in_idx) {
VLOG(3) << "Mapping Input[" << in_idx << "]";
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
// is_vector tells if this Input is Tensor or std::vector<Tensor>
if (!input_defs.at(in_idx).is_vector) {
paddle::experimental::Tensor custom_t;
auto& ctx_tensor = ctx->InputAt<phi::DenseTensor>(range.first);
custom_t.set_impl(std::make_shared<phi::DenseTensor>(ctx_tensor));
custom_ins.emplace_back(custom_t);
} else {
std::vector<paddle::experimental::Tensor> custom_vec_in;
auto ctx_tensor_vec =
ctx->MoveInputsBetween<phi::DenseTensor>(range.first, range.second);
for (auto& ctx_tensor : ctx_tensor_vec) {
paddle::experimental::Tensor custom_t;
custom_t.set_impl(std::make_shared<phi::DenseTensor>(ctx_tensor));
custom_vec_in.emplace_back(custom_t);
}
custom_vec_ins.emplace_back(custom_vec_in);
}
VLOG(3) << "Mapped Input[" << in_idx << "] with range[" << range.first
<< "," << range.second << ").";
}
// Attributes mapping
std::vector<paddle::any> custom_attrs;
for (size_t attr_idx = 0; attr_idx < attribute_defs.size(); ++attr_idx) {
VLOG(3) << "Mapping Attribute[" << attr_idx << "]";
if (attribute_defs[attr_idx].type_index == std::type_index(typeid(bool))) {
bool arg = ctx->AttrAt<bool>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(int))) {
int arg = ctx->AttrAt<int>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(float))) {
float arg = ctx->AttrAt<float>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(double))) {
double arg = ctx->AttrAt<double>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(int64_t))) {
int64_t arg = ctx->AttrAt<int64_t>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(phi::dtype::float16))) {
phi::dtype::float16 arg = ctx->AttrAt<phi::dtype::float16>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(DataType))) {
DataType arg = ctx->AttrAt<DataType>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const Scalar&))) {
const Scalar& arg = ctx->AttrAt<const Scalar&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const std::vector<int64_t>&))) {
const std::vector<int64_t>& arg =
ctx->AttrAt<const std::vector<int64_t>&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const ScalarArray&))) {
const ScalarArray& arg = ctx->AttrAt<const ScalarArray&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const std::vector<int>&))) {
const std::vector<int>& arg =
ctx->AttrAt<const std::vector<int>&>(attr_idx);
custom_attrs.emplace_back(arg);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported attribute attribute_defs[%d].type_index", attr_idx));
}
VLOG(3) << "Mapped Attribute[" << attr_idx << "]";
}
// Outputs mapping
std::vector<paddle::experimental::Tensor*> custom_outs;
std::vector<std::vector<paddle::experimental::Tensor*>> custom_vec_outs;
std::vector<std::shared_ptr<phi::DenseTensor>> custom_outs_ptr;
std::vector<std::vector<std::shared_ptr<phi::DenseTensor>>>
custom_vec_outs_ptr;
for (size_t out_idx = 0; out_idx < output_defs.size(); ++out_idx) {
VLOG(3) << "Mapping Output[" << out_idx << "]";
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
// is_vector tells if this Output is Tensor or std::vector<Tensor>
if (!output_defs.at(out_idx).is_vector) {
auto* ctx_tensor = ctx->MutableOutputAt<phi::DenseTensor>(range.first);
auto* custom_t = new paddle::experimental::Tensor();
auto custom_t_ptr = std::make_shared<phi::DenseTensor>(*ctx_tensor);
custom_t->set_impl(custom_t_ptr);
custom_outs.emplace_back(custom_t);
custom_outs_ptr.emplace_back(custom_t_ptr);
} else {
std::vector<paddle::experimental::Tensor*> custom_vec_out;
std::vector<std::shared_ptr<phi::DenseTensor>> custom_vec_out_ptr;
auto ctx_tensor_vec = ctx->MutableOutputBetween<phi::DenseTensor>(
range.first, range.second);
for (auto ctx_tensor : ctx_tensor_vec) {
auto* custom_t = new paddle::experimental::Tensor();
auto custom_t_ptr = std::make_shared<phi::DenseTensor>(*ctx_tensor);
custom_t->set_impl(custom_t_ptr);
custom_vec_out.emplace_back(custom_t);
custom_vec_out_ptr.emplace_back(custom_t_ptr);
}
custom_vec_outs.emplace_back(custom_vec_out);
custom_vec_outs_ptr.emplace_back(custom_vec_out_ptr);
}
VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first
<< "," << range.second << ").";
}
// DeviceContext
// In pten, the first paramter XXContext is decided when registering
// through template param, but custom kernel function use unified
// DeviceContext as first parameter of user_kernel_fn, we use backend
// from OpKernelInfo to decide XXContext. In temporary simple
// DeviceContext, we just set necessary info to dev_ctx(such as stream
// in NPUContext), more related work should be done when
// phi::DeviceContext is exposed to outer.
DeviceContext dev_ctx;
auto& backend = OpKernelInfoHelper::GetBackend(op_kernel_info);
if (backend == phi::Backend::CPU) {
// do nothing
} else {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(phi::Backend::ALL_BACKEND);
std::string device_type = phi::GetGlobalDeviceType(device_type_id_);
if (!device_type.empty()) {
auto custom_ctx =
ctx->GetDeviceContext<paddle::platform::CustomDeviceContext>();
dev_ctx.set_stream(custom_ctx.stream());
return;
}
#endif
LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend
<< " with compiled Paddle.";
return;
}
auto& user_kernel_fn = OpKernelInfoHelper::GetKernelFn(op_kernel_info);
// call user function
user_kernel_fn(dev_ctx, custom_ins, custom_vec_ins, custom_attrs,
&custom_outs, &custom_vec_outs);
VLOG(3) << "[CUSTOM KERNEL] finished call user kernel function.";
// NOTE: Map back the output tensors with stored shared_ptrs.
for (int out_idx = output_defs.size() - 1; out_idx >= 0; --out_idx) {
VLOG(3) << "Mapping Back Output[" << out_idx << "]";
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
// is_vector tells if this Output is Tensor or std::vector<Tensor>
if (!output_defs.at(out_idx).is_vector) {
auto* ctx_tensor = ctx->MutableOutputAt<phi::DenseTensor>(range.first);
*ctx_tensor = *(custom_outs_ptr.back().get());
custom_outs_ptr.pop_back();
} else {
auto ctx_tensor_vec = ctx->MutableOutputBetween<phi::DenseTensor>(
range.first, range.second);
auto custom_vec_ptr_out = custom_vec_outs_ptr.back();
for (int idx = ctx_tensor_vec.size() - 1; idx >= 0; --idx) {
*(ctx_tensor_vec[idx]) = *(custom_vec_ptr_out.back().get());
custom_vec_ptr_out.pop_back();
}
custom_vec_outs_ptr.pop_back();
}
VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first
<< "," << range.second << "].";
}
// delete newed paddle::Tensor for outputs while calling user kernel function
for (size_t i = 0; i < custom_outs.size(); ++i) {
delete custom_outs[i];
}
for (size_t i = 0; i < custom_vec_outs.size(); ++i) {
for (size_t j = 0; j < custom_vec_outs[i].size(); ++j) {
delete custom_vec_outs[i][j];
}
}
}
void RegisterKernelWithMetaInfo(
const std::vector<OpKernelInfo>& op_kernel_infos) {
for (size_t i = 0; i < op_kernel_infos.size(); ++i) {
auto& kernel_info = op_kernel_infos[i];
auto op_type = OpKernelInfoHelper::GetOpName(kernel_info);
auto kernel_key = OpKernelInfoHelper::GetKernelKey(kernel_info);
VLOG(3) << "[CUSTOM KERNEL] registering [" << op_type << "]" << kernel_key;
// 1.Check whether this kernel is valid for a specific operator
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type), true,
platform::errors::InvalidArgument(
"[CUSTOM KERNEL] %s is not ready for custom kernel registering.",
op_type));
// 2.Check whether kernel_key has been already registed
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().kernels()[op_type].find(kernel_key),
phi::KernelFactory::Instance().kernels()[op_type].end(),
platform::errors::InvalidArgument(
"[CUSTOM KERNEL] The operator <%s>'s kernel: %s has been "
"already existed in Paddle, please contribute PR if need "
"to optimize the kernel code. Custom kernel do NOT support "
"to replace existing kernel in Paddle.",
op_type, kernel_key));
// phi::KernelFn
phi::KernelFn kernel_fn = [kernel_info](phi::KernelContext* ctx) {
VLOG(3) << "[CUSTOM KERNEL] run custom PTEN kernel func in lambda.";
RunKernelFunc(ctx, kernel_info);
};
// variadic_kernel_fn
void* variadic_kernel_fn =
OpKernelInfoHelper::GetVariadicKernelFn(kernel_info);
phi::Kernel kernel(kernel_fn, variadic_kernel_fn);
// args info
ParseArgs(kernel_info, kernel.mutable_args_def());
// register custom kernel to phi::KernelFactory
phi::KernelFactory::Instance().kernels()[op_type][kernel_key] = kernel;
VLOG(3) << "[CUSTOM KERNEL] Successed in registering operator <" << op_type
<< ">'s kernel " << kernel_key << " to Paddle. "
<< "It will be used like native ones.";
}
}
void RegisterKernelWithMetaInfoMap(
const paddle::OpKernelInfoMap& op_kernel_info_map) {
auto& kernel_info_map = op_kernel_info_map.GetMap();
VLOG(3) << "[CUSTOM KERNEL] size of op_kernel_info_map: "
<< kernel_info_map.size();
// pair: {op_type, OpKernelInfo}
for (auto& pair : kernel_info_map) {
VLOG(3) << "[CUSTOM KERNEL] pair first -> op name: " << pair.first;
RegisterKernelWithMetaInfo(pair.second);
}
}
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) { void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX #ifdef _LINUX
typedef OpKernelInfoMap& get_op_kernel_info_map_t(); typedef phi::CustomKernelMap& get_custom_kernel_map_t();
auto* func = reinterpret_cast<get_op_kernel_info_map_t*>( auto* func = reinterpret_cast<get_custom_kernel_map_t*>(
dlsym(dso_handle, "PD_GetOpKernelInfoMap")); dlsym(dso_handle, "PD_GetCustomKernelMap"));
if (func == nullptr) { if (func == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find " LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetOpKernelInfoMap symbol in this lib."; << "PD_GetCustomKernelMap symbol in this lib.";
return; return;
} }
auto& op_kernel_info_map = func(); auto& custom_kernel_map = func();
RegisterKernelWithMetaInfoMap(op_kernel_info_map); phi::RegisterCustomKernels(custom_kernel_map);
LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path; LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path;
#else #else
VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux."; VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux.";
......
...@@ -14,22 +14,13 @@ limitations under the License. */ ...@@ -14,22 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/api/ext/op_kernel_info.h" #include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Load custom kernel lib and register
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle); void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
// Load custom kernel api: register kernel after user compiled
void LoadOpKernelInfoAndRegister(const std::string& dso_name);
// Register custom kernel api: register kernel directly
void RegisterKernelWithMetaInfoMap(
const paddle::OpKernelInfoMap& op_kernel_info_map);
// Interface for selective register custom kernel.
void RegisterKernelWithMetaInfo(
const std::vector<OpKernelInfo>& op_kernel_infos);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DECLARE_bool(sync_nccl_allreduce); DECLARE_bool(sync_nccl_allreduce);
...@@ -47,6 +48,8 @@ GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle( ...@@ -47,6 +48,8 @@ GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle(
#endif #endif
void GradMergeAllReduceOpHandle::RunImpl() { void GradMergeAllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(
Name(), platform::TracerEventType::Communication, 1);
PADDLE_ENFORCE_GT(local_scopes_.size(), 0, PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of local scope should be > 0, but got %zu.", "The number of local scope should be > 0, but got %zu.",
...@@ -96,6 +99,8 @@ FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle( ...@@ -96,6 +99,8 @@ FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle(
#endif #endif
void FusedGradMergeAllReduceOpHandle::RunImpl() { void FusedGradMergeAllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(
Name(), platform::TracerEventType::Communication, 1);
PADDLE_ENFORCE_GT(local_scopes_.size(), 0, PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of local scope should be > 0, but got %zu.", "The number of local scope should be > 0, but got %zu.",
......
...@@ -10,6 +10,8 @@ IF(WITH_GPU) ...@@ -10,6 +10,8 @@ IF(WITH_GPU)
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS}) nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
ENDIF() ENDIF()
IF(WITH_ROCM) IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "heter_comm.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource)
: HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25;
}
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult *graph_neighbor_sample(int gpu_id, int64_t *key,
int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
int sample_size, int *h_left,
int *h_right,
int64_t *src_sample_res,
int *actual_sample_size);
private:
std::vector<GpuPsCommGraph> gpu_graph_list;
};
}
};
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h"
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
/*
comment 0
this kernel just serves as an example of how to sample nodes' neighbors.
feel free to modify it
index[0,len) saves the nodes' index
actual_size[0,len) is to save the sample size of each node.
for ith node in index, actual_size[i] = min(node i's neighbor size, sample size)
sample_result is to save the neighbor sampling result, its size is len *
sample_size;
*/
__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
int* actual_size,
int64_t* sample_result, int sample_size,
int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto node_index = index[i];
actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size
? graph.node_list[node_index].neighbor_size
: sample_size;
int offset = graph.node_list[node_index].neighbor_offset;
for (int j = 0; j < actual_size[i]; j++) {
sample_result[sample_size * i + j] = graph.neighbor_list[offset + j];
}
}
}
/*
comment 1
gpu i triggers a neighbor_sample task,
when this task is done,
this function is called to move the sample result on other gpu back
to gup i and aggragate the result.
the sample_result is saved on src_sample_res and the actual sample size for
each node is saved on actual_sample_size.
the number of actual sample_result for
key[x] (refer to comment 2 for definition of key)
is saved on actual_sample_size[x], since the neighbor size of key[x] might be
smaller than sample_size,
is saved on src_sample_res [x*sample_size, x*sample_size +
actual_sample_size[x])
since before each gpu runs the neighbor_sample task,the key array is shuffled,
but we have the idx array to save the original order.
when the gpu i gets all the sample results from other gpus, it relies on
idx array to recover the original order.
that's what fill_dvals does.
*/
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right,
int64_t* src_sample_res, int* actual_sample_size) {
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto shard_len = h_right[i] - h_left[i] + 1;
// int cur_step = path_[gpu_id][i].nodes_.size() - 1;
// auto& node = path_[gpu_id][i].nodes_[cur_step];
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size),
node.val_storage + sizeof(int64_t) * shard_len,
node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault,
node.out_stream);
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
}
}
/*
TODO:
how to optimize it to eliminate the for loop
*/
__global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals,
int* d_shard_actual_sample_size,
int* d_actual_sample_size, int* idx,
int sample_size, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
// d_vals[idx[i]] = d_shard_vals[i];
for (int j = 0; j < sample_size; j++) {
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
}
}
}
__global__ void node_query_example(GpuPsCommGraph graph, int start, int size,
int64_t* res) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size) {
res[i] = graph.node_list[start + i].node_id;
}
}
void GpuPsGraphTable::clear_graph_info() {
if (tables_.size()) {
for (auto table : tables_) delete table;
}
tables_.clear();
for (auto graph : gpu_graph_list) {
if (graph.neighbor_list != NULL) {
cudaFree(graph.neighbor_list);
}
if (graph.node_list != NULL) {
cudaFree(graph.node_list);
}
}
gpu_graph_list.clear();
}
/*
the parameter std::vector<GpuPsCommGraph> cpu_graph_list is generated by cpu.
it saves the graph to be saved on each gpu.
for the ith GpuPsCommGraph, any the node's key satisfies that key % gpu_number
== i
In this function, memory is allocated on each gpu to save the graphs,
gpu i saves the ith graph from cpu_graph_list
*/
void GpuPsGraphTable::build_graph_from_cpu(
std::vector<GpuPsCommGraph>& cpu_graph_list) {
PADDLE_ENFORCE_EQ(
cpu_graph_list.size(), resource_->total_gpu(),
platform::errors::InvalidArgument("the cpu node list size doesn't match "
"the number of gpu on your machine."));
clear_graph_info();
for (int i = 0; i < cpu_graph_list.size(); i++) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
gpu_graph_list.push_back(GpuPsCommGraph());
auto table =
new Table(std::max(1, cpu_graph_list[i].node_size) / load_factor_);
tables_.push_back(table);
if (cpu_graph_list[i].node_size > 0) {
std::vector<int64_t> keys;
std::vector<int> offset;
cudaMalloc((void**)&gpu_graph_list[i].node_list,
cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode));
cudaMemcpy(gpu_graph_list[i].node_list, cpu_graph_list[i].node_list,
cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode),
cudaMemcpyHostToDevice);
for (int j = 0; j < cpu_graph_list[i].node_size; j++) {
keys.push_back(cpu_graph_list[i].node_list[j].node_id);
offset.push_back(j);
}
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8);
gpu_graph_list[i].node_size = cpu_graph_list[i].node_size;
} else {
gpu_graph_list[i].node_list = NULL;
gpu_graph_list[i].node_size = 0;
}
if (cpu_graph_list[i].neighbor_size) {
cudaMalloc((void**)&gpu_graph_list[i].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(int64_t));
cudaMemcpy(gpu_graph_list[i].neighbor_list,
cpu_graph_list[i].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(int64_t),
cudaMemcpyHostToDevice);
gpu_graph_list[i].neighbor_size = cpu_graph_list[i].neighbor_size;
} else {
gpu_graph_list[i].neighbor_list = NULL;
gpu_graph_list[i].neighbor_size = 0;
}
}
cudaDeviceSynchronize();
}
NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int64_t* key,
int sample_size,
int len) {
/*
comment 2
this function shares some kernels with heter_comm_inl.h
arguments definitions:
gpu_id:the id of gpu.
len:how many keys are used,(the length of array key)
sample_size:how many neighbors should be sampled for each node in key.
the code below shuffle the key array to make the keys
that belong to a gpu-card stay together,
the shuffled result is saved on d_shard_keys,
if ith element in d_shard_keys_ptr is
from jth element in the original key array, then idx[i] = j,
idx could be used to recover the original array.
if keys in range [a,b] belong to ith-gpu, then h_left[i] = a, h_right[i] =
b,
if no keys are allocated for ith-gpu, then h_left[i] == h_right[i] == -1
for example, suppose key = [0,1,2,3,4,5,6,7,8], gpu_num = 2
when we run this neighbor_sample function,
the key is shuffled to [0,2,4,6,8,1,3,5,7]
the first part (0,2,4,6,8) % 2 == 0,thus should be handled by gpu 0,
the rest part should be handled by gpu1, because (1,3,5,7) % 2 == 1,
h_left = [0,5],h_right = [4,8]
*/
NeighborSampleResult* result = new NeighborSampleResult(sample_size, len);
if (len == 0) {
return result;
}
cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_id, 0);
int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT
auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
//
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
fill_shard_key<<<grid_size, block_size_, 0, stream>>>(d_shard_keys_ptr, key,
d_idx_ptr, len);
cudaStreamSynchronize(stream);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
/*
comment 3
shard_len denotes the size of keys on i-th gpu here,
when we sample on i-th gpu, we allocate shard_len * (1 + sample_size)
int64_t units
of memory, we use alloc_mem_i to denote it, the range [0,shard_len) is saved
for the respective nodes' indexes
and acutal sample_size.
with nodes' indexes we could get the nodes to sample.
since size of int64_t is 8 bits, while size of int is 4,
the range of [0,shard_len) contains shard_len * 2 int uinits;
The values of the first half of this range will be updated by
the k-v map on i-th-gpu.
The second half of this range is saved for actual sample size of each node.
For node x,
its sampling result is saved on the range
[shard_len + sample_size * x,shard_len + sample_size * x +
actual_sample_size_of_x)
of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x]
*/
create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t));
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
// auto& node = path_[gpu_id][i].nodes_.back();
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// use the key-value map to update alloc_mem_i[0,shard_len)
tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
reinterpret_cast<int*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
// cudaStreamSynchronize(resource_->remote_stream(i, num));
// tables_[i]->rwlock_->UNLock();
platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto& node = path_[gpu_id][i].nodes_.front();
auto shard_len = h_right[i] - h_left[i] + 1;
auto graph = gpu_graph_list[i];
int* res_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = res_array + shard_len;
int64_t* sample_array = (int64_t*)(res_array + shard_len * 2);
neighbor_sample_example<<<grid_size, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>(
graph, res_array, actual_size_array, sample_array, sample_size,
shard_len);
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
tables_[i]->rwlock_->UNLock();
}
// walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
destroy_storage(gpu_id, i);
}
return result;
}
NodeQueryResult* GpuPsGraphTable::graph_node_sample(int gpu_id,
int sample_size) {}
NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start,
int query_size) {
NodeQueryResult* result = new NodeQueryResult();
if (query_size <= 0) return result;
int& actual_size = result->actual_sample_size;
actual_size = 0;
cudaMalloc((void**)&result->val, query_size * sizeof(int64_t));
int64_t* val = result->val;
int dev_id = resource_->dev_id(gpu_id);
platform::CUDADeviceGuard guard(dev_id);
std::vector<int> idx, gpu_begin_pos, local_begin_pos, sample_size;
int size = 0;
/*
if idx[i] = a, gpu_begin_pos[i] = p1,
gpu_local_begin_pos[i] = p2;
sample_size[i] = s;
then on gpu a, the nodes of positions [p1,p1 + s) should be returned
and saved from the p2 position on the sample_result array
for example:
suppose
gpu 0 saves [0,2,4,6,8], gpu1 saves [1,3,5,7]
start = 3, query_size = 5
we know [6,8,1,3,5] should be returned;
idx = [0,1]
gpu_begin_pos = [3,0]
local_begin_pos = [0,3]
sample_size = [2,3]
*/
for (int i = 0; i < gpu_graph_list.size() && query_size != 0; i++) {
auto graph = gpu_graph_list[i];
if (graph.node_size == 0) {
continue;
}
if (graph.node_size + size > start) {
int cur_size = min(query_size, graph.node_size + size - start);
query_size -= cur_size;
idx.emplace_back(i);
gpu_begin_pos.emplace_back(start - size);
local_begin_pos.emplace_back(actual_size);
start += cur_size;
actual_size += cur_size;
sample_size.emplace_back(cur_size);
create_storage(gpu_id, i, 1, cur_size * sizeof(int64_t));
}
size += graph.node_size;
}
for (int i = 0; i < idx.size(); i++) {
int dev_id_i = resource_->dev_id(idx[i]);
platform::CUDADeviceGuard guard(dev_id_i);
auto& node = path_[gpu_id][idx[i]].nodes_.front();
int grid_size = (sample_size[i] - 1) / block_size_ + 1;
node_query_example<<<grid_size, block_size_, 0,
resource_->remote_stream(idx[i], gpu_id)>>>(
gpu_graph_list[idx[i]], gpu_begin_pos[i], sample_size[i],
(int64_t*)node.val_storage);
}
for (int i = 0; i < idx.size(); i++) {
cudaStreamSynchronize(resource_->remote_stream(idx[i], gpu_id));
auto& node = path_[gpu_id][idx[i]].nodes_.front();
cudaMemcpyAsync(reinterpret_cast<char*>(val + local_begin_pos[i]),
node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < idx.size(); i++) {
auto& node = path_[gpu_id][idx[i]].nodes_.front();
cudaStreamSynchronize(node.out_stream);
}
return result;
}
}
};
#endif
...@@ -173,16 +173,18 @@ class HeterComm { ...@@ -173,16 +173,18 @@ class HeterComm {
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
ValType* src_val); ValType* src_val);
private: protected:
using Table = HashTable<KeyType, ValType>; using Table = HashTable<KeyType, ValType>;
int block_size_{256};
float load_factor_{0.75};
std::vector<Table*> tables_; std::vector<Table*> tables_;
std::shared_ptr<HeterPsResource> resource_; std::shared_ptr<HeterPsResource> resource_;
CustomGradMerger merger_;
int topo_aware_{0};
std::vector<std::vector<Path>> path_; std::vector<std::vector<Path>> path_;
float load_factor_{0.75};
int block_size_{256};
private:
std::vector<LocalStorage> storage_; std::vector<LocalStorage> storage_;
CustomGradMerger merger_;
int topo_aware_{0};
int feanum_{1800 * 2048}; int feanum_{1800 * 2048};
int multi_node_{0}; int multi_node_{0};
std::vector<ncclComm_t> nccl_inner_comms_; std::vector<ncclComm_t> nccl_inner_comms_;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
TEST(TEST_FLEET, graph_comm) {
int gpu_count = 3;
std::vector<int> dev_ids;
dev_ids.push_back(0);
dev_ids.push_back(1);
dev_ids.push_back(2);
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids);
resource->enable_p2p();
GpuPsGraphTable g(resource);
int node_count = 10;
std::vector<std::vector<int64_t>> neighbors(node_count);
int ind = 0;
int64_t node_id = 0;
std::vector<GpuPsCommGraph> graph_list(gpu_count);
while (ind < node_count) {
int neighbor_size = ind + 1;
graph_list[ind % gpu_count].node_size++;
graph_list[ind % gpu_count].neighbor_size += neighbor_size;
while (neighbor_size--) {
neighbors[ind].push_back(node_id++);
}
ind++;
}
std::vector<int> neighbor_offset(gpu_count, 0), node_index(gpu_count, 0);
for (int i = 0; i < graph_list.size(); i++) {
graph_list[i].node_list = new GpuPsGraphNode[graph_list[i].node_size];
graph_list[i].neighbor_list = new int64_t[graph_list[i].neighbor_size];
}
for (int i = 0; i < node_count; i++) {
ind = i % gpu_count;
graph_list[ind].node_list[node_index[ind]].node_id = i;
graph_list[ind].node_list[node_index[ind]].neighbor_offset =
neighbor_offset[ind];
graph_list[ind].node_list[node_index[ind]].neighbor_size =
neighbors[i].size();
for (auto x : neighbors[i]) {
graph_list[ind].neighbor_list[neighbor_offset[ind]++] = x;
}
node_index[ind]++;
}
g.build_graph_from_cpu(graph_list);
/*
gpu 0:
0,3,6,9
gpu 1:
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
int64_t answer[6] = {6, 9, 1, 4, 7, 2};
int64_t *res = new int64_t[6];
auto query_res = g.query_node_list(0, 2, 6);
cudaMemcpy(res, query_res->val, 48, cudaMemcpyDeviceToHost);
ASSERT_EQ(query_res->actual_sample_size, 6);
for (int i = 0; i < 6; i++) {
ASSERT_EQ(res[i], answer[i]);
}
delete[] res;
delete query_res;
/*
node x's neighbor list = [(1+x)*x/2,(1+x)*x/2 + 1,.....,(1+x)*x/2 + x]
so node 6's neighbors are [21,22...,27]
node 7's neighbors are [28,29,..35]
node 0's neighbors are [0]
query([7,0,6],sample_size=3) should return [28,29,30,0,x,x,21,22,23]
6 --index-->2
0 --index--->0
7 --index-->2
*/
int64_t cpu_key[3] = {7, 0, 6};
void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
res = new int64_t[9];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23};
for (int i = 0; i < 9; i++) {
if (expected_sample_val[i] != -1) {
ASSERT_EQ(res[i], expected_sample_val[i]);
}
}
delete[] res;
delete neighbor_sample_res;
}
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
...@@ -376,47 +377,101 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -376,47 +377,101 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name)); attr_name));
} }
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::Scalar))) {
if (ctx->HasAttr(attr_name)) {
// TODO(chentianyu03): support other attrs later
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(float, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"InferMetaContext.",
attr_name));
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);
if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarFromVar(*var)));
} else {
phi::Scalar tensor_scalar(-1);
tensor_scalar.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input.size() when cast op attribute `%s` to Scalar, "
"expected 1, but actually is %d .",
attr_name, infershape_input.size()));
}
}
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr. // Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) == std::type_index(typeid(int))) { } else if (attr_defs[i].type_index == std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
std::type_index(typeid(int64_t))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<bool>))) { std::type_index(typeid(std::vector<bool>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr)); BOOST_GET_CONST(std::vector<bool>, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr)); BOOST_GET_CONST(std::vector<int>, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr( if (std::type_index(attr.type()) ==
BOOST_GET_CONST(std::vector<int64_t>, attr)); std::type_index(typeid(std::vector<int>))) {
} else if (std::type_index(attr.type()) == // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
infer_meta_context.EmplaceBackAttr(vector_int64_attr);
} else {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) { std::type_index(typeid(std::vector<float>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr)); BOOST_GET_CONST(std::vector<float>, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<double>))) { std::type_index(typeid(std::vector<double>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr)); BOOST_GET_CONST(std::vector<double>, attr));
} else if (std::type_index(attr.type()) == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) { std::type_index(typeid(std::vector<std::string>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr)); BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
infer_meta_context.EmplaceBackAttr(data_type);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported attribute type is received when call " "Unsupported attribute type is received when call "
......
...@@ -118,7 +118,7 @@ REGISTER_OPERATOR(infer_shape_utils_test, ...@@ -118,7 +118,7 @@ REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOpMaker, paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor); InferShapeUtilsTestInferShapeFunctor);
PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT, PD_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT,
paddle::framework::InferShapeUtilsTestKernel, int) {} paddle::framework::InferShapeUtilsTestKernel, int) {}
TEST(InferShapeUtilsTest, ALL) { TEST(InferShapeUtilsTest, ALL) {
......
...@@ -147,7 +147,7 @@ if(WITH_IPU) ...@@ -147,7 +147,7 @@ if(WITH_IPU)
pass_library(ipu_runtime_replacer_pass base DIR ipu) pass_library(ipu_runtime_replacer_pass base DIR ipu)
pass_library(inference_process_pass base DIR ipu) pass_library(inference_process_pass base DIR ipu)
pass_library(inference_postprocess_pass base DIR ipu) pass_library(inference_postprocess_pass base DIR ipu)
pass_library(popart_canonicalization_pass base DIR ipu) pass_library(popart_canonicalization_pass base DIR ipu DEPS paddle_ipu)
pass_library(ipu_inplace_pass base DIR ipu) pass_library(ipu_inplace_pass base DIR ipu)
pass_library(infer_shape_pass base DIR ipu) pass_library(infer_shape_pass base DIR ipu)
pass_library(delete_scale_op_pass base DIR ipu) pass_library(delete_scale_op_pass base DIR ipu)
......
...@@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() { ...@@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op; return op;
} }
PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = { const std::unordered_set<std::string> &supported_op_types = {
"abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"}; "abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"};
......
...@@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase { ...@@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation // Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm // supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase { struct MKLDNNInPlace : public PatternBase {
......
...@@ -56,7 +56,7 @@ const bool is_regularization_op(const std::string& op_namescope) { ...@@ -56,7 +56,7 @@ const bool is_regularization_op(const std::string& op_namescope) {
} }
void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
// 这里构建的 op 符合 popart 的定义, 涉及到的一些值需要在 LowerOptimier 时获得 // optimizer values will be extracted when lowering optimizer in ipu_backend
OpDesc new_op("popart_optimizer", {}, {}, {}); OpDesc new_op("popart_optimizer", {}, {}, {});
new_op.SetAttr("op_role", 0); new_op.SetAttr("op_role", 0);
new_op.SetAttr("with_lr_sched", false); new_op.SetAttr("with_lr_sched", false);
...@@ -86,7 +86,7 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { ...@@ -86,7 +86,7 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
bool is_regularization = is_regularization_op(op_namescope); bool is_regularization = is_regularization_op(op_namescope);
VLOG(10) << "found optimizer releated op: " << op_type; VLOG(10) << "found optimizer releated op: " << op_type;
// initial larning_rate will be set in LowerOptimier // initial larning_rate will be set in ipu_backend
set_ops.insert(op_type); set_ops.insert(op_type);
if (op_type == "sgd") { if (op_type == "sgd") {
auto type = std::string{"sgd"}; auto type = std::string{"sgd"};
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h" #include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" #include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
...@@ -28,11 +29,8 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { ...@@ -28,11 +29,8 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
auto custom_ops = Get<std::unordered_set<std::string>>("custom_ops"); auto custom_ops = Get<std::unordered_set<std::string>>("custom_ops");
std::vector<std::string> missing_ops; std::vector<std::string> missing_ops;
auto nodes = graph->Nodes(); auto sorted_ops = TopologySortOperations(*graph);
for (auto* node : nodes) { for (auto* node : sorted_ops) {
if (!node->IsOp()) {
continue;
}
auto* op = node->Op(); auto* op = node->Op();
auto op_type = op->Type(); auto op_type = op->Type();
......
...@@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) { ...@@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) {
} }
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) { int& quantize_counter) {
std::vector<std::string> input_names; std::vector<std::string> input_names;
// Find the name of the input linking op to op_in // Find the name of the input linking op to op_in
...@@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, ...@@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
IR_NODE_LINK_TO(op_in, quantize_op); IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node); IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op); IR_NODE_LINK_TO(quantize_out_node, op);
(*quantize_counter)++; quantize_counter++;
} }
void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) {
auto inputs = op->inputs; auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1, PADDLE_ENFORCE_GE(inputs.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { ...@@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
IR_NODE_LINK_TO(inputs[i], quantize_op); IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op); IR_NODE_LINK_TO(quantize_out_nodes[i], op);
(*quantize_counter)++; quantize_counter++;
} }
op->Op()->SetInput("X", quantize_out_node_names); op->Op()->SetInput("X", quantize_out_node_names);
...@@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { ...@@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
// Operators like Concat and Sum have a single input name X, which actually // Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find // consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops. // pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(), patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"}; "duplicated_inputs"};
...@@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { ...@@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
// Adding quantize ops before all operators except Concat and Sum, which have // Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs // already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"}; "first_bfloat16_ops"};
...@@ -169,60 +169,134 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { ...@@ -169,60 +169,134 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0; int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter); AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter); AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op", PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter); quantize_counter);
} }
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
if (op->Op()->Type() == "prior_box") return;
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);
if (output_names.empty()) return;
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(), outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));
OpDesc deq_desc;
deq_desc.SetType("dequantize");
std::vector<Node*> dequantize_in_nodes(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name();
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetAttr("Scale", 1.f);
deq_desc.SetAttr("Shift", 0.0f);
deq_desc.SetAttr("bfloat16", true);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
dequantize_counter++;
}
op->Op()->SetOutput("Out", dequantize_in_node_names);
}
// Operators like split have a single output name Out, which actually
// consists of multiple outputs. Such operators require a different way to find
// pattern and add dequantize ops.
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}
// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"}; "last_bfloat16_ops"};
bfloat16_ops(); bfloat16_ops();
int dequantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "prior_box") { if (op->Op()->Type() != "split") {
// Find the name of the output linking op to op_out AddDequantize(g, op, op_out, dequantize_counter);
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);
if (output_names.empty()) return;
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
}
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
int dequantize_counter = 0;
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op", PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter); dequantize_counter);
} }
......
...@@ -45,7 +45,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -45,7 +45,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "concat" || type == "sum") { } else if (type == "concat" || type == "sum" || type == "split") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
...@@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) { ...@@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true; bool use_mkldnn = true;
int quant_op = 3; int quant_op = 3;
int dequant_op = 3; int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2; int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes); MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
} }
...@@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) { ...@@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
int quant_op = 4; int quant_op = 4;
int dequant_op = 3; int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2; int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op, MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes); added_nodes);
...@@ -164,11 +166,35 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) { ...@@ -164,11 +166,35 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
int quant_op = 5; int quant_op = 5;
int dequant_op = 3; int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2; int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op, MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes); added_nodes);
} }
ProgramDesc BuildProgramDescDuplicatedOutput(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c", "d"}, use_mkldnn, "bfloat16");
SetOp(&prog, "transpose2", "Transpose", {"c"}, {"e"}, use_mkldnn, "float32");
SetOp(&prog, "reshape2", "Reshape", {"d"}, {"f"}, use_mkldnn, "bfloat16");
return prog;
}
TEST(CpuBfloat16Pass, duplicated_output_ops) {
bool use_mkldnn = true;
int quant_op = 2;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedOutput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}
ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
...@@ -190,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) { ...@@ -190,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
int quant_op = 3; int quant_op = 3;
int dequant_op = 3; int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2; int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op, MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes); added_nodes);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册