提交 f0d48e2a 编写于 作者: L lujun

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into hot-fix-utest

......@@ -47,34 +47,37 @@ find_package(Threads REQUIRED)
include(simd)
################################ Configurations #######################################
################################ Exposed Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_ANAKIN "Compile with Anakin library" OFF)
################################ Internal Configurations #######################################
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_NGRAPH "Compile PaddlePaddle with nGraph support." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF)
option(WITH_JEMALLOC "Compile PaddlePaddle with jemalloc" OFF)
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF)
option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(WITH_HIGH_LEVEL_API_TEST "Test fluid python high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
# PY_VERSION
if(NOT PY_VERSION)
......@@ -194,9 +197,14 @@ if(WITH_GPU)
include(anakin_subgraph)
endif()
if(WITH_GPU AND NOT WIN32)
if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
set(WITH_DGC OFF)
endif()
if(WITH_DGC)
message(STATUS "add dgc lib.")
include(external/dgc)
add_definitions(-DPADDLE_WITH_DGC)
endif()
if(WITH_MKL OR WITH_MKLML)
......
......@@ -221,6 +221,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=lib
-DBUILD_SHARED_LIBS=OFF
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR}
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
......@@ -131,15 +131,6 @@ elseif (NOT CBLAS_FOUND OR WIN32)
)
endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib
......
此差异已折叠。
......@@ -72,7 +72,6 @@ bool DataFeed::PickOneFile(std::string* filename) {
}
VLOG(3) << "file_idx_=" << *file_idx_;
*filename = filelist_[(*file_idx_)++];
// LOG(ERROR) << "pick file:" << *filename;
return true;
}
......@@ -242,6 +241,11 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
#ifdef _LINUX
......@@ -361,8 +365,13 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_);
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
for (auto& vec : send_vec) {
vec.reserve(fleet_send_batch_size_);
vec.reserve(reserve_len);
}
for (int i = 0; i < trainer_num_; ++i) {
send_index[i] = i;
}
std::vector<std::future<int32_t>> total_status;
auto interval = GetMemoryDataInterval();
......@@ -375,7 +384,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_vec.size(); ++j) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
......@@ -388,7 +400,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
}
}
}
for (int j = 0; j < send_vec.size(); ++j) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (send_vec[j].size() != 0) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
......@@ -450,6 +465,17 @@ void MultiSlotDataFeed::Init(
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
// for batch size holder if is_dense
if (slot.shape(0) > 0) {
local_shape.push_back(0);
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
......@@ -505,7 +531,7 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
char* endptr = const_cast<char*>(str);
int len = line.length();
for (size_t i = 0; i < all_slots_.size(); ++i) {
int num = strtol(endptr, &endptr, 10);
auto num = strtol(endptr, &endptr, 10);
if (num < 0) {
VLOG(0) << "error: the number of ids is a negative number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<"
......@@ -736,8 +762,8 @@ void MultiSlotDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
use_slots_shape_[i][0] = batch_size_;
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
......@@ -769,6 +795,16 @@ void MultiSlotInMemoryDataFeed::Init(
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
if (slot.shape(0) > 0) {
local_shape.push_back(0);
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
......@@ -924,8 +960,8 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
use_slots_shape_[i][0] = batch_size_;
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
......
......@@ -94,6 +94,8 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
// This function will do nothing at default
virtual void SetFleetSendBatchSize(int64_t size) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
......@@ -140,6 +142,7 @@ class DataFeed {
// object)
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_
......@@ -212,6 +215,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
......
......@@ -19,6 +19,7 @@ message Slot {
required string type = 2;
optional bool is_dense = 3 [ default = false ];
optional bool is_used = 4 [ default = false ];
repeated int32 shape = 5; // we can define N-D Tensor
}
message MultiSlotDesc { repeated Slot slots = 1; }
......
......@@ -64,6 +64,17 @@ void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
}
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetFleetSendBatchSize
template <typename T>
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
for (auto reader : readers_) {
reader->SetFleetSendBatchSize(size);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
......
......@@ -47,6 +47,8 @@ class Dataset {
virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fleet send batch size
virtual void SetFleetSendBatchSize(int64_t size) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
......@@ -59,6 +61,8 @@ class Dataset {
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get fleet send batch size
virtual int64_t GetFleetSendBatchSize() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
......@@ -98,6 +102,7 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
......@@ -105,6 +110,7 @@ class DatasetImpl : public Dataset {
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
......@@ -137,6 +143,7 @@ class DatasetImpl : public Dataset {
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
int64_t fleet_send_batch_size_;
};
// use std::vector<MultiSlotType> as data type
......
......@@ -24,15 +24,19 @@ if(WITH_DISTRIBUTE)
endif()
endif()
set(all_reduce_deps all_reduce_op_handle)
if(WITH_GPU)
set(dgc_deps "")
if(NOT WIN32)
set(dgc_deps dgc)
endif()
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor ${dgc_deps})
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle)
set(all_reduce_deps sparse_all_reduce_op_handle)
endif()
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
......@@ -80,7 +84,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${all_reduce_deps} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
......
......@@ -17,11 +17,6 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -40,14 +35,11 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
......@@ -62,92 +54,8 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
void AllReduceOpHandle::RunAllReduceFuncs(
const std::vector<std::function<void()>> &all_reduce_calls) {
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
......@@ -178,68 +86,9 @@ void AllReduceOpHandle::RunImplEncoded() {
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
......@@ -300,27 +149,7 @@ void AllReduceOpHandle::RunImplNormal() {
comm, stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaStreamSynchronize(stream);
}
}
RunAllReduceFuncs(all_reduce_calls);
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......
......@@ -28,19 +28,12 @@ namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase {
class AllReduceOpHandle : public OpHandleBase {
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
......@@ -54,18 +47,13 @@ struct AllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
private:
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
void RunAllReduceFuncs(
const std::vector<std::function<void()>> &all_reduce_calls);
const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif
void RunImplNormal();
bool IsEncoded();
};
} // namespace details
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -64,9 +64,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
......@@ -75,9 +78,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
epmap, {}, trainer_id);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
......
......@@ -101,8 +101,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} else {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
......@@ -142,6 +140,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("memory_optimize_pass");
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
if (strategy_.cache_expected_kernel_) {
VLOG(10) << "Add expected_kernel_cache_pass";
AppendPass("expected_kernel_cache_pass");
}
AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) {
......@@ -243,7 +254,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type();
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......@@ -328,3 +339,5 @@ USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
......@@ -83,11 +83,11 @@ struct BuildStrategy {
bool sync_batch_norm_{false};
bool memory_optimize_{true};
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
bool enable_inplace_{true};
// FIXME(liuwei1031) disable memory_optimzie and enable_inplace in 1.4
// to open them by default, we need to solve the fetch variable issue
bool memory_optimize_{false};
bool enable_inplace_{false};
bool enable_sequential_execution_{false};
......@@ -107,6 +107,9 @@ struct BuildStrategy {
std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{true};
bool cache_runtime_context_{false};
bool cache_expected_kernel_{true};
// NOTE:
// Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace framework {
namespace details {
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_u[] = "__dgc_u__";
constexpr char g_dgc_v[] = "__dgc_v__";
constexpr char g_dgc_k[] = "__dgc_k__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,7 +24,7 @@ namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const {
return {"Param", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
void FuseAdamOpPass::FuseOptimizerOps(
......@@ -77,16 +77,16 @@ void FuseAdamOpPass::FuseAdamOps(
VLOG(10) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam");
adam_desc.SetInput("Param", {fused_vars_name.at("Param")});
adam_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
adam_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
adam_desc.SetInput("Moment1", {fused_vars_name.at("Moment1")});
adam_desc.SetInput("Moment2", {fused_vars_name.at("Moment2")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
adam_desc.SetInput("LearningRate", adam_ops[0]->Op()->Input("LearningRate"));
adam_desc.SetInput(kLearningRate, adam_ops[0]->Op()->Input(kLearningRate));
adam_desc.SetInput("Beta1Pow", adam_ops[0]->Op()->Input("Beta1Pow"));
adam_desc.SetInput("Beta2Pow", adam_ops[0]->Op()->Input("Beta2Pow"));
adam_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
adam_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
adam_desc.SetOutput("Moment1Out", {fused_vars_name.at("Moment1")});
adam_desc.SetOutput("Moment2Out", {fused_vars_name.at("Moment2")});
adam_desc.SetAttr("beta1", beta1);
......
......@@ -29,7 +29,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
const std::string fuse_op_type = GetOpType();
const std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
aux_var_names.emplace_back(kParam);
aux_var_names.emplace_back(kGrad);
// Step 1: Get the specified op and auxiliary variables.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
......@@ -61,7 +63,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
result.Set(kFusedVars, new FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size() + 1);
fused_vars_name.reserve(aux_var_names.size());
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
......@@ -75,39 +77,103 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
}
// Step 3: Get the fused Gradient's name
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (!result.Has(kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before this "
"pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name.emplace("Grad", fused_grad);
// Step 4: Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.
PADDLE_ENFORCE(result.Has(kParamsAndGrads), "Does't find kParamsAndGrads.");
PADDLE_ENFORCE_EQ(params_grads.size(), aux_var_set.begin()->second.size(),
"The size of params_grads and aux_var_set are not equal.");
SortParametersAndAuxVars(params_grads, &aux_var_set, &opt_ops);
// Step 5: Alloc continuous space for Parameters and AuxiliaryVar(e.g.
bool grad_fused = false;
if (result.Has(kParamsAndGrads)) {
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
PADDLE_ENFORCE_EQ(
params_grads.size(), aux_var_set.at(kGrad).size(),
"The number of gradients and optimizer ops is not equal.");
std::unordered_set<std::string> opt_grad_set(aux_var_set.at(kGrad).begin(),
aux_var_set.at(kGrad).end());
size_t same_grad_num = 0;
for (auto &p_g : params_grads) {
if (opt_grad_set.count(p_g.second)) {
++same_grad_num;
}
}
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad.
if (same_grad_num == aux_var_set.at(kGrad).size()) {
if (!result.Has(kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before "
"this pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad;
// Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.
SortParametersAndAuxVars(params_grads, &aux_var_set, &opt_ops);
grad_fused = true;
}
}
// Step 4: Alloc continuous space for Parameters and AuxiliaryVar(e.g.
// Moment1, Moment2, Beta1Pow, Beta2Pow) of all the optimizer ops separately.
aux_var_names.pop_back();
if (!grad_fused) {
InitFusedGradsAndAllocSpaceForGrads(
places, local_scopes, aux_var_set.at(kParam), aux_var_set.at(kGrad),
fused_vars_name.at(kGrad), &result);
}
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, aux_var_names,
aux_var_set, fused_vars_name);
// Step 6: Fuse optimizer Ops and Scale Ops
// Step 5: Fuse optimizer Ops and Scale Ops
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_ops, &result);
// Step 7: Remove optimizer Ops
// Step 6: Remove optimizer Ops
for (auto &opt_op : opt_ops) {
graph->RemoveNode(opt_op);
}
}
void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &params,
const std::vector<std::string> &grads, const std::string &fused_grad_name,
ir::Graph *result) const {
// Get Var Nodes
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result->Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
// Init Grads
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
VLOG(10) << "Init " << fused_grad_name;
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
"%s has existed in scope.", fused_grad_name);
scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
for (auto &grad_var_name : grads) {
auto iter = vars.find(grad_var_name);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(grad_var_name)->GetMutable<LoDTensor>();
}
}
// Define Ops
ProgramDesc program_desc;
auto *global_block = program_desc.MutableBlock(0);
AppendAllocContinuousSpace(params, grads, fused_grad_name, global_block,
false, false);
// Run Ops
RunInitOps(places, local_scopes, *global_block);
}
void FuseOptimizerOpPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
......@@ -115,37 +181,49 @@ void FuseOptimizerOpPass::InitFusedVarsAndAllocSpaceForVars(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
for (auto &var_name : aux_var_names) {
auto fused_var_name = fused_vars_name.at(var_name);
VLOG(10) << "Init " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
}
// Init Vars
for (auto &var_name : aux_var_names) {
auto &fused_var_name = fused_vars_name.at(var_name);
InitVars(local_scopes, fused_var_name);
}
// Define Ops
ProgramDesc program_desc;
auto *global_block = program_desc.MutableBlock(0);
for (auto &var_name : aux_var_names) {
AppendAllocContinuousSpace(aux_var_set.at(var_name),
fused_vars_name.at(var_name), true,
global_block);
AppendAllocContinuousSpace(
aux_var_set.at(var_name), aux_var_set.at(var_name),
fused_vars_name.at(var_name), global_block, true);
}
// Run Ops
RunInitOps(places, local_scopes, *global_block);
}
void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const BlockDesc &global_block) const {
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : global_block->AllOps()) {
for (auto &op_desc : global_block.AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void FuseOptimizerOpPass::InitVars(const std::vector<Scope *> &local_scopes,
const std::string &fused_var_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
VLOG(10) << "Init " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
}
}
void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set,
......@@ -203,15 +281,16 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
}
void FuseOptimizerOpPass::AppendAllocContinuousSpace(
const std::vector<std::string> &args, const std::string &out_arg,
bool copy_data, BlockDesc *global_block) const {
const std::vector<std::string> &in_args,
const std::vector<std::string> &out_args, const std::string &fused_out_arg,
BlockDesc *global_block, bool copy_data, bool check_name) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", args);
op_desc->SetOutput("Output", args);
op_desc->SetOutput("FusedOutput", {out_arg});
op_desc->SetInput("Input", in_args);
op_desc->SetOutput("Output", out_args);
op_desc->SetOutput("FusedOutput", {fused_out_arg});
op_desc->SetAttr("copy_data", copy_data);
op_desc->SetAttr("check_name", true);
op_desc->SetAttr("check_name", check_name);
}
void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
......
......@@ -27,6 +27,10 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kGrad[] = "Grad";
constexpr char kParam[] = "Param";
constexpr char kLearningRate[] = "LearningRate";
class FuseOptimizerOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
......@@ -56,9 +60,18 @@ class FuseOptimizerOpPass : public ir::Pass {
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const;
void AppendAllocContinuousSpace(const std::vector<std::string> &args,
const std::string &out_arg, bool copy_data,
BlockDesc *global_block) const;
void AppendAllocContinuousSpace(const std::vector<std::string> &in_args,
const std::vector<std::string> &out_args,
const std::string &fused_out_arg,
BlockDesc *global_block, bool copy_data,
bool check_name = true) const;
void InitFusedGradsAndAllocSpaceForGrads(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &params,
const std::vector<std::string> &grads, const std::string &fused_grad_name,
ir::Graph *result) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
......@@ -68,6 +81,13 @@ class FuseOptimizerOpPass : public ir::Pass {
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name)
const;
void RunInitOps(const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const BlockDesc &global_block) const;
void InitVars(const std::vector<Scope *> &local_scopes,
const std::string &fused_var_name) const;
};
} // namespace details
......
......@@ -24,7 +24,7 @@ namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const {
return {"Param"};
return {};
}
void FuseSgdOpPass::FuseOptimizerOps(
......@@ -50,12 +50,12 @@ void FuseSgdOpPass::FuseSgdOps(
// Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd");
Sgd_desc.SetInput("Param", {fused_vars_name.at("Param")});
Sgd_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
Sgd_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
Sgd_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
Sgd_desc.SetInput("LearningRate", sgd_ops[0]->Op()->Input("LearningRate"));
Sgd_desc.SetInput(kLearningRate, sgd_ops[0]->Op()->Input(kLearningRate));
// NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
......
......@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
if (var_nodes_[in_var_name].back() != in_node) {
VLOG(4) << "SKIP since " << in_var_name
<< " is also used as output by other ops";
continue;
}
bool can_replace = true;
if (in_var_name == out_var_name) {
can_replace = false;
......@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) {
};
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
// avoid optimize the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) update_skip_set(node);
if (node->Name() == "send") update_skip_set(node);
if (node->Name() == "recv") update_skip_set(node);
if (node->Name() == "prefetch") update_skip_set(node);
......
......@@ -34,6 +34,10 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
#if defined(PADDLE_WITH_DGC)
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#endif
namespace paddle {
namespace framework {
namespace details {
......@@ -438,12 +442,22 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
auto append_allreduce_op = [&](
const std::vector<Scope *> &scopes,
const std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_DGC)
if (is_encoded) {
result->Get<GraphOps>(kGraphOps).emplace_back(new SparseAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
} else {
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
}
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
scopes, places, nccl_ctxs_));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
......@@ -561,7 +575,11 @@ void AllReduceSSAGraphBuilder::InsertCollectiveOp(
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
#if defined(PADDLE_WITH_DGC)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
CreateAllReduceOp(result, g_name);
#endif
}
}
......@@ -965,8 +983,9 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
#if defined(PADDLE_WITH_DGC)
bool AllReduceSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + g_dgc_u;
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
......@@ -975,6 +994,11 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
return true;
}
#else
bool AllReduceSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
return false;
}
#endif
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
......@@ -992,11 +1016,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
CreateAllReduceOp(result, g_name);
}
break;
default:
......
......@@ -113,6 +113,8 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
const std::string &g_name) const;
virtual void InsertPostprocessOps(ir::Graph *result) const {}
bool IsEncoded(const std::string &p_name) const;
};
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
......@@ -203,8 +205,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
......@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
}
};
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details
} // namespace framework
......
......@@ -106,7 +106,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
executors_.emplace_back(new details::FastThreadedSSAGraphExecutor(
strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get()));
}
}
......
......@@ -14,12 +14,12 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
......@@ -48,7 +48,8 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<platform::Place> places_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>>
executors_;
ExceptionHolder exception_holder_;
};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#include <algorithm>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(sync_nccl_allreduce);
namespace paddle {
namespace framework {
namespace details {
SparseAllReduceOpHandle::SparseAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs, bool is_encoded, int nranks)
: AllReduceOpHandle(node, local_scopes, places, ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
// TODO(gongwb) :polish them!
if (is_encoded) {
VLOG(1) << "Use dgc allreduce mode";
}
}
void SparseAllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var, "%s should not be null", encode_var_name);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
RunAllReduceFuncs(all_reduce_calls);
}
int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
bool SparseAllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
void SparseAllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
AllReduceOpHandle::RunImpl();
return;
}
RunImplEncoded();
}
std::string SparseAllReduceOpHandle::Name() const {
return "sparse_all_reduce";
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/dgc_const_values.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace framework {
namespace details {
class SparseAllReduceOpHandle : public AllReduceOpHandle {
public:
SparseAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
std::string Name() const override;
protected:
void RunImpl() override;
int GetKValue(const std::string &grad_name);
bool IsEncoded();
void RunImplEncoded();
private:
bool is_encoded_{false};
int nranks_{-1};
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -21,40 +21,40 @@ namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
for (int i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
TableParameter table = param_.sparse_table(i);
sparse_key_names_[table_id].resize(table.sparse_key_name_size());
for (size_t j = 0; j < table.sparse_key_name_size(); ++j) {
for (int j = 0; j < table.sparse_key_name_size(); ++j) {
sparse_key_names_[table_id][j] = table.sparse_key_name(j);
}
sparse_value_names_[table_id].resize(table.sparse_value_name_size());
for (size_t j = 0; j < table.sparse_value_name_size(); ++j) {
for (int j = 0; j < table.sparse_value_name_size(); ++j) {
sparse_value_names_[table_id][j] = table.sparse_value_name(j);
}
sparse_grad_names_[table_id].resize(table.sparse_grad_name_size());
for (size_t j = 0; j < table.sparse_grad_name_size(); ++j) {
for (int j = 0; j < table.sparse_grad_name_size(); ++j) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
}
label_var_name_[table_id] = table.label_var_name();
}
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
for (int i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
dense_value_names_[table_id].resize(table.dense_value_name_size());
for (size_t j = 0; j < table.dense_value_name_size(); ++j) {
for (int j = 0; j < table.dense_value_name_size(); ++j) {
dense_value_names_[table_id][j] = table.dense_value_name(j);
}
dense_grad_names_[table_id].resize(table.dense_grad_name_size());
for (size_t j = 0; j < table.dense_grad_name_size(); ++j) {
for (int j = 0; j < table.dense_grad_name_size(); ++j) {
dense_grad_names_[table_id][j] = table.dense_grad_name(j);
}
}
skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) {
for (int i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i);
}
......@@ -83,14 +83,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label_ptr = tensor->data<int64_t>();
int global_index = 0;
size_t global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
size_t fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
......@@ -138,7 +138,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
for (auto index = 0u; index < len; ++index) {
for (int index = 0; index < len; ++index) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
sizeof(float) * table.emb_dim());
......@@ -192,7 +192,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "program config size: " << param_.program_config_size();
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
......@@ -244,8 +244,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
if (need_to_push_sparse_) {
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
......@@ -268,8 +268,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
if (need_to_push_dense_) {
timeline.Start();
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
......@@ -315,8 +315,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
......@@ -362,7 +362,7 @@ void DownpourWorker::TrainFiles() {
int cur_batch;
while ((cur_batch = device_reader_->Next()) > 0) {
// pull sparse here
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
......@@ -397,8 +397,8 @@ void DownpourWorker::TrainFiles() {
if (need_to_push_sparse_) {
// push gradients here
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
......@@ -416,8 +416,8 @@ void DownpourWorker::TrainFiles() {
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
......@@ -461,8 +461,8 @@ void DownpourWorker::TrainFiles() {
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
......
......@@ -3,3 +3,5 @@ if(WITH_PSLIB)
else()
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_PSLIB)
cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope)
......@@ -237,6 +237,7 @@ void FleetWrapper::PushDenseParamSync(
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel());
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/nccl_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
std::shared_ptr<NCCLWrapper> NCCLWrapper::s_instance_ = NULL;
bool NCCLWrapper::is_initialized_ = false;
void NCCLWrapper::InitNCCL() {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
&(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
nccl_info_.my_global_rank_));
#endif
return;
}
void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
#ifdef PADDLE_WITH_CUDA
nccl_info_.nccl_id_ = nccl_info.nccl_id_;
#endif
return;
}
NCCLInfo NCCLWrapper::GetNCCLId() {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
#endif
return nccl_info_;
}
void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
const int ranks) {
#ifdef PADDLE_WITH_CUDA
nccl_info_.local_rank_ = local_rank;
nccl_info_.my_global_rank_ = global_rank;
nccl_info_.global_ranks_ = ranks;
PADDLE_ENFORCE(cudaSetDevice(local_rank));
PADDLE_ENFORCE(cudaStreamCreate(&(nccl_info_.stream_)));
#endif
return;
}
void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_CUDA
for (auto& name : var_names) {
auto var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int32_t total_size = tensor->numel();
PADDLE_ENFORCE(platform::dynload::ncclBcast(
reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat,
root_rank, nccl_info_.comm_, nccl_info_.stream_));
cudaStreamSynchronize(nccl_info_.stream_);
}
#endif
return;
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
class NCCLInfo {
public:
NCCLInfo() {}
virtual ~NCCLInfo() {}
public:
int local_rank_;
int global_ranks_;
int my_global_rank_;
#ifdef PADDLE_WITH_CUDA
ncclUniqueId nccl_id_;
ncclComm_t comm_;
cudaStream_t stream_;
#endif
};
class NCCLWrapper {
public:
virtual ~NCCLWrapper() {}
NCCLWrapper() {}
void InitNCCL();
void SetNCCLId(const NCCLInfo& nccl_info);
NCCLInfo GetNCCLId();
void SetRankInfo(const int local_rank, const int global_rank,
const int ranks);
void SyncVar(const int root_rank, const Scope& scope,
const std::vector<std::string>& var_names);
static std::shared_ptr<NCCLWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::NCCLWrapper());
}
return s_instance_;
}
public:
NCCLInfo nccl_info_;
private:
static std::shared_ptr<NCCLWrapper> s_instance_;
protected:
static bool is_initialized_;
DISABLE_COPY_AND_ASSIGN(NCCLWrapper);
};
} // end namespace framework
} // end namespace paddle
......@@ -126,7 +126,7 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
if (execl("/bin/bash", "bash", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
......
......@@ -68,6 +68,7 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/expected_kernel_cache_pass.h"
#include <memory>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
void ExpectedKernelCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Expected Kernel Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheExpectedKernel, true);
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(expected_kernel_cache_pass,
paddle::framework::ir::ExpectedKernelCachePass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class ExpectedKernelCachePass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -31,10 +31,10 @@ namespace paddle {
namespace framework {
namespace ir {
namespace {
void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
void SortHelper(const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>,
ir::NodeComp> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
......@@ -50,7 +50,8 @@ void SortHelper(
bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
&adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace,
std::vector<std::vector<ir::Node *>> *circles) {
......@@ -84,7 +85,8 @@ bool HasCircleHelper(
}
bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
&adj_list,
std::vector<std::vector<ir::Node *>> *circles) {
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
......@@ -107,8 +109,8 @@ bool FindCircleSubGraph(const Graph &graph,
}
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph);
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list = BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
......@@ -117,34 +119,30 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
return ret;
}
// Build operator inlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
BuildOperationAdjList(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list;
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
adj_list[n] = std::set<ir::Node *, ir::NodeComp>();
}
std::vector<ir::Node *> nodes;
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
nodes.push_back(adj_n);
adj_list[n].insert(adj_n);
}
}
std::sort(nodes.begin(), nodes.end(), [](ir::Node *node1, ir::Node *node2) {
return node1->id() > node2->id();
});
adj_list[n].insert(std::make_move_iterator(nodes.begin()),
std::make_move_iterator(nodes.end()));
}
return adj_list;
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -25,6 +26,13 @@ namespace paddle {
namespace framework {
namespace ir {
// Compare nodes via node id.
struct NodeComp {
bool operator()(ir::Node *const &node1, ir::Node *const &node2) const {
return node1->id() < node2->id();
}
};
// Test if the graph contains circle.
bool HasCircle(const Graph &graph);
......@@ -57,8 +65,8 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph, SortKind sort_kind);
void CleanIndividualNodes(Graph *graph);
// Build an adjacency list of operations for the `graph`.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
BuildOperationAdjList(const Graph &graph);
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
......
......@@ -23,7 +23,7 @@ namespace ir {
void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}
......
......@@ -241,6 +241,7 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
block_ = nullptr;
}
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
......
......@@ -880,7 +880,16 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
if (!HasAttr(kEnableCacheRuntimeContext)) {
// To reduce the elapsed time of HasAttr, we use bool variable to record the
// result of HasAttr.
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
enable_cache_runtime_context = true;
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
enable_cache_expected_kernel = true;
if (!all_kernels_must_compute_runtime_shape &&
HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape = true;
if (!enable_cache_runtime_context) {
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
} else {
......@@ -899,60 +908,33 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
if (!enable_cache_expected_kernel || !kernel_type_) {
ChooseKernel(*runtime_ctx, scope, place);
}
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
std::vector<KernelConfig>* kernel_configs =
GetKernelConfig(expected_kernel_key);
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = PrepareData(scope, expected_kernel_key,
&transfered_inplace_vars, runtime_ctx);
auto* transfer_scope =
PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx);
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope);
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_);
if (!(kernel_type_->place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(kernel_type_->place_);
}
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
if (!all_kernels_must_compute_runtime_shape) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx);
}
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
*runtime_ctx, kernel_configs));
(*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
kernel_configs));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
......@@ -978,6 +960,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
const Scope& scope,
const platform::Place& place) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
}
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
kernel_type_.reset(new OpKernelType(expected_kernel_key));
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
}
void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const {
......
......@@ -70,6 +70,12 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";
/// If an Op has attribtue kEnableCacheExpectedKernel, it means that in a same
/// name scope and same place, since the expected kerenl of this Op does not
/// change in the execution, it could be recorded only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheExpectedKernel[] = "@ENABLE_CACHE_EXPECTED_KERNEL@";
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
......@@ -491,10 +497,18 @@ class OperatorWithKernel : public OperatorBase {
const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const;
void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
const platform::Place& place) const;
protected:
mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
mutable bool enable_cache_runtime_context = false;
mutable bool enable_cache_expected_kernel = false;
mutable bool all_kernels_must_compute_runtime_shape = false;
};
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -221,7 +221,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!");
graphs.push_back(graph);
for (int i = 1; i < places.size(); ++i) {
for (size_t i = 1; i < places.size(); ++i) {
auto *tmp_graph = new ir::Graph(graph->OriginProgram());
async_graphs_.emplace_back(tmp_graph);
graphs.push_back(tmp_graph);
......@@ -315,7 +315,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
for (int i = 1; i < member_->places_.size(); ++i) {
for (size_t i = 1; i < member_->places_.size(); ++i) {
graphs[i] =
build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1,
......
......@@ -76,7 +76,7 @@ message PullDenseWorkerParameter {
message TableParameter {
// dense table only
optional int64 table_id = 1;
optional uint64 table_id = 1;
repeated string dense_value_name = 2;
repeated string dense_grad_name = 3;
repeated int32 push_dense_wait_times = 5;
......
......@@ -45,12 +45,16 @@ class InferVarTypeContext {
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
auto& inputs = op_->Inputs();
auto input = inputs.find(name);
return input != inputs.end() && !input->second.empty();
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
auto& outputs = op_->Outputs();
auto output = outputs.find(name);
return output != outputs.end() && !output->second.empty();
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
......
......@@ -464,7 +464,11 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetType(type);
if (name == "kLookupTablePath") {
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
} else {
var_set_[name]->SetType(type);
}
}
framework::proto::VarType::Type GetDataType(
......
......@@ -33,8 +33,7 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep,
// creating socket fd
if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0)
PADDLE_THROW("create server fd failed");
if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt,
sizeof(opt)))
if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)))
PADDLE_THROW("set socket opt failed");
address.sin_family = AF_INET;
......
......@@ -14,7 +14,7 @@
#pragma once
// network header files
#ifndef _WIN32
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdlib.h>
......@@ -26,7 +26,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#include "paddle/fluid/platform/place.h"
......
......@@ -49,11 +49,6 @@ set(SHARED_INFERENCE_SRCS
${mkldnn_quantizer_src}
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc)
# FIXME(gongwb): hidden libdgc.a
if(WITH_GPU AND NOT WIN32)
set(fluid_modules ${fluid_modules} dgc)
endif()
if(WIN32)
sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array
analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
......
cc_library(anakin_engine SRCS engine.cc DEPS framework_proto)
cc_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto)
cc_library(anakin_engine SRCS engine.cc DEPS framework_proto boost)
cc_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto boost)
target_link_libraries(anakin_engine anakin anakin_saber_common)
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
add_subdirectory(convert)
......@@ -231,6 +231,7 @@ void AnalysisConfig::Update() {
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
}
pass_builder()->DeletePass("runtime_context_cache_pass");
pass_builder()->DeletePass("expected_kernel_cache_pass");
}
if (use_mkldnn_) {
......
......@@ -259,6 +259,9 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
return false;
}
PADDLE_ENFORCE_NOT_NULL(input_ptr);
PADDLE_ENFORCE_NOT_NULL(inputs[i].data.data());
if (platform::is_cpu_place(place_)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), inputs[i].data.data(),
......@@ -829,6 +832,45 @@ std::string AnalysisPredictor::GetSerializedProgram() const {
return inference_program_->Proto()->SerializeAsString();
}
// Add SaveOptimModel
void AnalysisPredictor::SaveOptimModel(const std::string &dir) {
// save model
std::string model_name = dir + "/model";
std::ofstream outfile;
outfile.open(model_name, std::ios::out | std::ios::binary);
std::string inference_prog_desc = GetSerializedProgram();
outfile << inference_prog_desc;
// save params
framework::ProgramDesc save_program;
auto *save_block = save_program.MutableBlock(0);
const framework::ProgramDesc &main_program = program();
const framework::BlockDesc &global_block = main_program.Block(0);
std::vector<std::string> save_var_list;
for (framework::VarDesc *var : global_block.AllVars()) {
if (IsPersistable(var)) {
framework::VarDesc *new_var = save_block->Var(var->Name());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);
save_var_list.push_back(new_var->Name());
}
}
std::sort(save_var_list.begin(), save_var_list.end());
auto *op = save_block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", save_var_list);
op->SetAttr("file_path", dir + "/params");
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(save_program, scope(), 0, true, true);
}
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<AnalysisConfig>(
const AnalysisConfig &config) {
......
......@@ -86,6 +86,10 @@ class AnalysisPredictor : public PaddlePredictor {
bool MkldnnQuantize();
// save program to model
// save parameters to params
void SaveOptimModel(const std::string &dir);
protected:
// For memory optimization.
bool need_collect_var_shapes_for_memory_optim();
......
......@@ -196,6 +196,9 @@ TEST(AnalysisPredictor, Clone) {
}
}
// This function is not released yet, will fail on some machine.
// TODO(Superjomn) Turn on it latter.
/*
TEST(AnalysisPredictor, memory_optim) {
AnalysisConfig config(FLAGS_dirname);
config.DisableGpu();
......@@ -246,6 +249,7 @@ TEST(AnalysisPredictor, memory_optim) {
inference::CompareResult(output, output1);
}
*/
#ifdef PADDLE_WITH_MKLDNN
class MkldnnQuantizerTest : public testing::Test {
......
......@@ -54,6 +54,7 @@ PaddleBuf &PaddleBuf::operator=(const PaddleBuf &other) {
memory_owned_ = other.memory_owned_;
} else {
Resize(other.length());
PADDLE_ENFORCE(!(other.length() > 0 && other.data() == nullptr));
memcpy(data_, other.data(), other.length());
length_ = other.length();
memory_owned_ = true;
......
......@@ -169,6 +169,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
// Hot fix the bug that result diff in multi-thread.
// TODO(Superjomn) re-implement a real clone here.
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<NativePaddlePredictor *>(cls.get()));
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(nullptr)) {
LOG(ERROR) << "fail to call Init";
return nullptr;
......@@ -210,6 +211,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
return false;
}
PADDLE_ENFORCE_NOT_NULL(input_ptr);
PADDLE_ENFORCE_NOT_NULL(inputs[i].data.data());
if (platform::is_cpu_place(place_)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), inputs[i].data.data(),
......@@ -316,6 +319,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
}
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<NativePaddlePredictor *>(predictor.get()));
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) {
return nullptr;
}
......
......@@ -99,6 +99,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass",
"expected_kernel_cache_pass", //
});
use_gpu_ = true;
......@@ -122,8 +123,8 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// will enhance this pass later.
"runtime_context_cache_pass", //
"attention_lstm_fuse_pass", //
"seqpool_concat_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", //
// "embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
......@@ -136,6 +137,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
"expected_kernel_cache_pass", //
});
use_gpu_ = false;
......
nv_library(tensorrt_engine SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto)
nv_library(tensorrt_engine SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context boost)
nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto boost)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(plugin)
......
......@@ -30,6 +30,7 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--infer_data=${data_dir}/data.bin
--warmup_batch_size=100
--batch_size=50
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2)
endfunction()
......@@ -120,7 +121,8 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 SERIAL)
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} SERIAL)
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
......
......@@ -170,6 +170,15 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->SwitchIrOptim(true);
}
void SetOptimConfig(AnalysisConfig *cfg) {
std::string optimModelPath =
FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) +
"/saved_optim_model";
cfg->SetModel(optimModelPath + "/model", optimModelPath + "/params");
cfg->SwitchIrOptim(true);
cfg->SwitchSpecifyInputNames();
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<PaddleTensor> input_slots;
......@@ -315,5 +324,44 @@ TEST(Analyzer_dam, compare_determine) {
input_slots_all);
}
// Save optim model
TEST(Analyzer_dam, save_optim_model) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::string optimModelPath =
FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) +
"/saved_optim_model";
mkdir(optimModelPath.c_str(), 0777);
auto predictor = CreateTestPredictor(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
FLAGS_use_analysis);
(static_cast<AnalysisPredictor *>(predictor.get()))
->SaveOptimModel(optimModelPath);
}
void CompareOptimAndOrig(const PaddlePredictor::Config *orig_config,
const PaddlePredictor::Config *optim_config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
PrintConfig(orig_config, true);
PrintConfig(optim_config, true);
std::vector<std::vector<PaddleTensor>> orig_outputs, optim_outputs;
TestOneThreadPrediction(orig_config, inputs, &orig_outputs, false);
TestOneThreadPrediction(optim_config, inputs, &optim_outputs, false);
CompareResult(orig_outputs.back(), optim_outputs.back());
}
TEST(Analyzer_dam, compare_optim_orig) {
AnalysisConfig orig_cfg;
AnalysisConfig optim_cfg;
SetConfig(&orig_cfg);
SetOptimConfig(&optim_cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareOptimAndOrig(
reinterpret_cast<const PaddlePredictor::Config *>(&orig_cfg),
reinterpret_cast<const PaddlePredictor::Config *>(&optim_cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -32,6 +32,17 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
SetFakeImageInput(inputs, FLAGS_infer_model);
}
void SetOptimConfig(AnalysisConfig *cfg) {
std::string optimModelPath =
FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) +
"/saved_optim_model";
cfg->SetModel(optimModelPath + "/model", optimModelPath + "/params");
cfg->DisableGpu();
cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
}
// Easy for profiling independently.
void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
......@@ -87,13 +98,51 @@ TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); }
TEST(Analyzer_resnet50, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
// Save optim model
TEST(Analyzer_resnet50, save_optim_model) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::string optimModelPath =
FLAGS_infer_model.substr(0, FLAGS_infer_model.find_last_of("/")) +
"/saved_optim_model";
mkdir(optimModelPath.c_str(), 0777);
auto predictor = CreateTestPredictor(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
FLAGS_use_analysis);
(static_cast<AnalysisPredictor *>(predictor.get()))
->SaveOptimModel(optimModelPath);
}
void CompareOptimAndOrig(const PaddlePredictor::Config *orig_config,
const PaddlePredictor::Config *optim_config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
PrintConfig(orig_config, true);
PrintConfig(optim_config, true);
std::vector<std::vector<PaddleTensor>> orig_outputs, optim_outputs;
TestOneThreadPrediction(orig_config, inputs, &orig_outputs, false);
TestOneThreadPrediction(optim_config, inputs, &optim_outputs, false);
CompareResult(orig_outputs.back(), optim_outputs.back());
}
TEST(Analyzer_resnet50, compare_optim_orig) {
AnalysisConfig orig_cfg;
AnalysisConfig optim_cfg;
SetConfig(&orig_cfg);
SetOptimConfig(&optim_cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareOptimAndOrig(
reinterpret_cast<const PaddlePredictor::Config *>(&orig_cfg),
reinterpret_cast<const PaddlePredictor::Config *>(&optim_cfg),
input_slots_all);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -47,6 +47,7 @@ struct DataRecord {
num_lines++;
std::vector<std::string> data;
split(line, '\t', &data);
PADDLE_ENFORCE(data.size() >= 4);
// load title1 data
std::vector<int64_t> title1_data;
split_to_int64(data[0], ' ', &title1_data);
......
......@@ -150,6 +150,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
if (use_mkldnn) {
cfg->EnableMKLDNN();
}
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time
cfg->pass_builder()->InsertPass(2, "seqpool_concat_fuse_pass");
}
void profile(bool use_mkldnn = false) {
......
......@@ -214,28 +214,23 @@ TEST(Analyzer_Transformer, fuse_statis) {
}
// Compare result of NativeConfig and AnalysisConfig
// void compare(bool use_mkldnn = false) {
// AnalysisConfig cfg;
// SetConfig(&cfg);
// if (use_mkldnn) {
// cfg.EnableMKLDNN();
// }
//
// std::vector<std::vector<PaddleTensor>> input_slots_all;
// SetInput(&input_slots_all);
// CompareNativeAndAnalysis(
// reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
// input_slots_all);
// }
// TODO(yihuaxu):
// Disable compare and compare_mkldnn temporary, see
// https://github.com/paddlePaddle/Paddle/issues/16316 for details.
// TEST(Analyzer_Transformer, compare) { compare(); }
// #ifdef PADDLE_WITH_MKLDNN
// TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */);
// }
// #endif
void compare(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
TEST(Analyzer_Transformer, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace inference
} // namespace paddle
......@@ -51,8 +51,6 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
<< "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n";
os << GenSpaces(num_spaces)
<< "specify_input_name: " << config.specify_input_name << "\n";
os << GenSpaces(num_spaces)
<< "cpu_num_threads: " << config.cpu_math_library_num_threads() << "\n";
num_spaces--;
os << GenSpaces(num_spaces) << "}\n";
return os;
......@@ -72,8 +70,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) {
}
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim()
<< "\n";
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim()
<< "\n";
os << GenSpaces(num_spaces)
<< "cpu_num_threads: " << config.cpu_math_library_num_threads() << "\n";
os << GenSpaces(num_spaces)
<< "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n";
os << GenSpaces(num_spaces)
......
......@@ -19,10 +19,11 @@ import sys
import random
import functools
import contextlib
from PIL import Image, ImageEnhance
from PIL import Image
import math
from paddle.dataset.common import download, md5file
from paddle.dataset.common import download
import tarfile
import StringIO
random.seed(0)
np.random.seed(0)
......@@ -32,9 +33,11 @@ SIZE_FLOAT32 = 4
SIZE_INT64 = 8
FULL_SIZE_BYTES = 30106000008
FULL_IMAGES = 50000
DATA_DIR_NAME = 'ILSVRC2012'
IMG_DIR_NAME = 'var'
TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2'
TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
FOLDER_NAME = "ILSVRC2012/"
VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
CHUNK_SIZE = 8192
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
......@@ -62,8 +65,7 @@ def crop_image(img, target_size, center):
return img
def process_image(img_path, mode, color_jitter, rotate):
img = Image.open(img_path)
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
......@@ -99,26 +101,11 @@ def download_concat(cache_folder, zip_path):
outfile.write(infile.read())
def extract(zip_path, extract_folder):
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
img_dir = os.path.join(data_dir, IMG_DIR_NAME)
print("Extracting...\n")
if not (os.path.exists(img_dir) and
len(os.listdir(img_dir)) == FULL_IMAGES):
tar = tarfile.open(zip_path)
tar.extractall(path=extract_folder)
tar.close()
print('Extracted. Full Imagenet Validation dataset is located at {0}\n'.
format(data_dir))
def print_processbar(done, total):
done_filled = done * '='
empty_filled = (total - done) * ' '
percentage_done = done * 100 / total
def print_processbar(done_percentage):
done_filled = done_percentage * '='
empty_filled = (100 - done_percentage) * ' '
sys.stdout.write("\r[%s%s]%d%%" %
(done_filled, empty_filled, percentage_done))
(done_filled, empty_filled, done_percentage))
sys.stdout.flush()
......@@ -126,15 +113,13 @@ def check_integrity(filename, target_hash):
print('\nThe binary file exists. Checking file integrity...\n')
md = hashlib.md5()
count = 0
total_parts = 50
chunk_size = 8192
onepart = FULL_SIZE_BYTES / chunk_size / total_parts
onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
with open(filename) as ifs:
while True:
buf = ifs.read(8192)
buf = ifs.read(CHUNK_SIZE)
if count % onepart == 0:
done = count / onepart
print_processbar(done, total_parts)
print_processbar(done)
count = count + 1
if not buf:
break
......@@ -146,54 +131,61 @@ def check_integrity(filename, target_hash):
return False
def convert(file_list, data_dir, output_file):
def convert(tar_file, output_file):
print('Converting 50000 images to binary file ...\n')
with open(file_list) as flist:
lines = [line.strip() for line in flist]
num_images = len(lines)
with open(output_file, "w+b") as ofs:
#save num_images(int64_t) to file
ofs.seek(0)
num = np.array(int(num_images)).astype('int64')
ofs.write(num.tobytes())
per_parts = 1000
full_parts = FULL_IMAGES / per_parts
print_processbar(0, full_parts)
for idx, line in enumerate(lines):
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
continue
#save image(float32) to file
img = process_image(
img_path, 'val', color_jitter=False, rotate=False)
np_img = np.array(img)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
idx)
ofs.write(np_img.astype('float32').tobytes())
ofs.flush()
#save label(int64_t) to file
label_int = (int)(label)
np_label = np.array(label_int)
ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 *
num_images + idx * SIZE_INT64)
ofs.write(np_label.astype('int64').tobytes())
ofs.flush()
if (idx + 1) % per_parts == 0:
done = (idx + 1) / per_parts
print_processbar(done, full_parts)
tar = tarfile.open(name=tar_file, mode='r:gz')
print_processbar(0)
dataset = {}
for tarInfo in tar:
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
with open(output_file, "w+b") as ofs:
ofs.seek(0)
num = np.array(int(FULL_IMAGES)).astype('int64')
ofs.write(num.tobytes())
per_percentage = FULL_IMAGES / 100
idx = 0
for imagedata in dataset.values():
img = Image.open(StringIO.StringIO(imagedata))
img = process_image(img)
np_img = np.array(img)
ofs.write(np_img.astype('float32').tobytes())
if idx % per_percentage == 0:
print_processbar(idx / per_percentage)
idx = idx + 1
val_info = tar.getmember(VALLIST_TAR_NAME)
val_list = tar.extractfile(val_info).read()
lines = val_list.split('\n')
val_dict = {}
for line_idx, line in enumerate(lines):
if line_idx == FULL_IMAGES:
break
name, label = line.split()
val_dict[name] = label
for img_name in dataset.keys():
remove_len = (len(FOLDER_NAME))
img_name_prim = img_name[remove_len:]
label = val_dict[img_name_prim]
label_int = (int)(label)
np_label = np.array(label_int)
ofs.write(np_label.astype('int64').tobytes())
print_processbar(100)
tar.close()
print("Conversion finished.")
def run_convert():
print('Start to download and convert 50000 images to binary file...')
cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download')
extract_folder = os.path.join(cache_folder, 'full_data')
data_dir = os.path.join(extract_folder, DATA_DIR_NAME)
file_list = os.path.join(data_dir, 'val_list.txt')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz')
zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz.partaa')
output_file = os.path.join(cache_folder, 'int8_full_val.bin')
retry = 0
try_limit = 3
......@@ -213,8 +205,7 @@ def run_convert():
"Can not convert the dataset to binary file with try limit {0}".
format(try_limit))
download_concat(cache_folder, zip_path)
extract(zip_path, extract_folder)
convert(file_list, data_dir, output_file)
convert(zip_path, output_file)
print("\nSuccess! The binary file can be found at {0}".format(output_file))
......
......@@ -55,6 +55,9 @@ DEFINE_bool(record_benchmark, false,
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DEFINE_double(quantized_accuracy, 1e-2, "Result Quantized Accuracy.");
DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
DEFINE_bool(warmup, false,
"Use warmup to calculate elapsed_time more accurately. "
"To reduce CI time, it sets false in default.");
DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads);
......@@ -367,7 +370,9 @@ void TestOneThreadPrediction(
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<std::vector<PaddleTensor>> *outputs, bool use_analysis = true) {
auto predictor = CreateTestPredictor(config, use_analysis);
PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
if (FLAGS_warmup) {
PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
}
PredictionRun(predictor.get(), inputs, outputs, 1, 0);
}
......@@ -395,7 +400,10 @@ void TestMultiThreadPrediction(
->SetMkldnnThreadID(static_cast<int>(tid) + 1);
}
#endif
PredictionWarmUp(predictor.get(), inputs, &outputs_tid, num_threads, tid);
if (FLAGS_warmup) {
PredictionWarmUp(predictor.get(), inputs, &outputs_tid, num_threads,
tid);
}
PredictionRun(predictor.get(), inputs, &outputs_tid, num_threads, tid);
});
}
......
......@@ -116,7 +116,7 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
reinterpret_cast<const PaddlePredictor::Config*>(&analysis_config);
auto native_pred = CreateTestPredictor(config, false);
auto analysis_pred = CreateTestPredictor(config, true);
for (int i = 0; i < 100; i++) {
for (int i = 0; i < 20; i++) {
std::vector<std::vector<PaddleTensor>> inputs_all;
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
......@@ -133,11 +133,13 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
TEST(TensorRT_mobilenet, compare) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
compare(model_dir, /* use_tensorrt */ true);
// Open it when need.
// profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
}
TEST(TensorRT_resnet50, compare) {
TEST(resnet50, compare_continuous_input) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare(model_dir, /* use_tensorrt */ true);
compare_continuous_input(model_dir, true);
}
TEST(TensorRT_resnext50, compare) {
......@@ -145,24 +147,6 @@ TEST(TensorRT_resnext50, compare) {
compare(model_dir, /* use_tensorrt */ true);
}
TEST(TensorRT_resnext50, profile) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
// Set FLAGS_record_benchmark to true to record benchmark to file.
// FLAGS_record_benchmark=true;
FLAGS_model_name = "resnext50";
profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
}
TEST(resnext50, compare_analysis_native) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
compare(model_dir, false /*use tensorrt*/);
}
TEST(TensorRT_mobilenet, analysis) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
compare(model_dir, false /* use_tensorrt */);
}
TEST(AnalysisPredictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
AnalysisConfig config;
......@@ -180,20 +164,5 @@ TEST(AnalysisPredictor, use_gpu) {
}
}
TEST(TensorRT_mobilenet, profile) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
profile(model_dir, true, false);
}
TEST(resnet50, compare_continuous_input) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare_continuous_input(model_dir, true);
}
TEST(resnet50, compare_continuous_input_native) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare_continuous_input(model_dir, false);
}
} // namespace inference
} // namespace paddle
......@@ -2,6 +2,7 @@ include(ExternalProject)
set(INFERENCE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(INFERENCE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
set(CPU_NUM_THREADS_ON_CI 4 CACHE STRING "Run multi-threads on CI to reduce CI time.")
function(inference_download INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
......
abs
acos
asin
atan
attention_lstm
brelu
conv_shift
cos
cos_sim
dequantize
elu
fc
flatten
fsp
......@@ -21,17 +14,10 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc
fusion_seqpool_concat
fusion_squared_mat_sub
gelu
gru
hard_shrink
hierarchical_sigmoid
leaky_relu
log
logsigmoid
lookup_table
lrn
lstm_unit
lstmp
max_pool2d_with_index
max_pool3d_with_index
maxout
......@@ -39,7 +25,6 @@ modified_huber_loss
nce
pool2d
pool3d
pow
prelu
quantize
rank_loss
......@@ -51,26 +36,10 @@ reduce_sum
requantize
reshape
rnn_memory_helper
round
sequence_softmax
sin
softplus
softshrink
softsign
space_to_depth
spp
square
squared_l2_distance
squared_l2_norm
squeeze
stanh
swish
tanh_shrink
teacher_student_sigmoid_loss
tensor_array_to_tensor
thresholded_relu
transpose
tree_conv
unpool
unsqueeze
warpctc
......@@ -72,7 +72,7 @@ endif()
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
if (WITH_GPU AND NOT WIN32)
if (WITH_DGC)
op_library(dgc_op DEPS dgc)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
......
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"
......@@ -82,6 +85,8 @@ template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -105,6 +112,8 @@ template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -116,6 +125,8 @@ template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename Functor>
......@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX);
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx);
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
......@@ -27,6 +29,25 @@ namespace operators {
using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
std::unique_ptr<std::unordered_set<std::string>> ret(
new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
return ret;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
......@@ -50,26 +71,32 @@ using paddle::framework::Tensor;
} \
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
class OP_NAME##GradMaker \
: public ::paddle::framework::SingleGradOpDescMaker { \
public: \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto* op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
\
op->SetAttrMap(Attrs()); \
\
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
op->SetInput("X", Input("X"));
}
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
}
return op;
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
......@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
auto out_grad_name = framework::GradVarName("Out");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
};
......@@ -199,6 +227,15 @@ $out = \sqrt{x}$
)DOC";
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.
Please make sure input is legal in case of numeric errors.
$out = \frac{1}{\sqrt{x}}$
)DOC";
UNUSED constexpr char AbsDoc[] = R"DOC(
Abs Activation Operator.
......@@ -547,6 +584,7 @@ REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
......@@ -559,78 +597,27 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \
__macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
REGISTER_OPERATOR( \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
......@@ -643,6 +630,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -15,7 +15,8 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
......@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
......@@ -79,9 +79,13 @@ class AffineChannelOp : public framework::OperatorWithKernel {
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dims[0], C);
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(b_dims[0], C);
if (ctx->IsRuntime() || scale_dims[0] > 0) {
PADDLE_ENFORCE_EQ(scale_dims[0], C);
}
if (ctx->IsRuntime() || b_dims[0] > 0) {
PADDLE_ENFORCE_EQ(b_dims[0], C);
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out");
......
......@@ -121,9 +121,11 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2});
Tensor sliced_out = output->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out,
T(0));
}
......@@ -161,8 +163,10 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3});
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2});
Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1),
&sliced_theta_grad, T(0));
......
......@@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
}
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
......@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
......
......@@ -65,11 +65,22 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C);
auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0], C);
PADDLE_ENFORCE_EQ(scale_dim[0], C);
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
......
......@@ -41,9 +41,11 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL,
"The input(Weight) must be a 3D tensor.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The first dimension(batch_size) of input(X) must be "
"equal to the first dimension of the input(Y).");
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) {
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The first dimension(batch_size) of input(X) must be "
"equal to the first dimension of the input(Y).");
}
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
"The second dimension of input(X) must be equal to "
"the second dimension of the input(Weight).");
......
......@@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
framework::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
......
......@@ -49,7 +49,15 @@ class ConcatOp : public framework::OperatorWithKernel {
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += ins[i][j];
if (ctx->IsRuntime()) {
out_dims[axis] += ins[i][j];
} else {
if (ins[i][j] == -1) {
out_dims[axis] = -1;
} else {
out_dims[axis] += ins[i][j];
}
}
} else {
if (ctx->IsRuntime()) {
// check all shape in run time
......
......@@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
"Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same");
int product_x = framework::product(dim_x);
int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) {
PADDLE_ENFORCE_EQ(
product_x, product_y,
"The number of elements in X and Y should be same, %d != %d",
product_x, product_y);
}
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
......
......@@ -81,8 +81,10 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
if (context->IsRuntime()) {
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
}
if (!context->HasInput("X")) {
return;
}
......
......@@ -68,9 +68,14 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
......
......@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The 1st dimension of Input(X) and Input(Y) should "
"be equal.");
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be odd.");
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X).");
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The 1st dimension of Input(X) and Input(Y) should "
"be equal.");
if (ctx->IsRuntime() || y_dims[1] > 0)
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be odd.");
if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X).");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
......
......@@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
PADDLE_ENFORCE(
x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
}
// resize tensor
ctx->SetOutputDim("Out", {x_dims[0], 1});
......
......@@ -95,20 +95,23 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
if (ctx->IsRuntime() || (emission_dims[1] > 0 && transition_dims[1] > 0)) {
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");
}
if (ctx->HasInput("Label")) {
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
if (ctx->IsRuntime() || (emission_dims[0] > 0 && label_dims[0] > 0)) {
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");
}
}
ctx->ShareLoD("Emission", /*->*/ "ViterbiPath");
......
此差异已折叠。
此差异已折叠。
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/data_norm_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -65,9 +66,11 @@ class DataNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Means", {C});
......
......@@ -24,6 +24,7 @@
**/
#include "paddle/fluid/operators/detection/gpc.h"
#include "paddle/fluid/platform/enforce.h"
namespace gpc {
......@@ -689,6 +690,7 @@ static bbox *create_contour_bboxes(gpc_polygon *p) {
gpc_malloc<bbox>(box, p->num_contours * sizeof(bbox),
const_cast<char *>("Bounding box creation"));
PADDLE_ENFORCE_NOT_NULL(box);
/* Construct contour bounding boxes */
for (c = 0; c < p->num_contours; c++) {
......@@ -852,6 +854,7 @@ void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) {
/* Create an extended hole array */
gpc_malloc<int>(extended_hole, (p->num_contours + 1) * sizeof(int),
const_cast<char *>("contour hole addition"));
PADDLE_ENFORCE_NOT_NULL(extended_hole);
/* Create an extended contour array */
gpc_malloc<gpc_vertex_list>(extended_contour,
......@@ -969,6 +972,7 @@ void gpc_polygon_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip,
/* Build scanbeam table from scanbeam tree */
gpc_malloc<double>(sbt, sbt_entries * sizeof(double),
const_cast<char *>("sbt creation"));
PADDLE_ENFORCE_NOT_NULL(sbt);
build_sbt(&scanbeam, sbt, sbtree);
scanbeam = 0;
free_sbtree(&sbtree);
......@@ -1604,6 +1608,7 @@ void gpc_tristrip_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip,
/* Build scanbeam table from scanbeam tree */
gpc_malloc<double>(sbt, sbt_entries * sizeof(double),
const_cast<char *>("sbt creation"));
PADDLE_ENFORCE_NOT_NULL(sbt);
build_sbt(&scanbeam, sbt, sbtree);
scanbeam = 0;
free_sbtree(&sbtree);
......
......@@ -171,8 +171,8 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors, In the second dimension(the channel
given number bounding boxes, this given number, which following will be represented as S,
is specified by the number of anchor clusters in each scale. In the second dimension(the channel
dimension), C should be equal to S * (class_num + 5), class_num is the object
category number of source dataset(such as 80 in coco dataset), so in the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
......@@ -203,7 +203,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
thresh, the confidence score loss of this anchor box will be ignored.
Therefore, the yolov3 loss consist of three major parts, box location loss,
confidence score loss, and classification loss. The L2 loss is used for
confidence score loss, and classification loss. The L1 loss is used for
box coordinates (w, h), and sigmoid cross entropy loss is used for box
coordinates (x, y), confidence score loss and classification loss.
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册