提交 09e2e662 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into imperative_mnist

......@@ -139,10 +139,12 @@ endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
add_definitions("-DPADDLE_CUDA_BINVER=\"60\"")
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"70\"")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
......@@ -150,6 +152,7 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
endif()
include_directories(${CUDA_INCLUDE_DIRS})
......
......@@ -89,6 +89,7 @@ if(CUDNN_FOUND)
if(NOT CUDNN_MAJOR_VERSION)
set(CUDNN_VERSION "???")
else()
add_definitions("-DPADDLE_CUDNN_BINVER=\"${CUDNN_MAJOR_VERSION}\"")
math(EXPR CUDNN_VERSION
"${CUDNN_MAJOR_VERSION} * 1000 +
${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
......
......@@ -32,4 +32,4 @@ endif()
add_dependencies(cub extern_cub)
LIST(APPEND externl_project_dependencies cub)
LIST(APPEND external_project_dependencies cub)
......@@ -28,4 +28,4 @@ endif()
add_dependencies(dlpack extern_dlpack)
LIST(APPEND externl_project_dependencies dlpack)
LIST(APPEND external_project_dependencies dlpack)
......@@ -37,13 +37,12 @@ INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_VERSION "0.9")
SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
SET(NGRAPH_GIT_TAG "v0.10.1")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION})
SET(NGRAPH_SHARED_LIB_NAME libngraph.so)
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
......
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -27,9 +27,10 @@ add_subdirectory(details)
proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor)
if(WITH_GPU)
......@@ -71,13 +72,13 @@ cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits)
cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
......@@ -129,11 +130,9 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
if(WITH_NGRAPH)
if(NOT WIN32)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler ngraph)
endif(NOT WIN32)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler)
endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
......@@ -175,11 +174,7 @@ if(WITH_DISTRIBUTE)
else()
if(WITH_NGRAPH)
if(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
else(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(WITH_NGRAPH)
......@@ -194,9 +189,9 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper)
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer)
endif(WITH_PSLIB)
......
......@@ -15,34 +15,123 @@
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/framework/unroll_array_ops.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T, size_t N>
class Array {
static_assert(N > 0, "The size of array must be larger than 0");
public:
HOSTDEVICE Array() {}
static constexpr size_t kSize = N;
HOSTDEVICE inline Array() {}
HOSTDEVICE explicit Array(const T &val) {
for (size_t i = 0; i < N; ++i) data_[i] = val;
template <typename... Args>
HOSTDEVICE inline explicit Array(const T &val, Args... args) {
static_assert(N == sizeof...(Args) + 1, "Invalid argument");
UnrollVarArgsAssign<T>::Run(data_, val, args...);
}
HOSTDEVICE const T *Get() const { return data_; }
HOSTDEVICE inline void Fill(const T &val) {
UnrollFillConstant<N>::Run(data_, val);
}
HOSTDEVICE T *GetMutable() { return data_; }
HOSTDEVICE inline const T *Get() const { return data_; }
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
HOSTDEVICE inline T *GetMutable() { return data_; }
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
HOSTDEVICE inline T &operator[](size_t i) { return *advance(data_, i); }
// Writing "return data_[i]" would cause compilation warning/error:
// "array subscript is above array bound" in Python 35 CI.
// It seems that it is a false warning of GCC if we do not check the bounds
// of array index. But for better performance, we do not check in operator[]
// like what is in STL. If users want to check the bounds, use at() instead
HOSTDEVICE inline const T &operator[](size_t i) const {
return *advance(data_, i);
}
HOSTDEVICE inline T &at(size_t i) {
#ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
#endif
return (*this)[i];
}
HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
#endif
return (*this)[i];
}
HOSTDEVICE constexpr size_t size() const { return N; }
HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
return UnrollCompare<N>::Run(data_, other.data_);
}
HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
return !(*this == other);
}
private:
template <typename U>
HOSTDEVICE static inline U *advance(U *ptr, size_t i) {
return ptr + i;
}
T data_[N];
};
template <typename T>
class Array<T, 0> {
public:
static constexpr size_t kSize = 0;
HOSTDEVICE inline Array() {}
HOSTDEVICE inline void Fill(const T &val) {}
HOSTDEVICE inline constexpr T *Get() const { return nullptr; }
// Add constexpr to GetMutable() cause warning in MAC
HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t) {
#ifdef __CUDA_ARCH__
static T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE inline const T &operator[](size_t) const {
#ifdef __CUDA_ARCH__
static const T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE inline T &at(size_t i) { return (*this)[i]; }
HOSTDEVICE inline const T &at(size_t i) const { return (*this)[i]; }
HOSTDEVICE constexpr size_t size() const { return 0; }
HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
return true;
}
HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
return false;
}
};
} // namespace framework
} // namespace paddle
......@@ -304,8 +304,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
if (debug) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFilesWithTimer,
workers[thidx].get()));
} else {
threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
}
}
for (auto& th : threads) {
......
......@@ -18,312 +18,159 @@ limitations under the License. */
namespace paddle {
namespace framework {
/// @cond HIDDEN
template <int i>
Dim<i> make_dim(const int64_t* d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1));
}
template <>
Dim<0> make_dim<0>(const int64_t* d) {
return Dim<0>(*d);
}
void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
PADDLE_THROW("Dynamic dimensions must have between [1, 9] dimensions.");
}
}
/// @endcond
DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0));
make_ddim(result, dims.begin(), dims.size());
return result;
return DDim(dims.begin(), dims.size());
}
DDim make_ddim(const std::vector<int64_t>& dims) {
DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size());
return result;
return DDim(dims.data(), dims.size());
}
DDim make_ddim(const std::vector<int>& dims) {
std::vector<int64_t> res(dims.size());
std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); });
return make_ddim(res);
return DDim(dims.data(), dims.size());
}
/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes errors
class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
struct DDimEqualityVisitor {
explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
template <int D>
int64_t& operator()(Dim<D>& dim) const {
return dim[idx_];
inline bool operator()(const Dim<D>& self) const {
return UnrollCompare<D>::Run(self.Get(), d_);
}
private:
int idx_;
const int64_t* d_;
};
class DynamicConstIndexer : public boost::static_visitor<int64_t> {
public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D>
int64_t operator()(const Dim<D>& dim) const {
return dim[idx_];
}
private:
int idx_;
};
/// @endcond
int64_t& DDim::operator[](int idx) {
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
bool DDim::operator==(const DDim& d) const {
return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
}
int64_t DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
}
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
int DDim::size() const { return arity(*this); }
struct DDimPlusVisitor {
explicit DDimPlusVisitor(const int64_t* d1, const int64_t* d2)
: d1_(d1), d2_(d2) {}
bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) {
return false;
} else {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) {
if (v1[i] != v2[i]) {
return false;
}
}
return true;
template <int D>
inline void operator()(Dim<D>& self) const {
UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
}
}
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
const int64_t* d1_;
const int64_t* d2_;
};
return make_ddim(v3);
DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret;
}
DDim DDim::operator*(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
struct DDimMulVisitor {
explicit DDimMulVisitor(const int64_t* d1, const int64_t* d2)
: d1_(d1), d2_(d2) {}
std::vector<int64_t> v3;
assert(v1.size() == v2.size());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
template <int D>
inline void operator()(Dim<D>& self) const {
UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
}
return make_ddim(v3);
const int64_t* d1_;
const int64_t* d2_;
};
DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret;
}
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN
struct VectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
template <typename T>
void operator()(const T& t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<0>& t) {}
};
/// @endcond
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result;
VectorizeVisitor visitor(result);
boost::apply_visitor(visitor, ddim);
std::vector<int64_t> result(DDim::kMaxRank);
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
std::vector<int> result(DDim::kMaxRank);
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
struct ProductVisitor : public boost::static_visitor<int64_t> {
struct ProductVisitor {
template <int D>
int64_t operator()(const Dim<D>& dim) {
inline int64_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
int64_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
return ddim.apply_visitor(ProductVisitor());
}
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
: vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
}
template <int S>
void operator()(const Dim<S>& dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
}
--end;
if (end > 0) {
this->operator()(dim.tail);
}
}
void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
}
};
DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int64_t> vec;
vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end);
boost::apply_visitor(visitor, dim);
return make_ddim(vec);
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
begin, end, dim.size());
// Constructor of DDim would check whether end - begin is valid
return DDim(dim.Get() + begin, end - begin);
}
/// \cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> {
template <int D>
int operator()(Dim<D>) const {
return D;
}
};
/// \endcond
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); }
int arity(const DDim& d) { return d.size(); }
/// \cond HIDDEN
struct DDimPrinter : boost::static_visitor<void> {
struct DDimPrinter {
std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {}
template <typename T>
void operator()(const T& t) {
template <int D>
void operator()(const Dim<D>& t) {
os << t;
}
};
/// \endcond
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDimPrinter printer(os);
boost::apply_visitor(printer, ddim);
ddim.apply_visitor(DDimPrinter(os));
return os;
}
DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, rank))});
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
return strides;
}
DDim stride_numel(const framework::DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
DDim stride_numel(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
return strides;
}
} // namespace framework
......
......@@ -18,62 +18,145 @@ limitations under the License. */
#include <stdexcept>
#include <vector>
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
case (rank): { \
constexpr auto kRank = (rank); \
return (callback); \
}
#define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \
PADDLE_THROW("Invalid rank %d", rank); \
}
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
}
/**
* \brief A dynamically sized dimension.
*
* The number of dimensions must be between [1, 9].
*/
struct DDim {
typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<7>, Dim<8>, Dim<9>>
DDimVar;
DDimVar var;
class DDim {
public:
constexpr static int kMaxRank = 9;
DDim() : rank_(1) { dim_[0] = 0; }
DDim() : var(Dim<1>()) {}
DDim(const DDim& ddim) : dim_() { CopyFrom(ddim); }
DDim(const int* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
DDim(const int64_t* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
template <int D>
explicit DDim(const Dim<D>& in) : var(in) {}
/*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT
UnsafeCast<D>() = in;
}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list)
: DDim(init_list.begin(), init_list.size()) {}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list);
inline DDim& operator=(const DDim& ddim) { return CopyFrom(ddim); }
template <int D>
DDim& operator=(const Dim<D>& in) {
var = in;
inline DDim& operator=(const Dim<D>& dim) {
rank_ = D;
UnsafeCast<D>() = dim;
return *this;
}
int64_t& operator[](int idx);
int64_t operator[](int idx) const;
inline int64_t& operator[](int idx) { return dim_[idx]; }
inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
inline int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
template <typename Visitor>
typename Visitor::result_type apply_visitor(Visitor& visitor) {
return var.apply_visitor(visitor);
typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
Visitor&& visitor) {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
template <typename Visitor>
typename Visitor::result_type apply_visitor(Visitor& visitor) const {
return var.apply_visitor(visitor);
typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
Visitor&& visitor) const {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
DDimVar getVar() { return var; }
bool operator==(const DDim& d) const;
bool operator!=(const DDim& d) const;
DDim operator+(const DDim& d) const;
bool operator==(DDim d) const;
DDim operator*(const DDim& d) const;
bool operator!=(DDim d) const;
inline const int64_t* Get() const { return dim_.Get(); }
DDim operator+(DDim d) const;
inline int64_t* GetMutable() { return dim_.GetMutable(); }
DDim operator*(DDim d) const;
inline int size() const { return rank_; }
private:
template <int D>
inline Dim<D>& UnsafeCast() {
static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<void*>(&dim_);
return *reinterpret_cast<Dim<D>*>(p);
}
template <int D>
inline const Dim<D>& UnsafeCast() const {
static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<const void*>(&dim_);
return *reinterpret_cast<const Dim<D>*>(p);
}
int size() const;
inline DDim& CopyFrom(const DDim& ddim) {
PADDLE_VISIT_DDIM(ddim.rank_, (*this = ddim.UnsafeCast<kRank>()));
}
friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim);
private:
Dim<kMaxRank> dim_;
int rank_;
};
#undef PADDLE_VISIT_DDIM_BASE
#undef PADDLE_VISIT_DDIM
/**
* \brief Make a DDim from std::vector<int64_t>
*
......@@ -92,7 +175,7 @@ DDim make_ddim(const std::vector<int>& dims);
DDim make_ddim(std::initializer_list<int64_t> dims);
int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val);
void set(DDim& dim, int idx, int val); // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim);
......@@ -129,12 +212,3 @@ DDim stride(const DDim& ddim);
DDim stride_numel(const DDim& ddim);
} // namespace framework
} // namespace paddle
namespace boost {
template <typename T>
T get(const paddle::framework::DDim& in) {
return boost::get<T>(in.var);
}
} // namespace boost
......@@ -50,7 +50,7 @@ void AllReduceOpHandle::RunImpl() {
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (NoDummyInputSize() == 1 &&
local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) {
#else
......
......@@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0};
bool use_cuda_{true};
bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100};
size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault};
bool dry_run_{false};
};
......
......@@ -64,20 +64,26 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
++drop_scope_counter_;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
bool stream_end = false;
if (!fetch_tensors.empty()) {
WaitComputationalStreams();
stream_end = true;
}
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
if (!stream_end) {
WaitComputationalStreams();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
drop_scope_counter_ = 0;
}
if (eptr) {
std::rethrow_exception(eptr);
......
......@@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
private:
size_t drop_scope_counter_{0};
......
......@@ -16,332 +16,184 @@
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace framework {
// Statically sized, statically indexed dimension
template <int i>
struct Dim {
static constexpr int dimensions = i;
template <int D>
class Dim : public Array<int64_t, D> {
public:
static_assert(D >= 0, "D must be not less than 0");
template <typename... Args>
HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
static_assert(sizeof...(_tail) == i - 1,
"Dim initialized with the wrong number of parameters");
}
static constexpr int kRank = D;
using BaseClass = Array<int64_t, D>;
HOSTDEVICE
Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
inline Dim(int64_t head, const Dim<D - 1>& tail) {
(*this)[0] = head;
new (this->GetMutable() + 1) Dim<D - 1>(tail);
}
HOSTDEVICE
Dim() : head(0), tail() {}
template <typename... Args>
HOSTDEVICE explicit Dim(int64_t head, Args... args)
: BaseClass(head, args...) {}
/** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */
HOSTDEVICE
Dim(int64_t idx, const Dim<i>& size)
: head(idx % size.head), tail(idx / size.head, size.tail) {}
HOSTDEVICE Dim(int64_t idx, const Dim<D>& size);
/** Construct a Dim with each dimension set to the given index */
HOSTDEVICE
Dim(int64_t idx) : head(idx), tail(idx) {}
HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); }
HOSTDEVICE
bool operator==(const Dim<i>& o) const {
return (head == o.head) && (tail == o.tail);
}
HOSTDEVICE
bool operator!=(const Dim<i>& o) const { return !(*this == o); }
HOSTDEVICE
int64_t& operator[](int idx);
HOSTDEVICE
int64_t operator[](int idx) const;
HOSTDEVICE Dim() = default;
HOST std::string to_string() const;
int64_t head;
Dim<i - 1> tail;
};
// Base case specialization
template <>
struct Dim<0> {
static constexpr int dimensions = 0;
HOSTDEVICE
Dim(int64_t _head) {}
HOSTDEVICE
Dim() {}
HOSTDEVICE
Dim(int idx, const Dim<0>& size) {
#ifndef __CUDA_ARCH__
if (idx > 0) {
throw std::invalid_argument("Index out of range.");
}
#else
PADDLE_ASSERT(idx == 0);
#endif
}
HOSTDEVICE
bool operator==(const Dim<0>& o) const { return true; }
HOSTDEVICE
bool operator!=(const Dim<0>& o) const { return false; }
HOSTDEVICE
int64_t& operator[](int idx);
HOSTDEVICE
int64_t operator[](int idx) const;
};
namespace {
// Helper for accessing Dim classes
template <int i>
struct DimGetter {
// Return a copy if Dim is const
template <typename D>
HOSTDEVICE static int64_t impl(const D& d) {
return DimGetter<i - 1>::impl(d.tail);
}
// Return a reference if Dim is mutable
template <typename D>
HOSTDEVICE static int64_t& impl(D& d) {
return DimGetter<i - 1>::impl(d.tail);
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct FortranOrderIndexingConstructorFunctor {
HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx,
int64_t* out) {
out[kStart] = (*idx) % in[kStart];
(*idx) /= in[kStart];
FortranOrderIndexingConstructorFunctor<kStart + 1, kEnd,
kStart + 1 == kEnd>::Run(in, idx,
out);
}
};
// Eureka! We found the element!
template <>
struct DimGetter<0> {
// Return a copy if Dim is const
template <typename D>
HOSTDEVICE static int64_t impl(const D& d) {
return d.head;
}
// Return a reference if Dim is mutable
template <typename D>
HOSTDEVICE static int64_t& impl(D& d) {
return d.head;
}
template <int kStart, int kEnd>
struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx,
int64_t* out) {}
};
} // namespace detail
template <int D>
HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension");
}
#else
PADDLE_ASSERT(idx >= 0);
#endif
if (idx == 0) {
return dim.head;
}
return indexer(dim.tail, idx - 1);
}
template <>
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(false);
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
int64_t head = 0;
#else
static int64_t head = 0;
#endif
return head;
#endif
}
template <int D>
HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension");
}
#else
PADDLE_ASSERT(idx >= 0);
#endif
if (idx == 0) {
return dim.head;
}
return indexer(dim.tail, idx - 1);
}
template <>
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(false);
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
int64_t head = 0;
#else
static int64_t head = 0;
#endif
return head;
#endif
}
} // namespace
// Static access to constant Dim
template <int i, int l>
HOSTDEVICE int64_t get(const Dim<l>& d) {
return DimGetter<i>::impl(d);
HOSTDEVICE Dim<D>::Dim(int64_t idx, const Dim<D>& size) {
detail::FortranOrderIndexingConstructorFunctor<0, D, D == 0>::Run(
size.Get(), &idx, this->GetMutable());
}
// Static access to mutable Dim
template <int i, int l>
HOSTDEVICE int64_t& get(Dim<l>& d) {
return DimGetter<i>::impl(d);
template <int idx, int D>
HOSTDEVICE inline int64_t get(const Dim<D>& dim) {
return dim[idx];
}
// Dynamic access to constant Dim
template <int l>
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
return indexer(*this, i);
template <int idx, int D>
HOSTDEVICE inline int64_t& get(Dim<D>& dim) { // NOLINT
return dim[idx];
}
// Dynamic access to mutable Dim
template <int l>
HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
return indexer(*this, i);
}
// Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
return indexer(*this, i);
}
// Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
return indexer(*this, i);
}
// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
int i) {
return d[i];
template <int D>
HOSTDEVICE inline int64_t get(const Dim<D>& dim, int idx) {
return dim[idx];
}
// Dynamic access to mutable Dim
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
int i) {
return d[i];
template <int D>
HOSTDEVICE inline int64_t& get(Dim<D>& dim, int idx) { // NOLINT
return dim[idx];
}
// Dot product of two dims
template <int i>
HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
return a.head * b.head + linearize(a.tail, b.tail);
}
// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
return 0;
template <int D>
HOSTDEVICE inline int64_t linearize(const Dim<D>& a, const Dim<D>& b) {
return UnrollProduct<D>::Run(a.Get(), b.Get());
}
// Product of a Dim
template <int i>
HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
return prod * a.head * product(a.tail);
}
// Base case product of a Dim
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod;
template <int D>
HOSTDEVICE inline int64_t product(const Dim<D>& a) {
return UnrollProduct<D>::Run(a.Get());
}
// Is 0 <= idx_i < size_i for all i?
template <int i>
HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
return ((0 <= idx.head) && (idx.head < size.head) &&
contained(idx.tail, size.tail));
}
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct ContainedFunctor {
HOSTDEVICE static inline bool Run(const int64_t* idx, const int64_t* size) {
return (idx[kStart] >= 0 && idx[kStart] < size[kStart]) &&
ContainedFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(idx,
size);
}
};
// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
return true;
template <int kStart, int kEnd>
struct ContainedFunctor<kStart, kEnd, true> {
HOSTDEVICE static constexpr inline bool Run(const int64_t* idx,
const int64_t* size) {
return true;
}
};
} // namespace detail
template <int D>
HOSTDEVICE inline bool contained(const Dim<D>& idx, const Dim<D>& size) {
return detail::ContainedFunctor<0, D, D == 0>::Run(idx.Get(), size.Get());
}
/**
* \brief Compute exclusive prefix-multiply of a Dim.
*/
template <int i>
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
}
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct ExPrefixMulFunctor {
HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) {
kStart == 0 ? out[kStart] = 1 : out[kStart] =
out[kStart - 1] * in[kStart - 1];
detail::ExPrefixMulFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(in,
out);
}
};
template <int kStart, int kEnd>
struct ExPrefixMulFunctor<kStart, kEnd, true> {
HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) {}
};
} // namespace detail
///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
return Dim<0>();
template <int D>
HOSTDEVICE inline Dim<D> ex_prefix_mul(const Dim<D>& src) {
Dim<D> ret;
detail::ExPrefixMulFunctor<0, D, D == 0>::Run(src.Get(), ret.GetMutable());
return ret;
}
///\endcond
/**
* Add two dimensions together
*/
template <int i>
HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
}
// Base case
template <>
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
template <int D>
HOSTDEVICE inline Dim<D> dim_plus(const Dim<D>& a, const Dim<D>& b) {
Dim<D> ret;
UnrollAdd<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
}
template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i>& lhs, const Dim<i>& rhs) {
template <int D>
HOSTDEVICE inline Dim<D> operator+(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_plus(lhs, rhs);
}
/**
* Multiply two dimensions together
*/
template <int i>
HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
}
// Base case
template <>
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
template <int D>
HOSTDEVICE inline Dim<D> dim_mult(const Dim<D>& a, const Dim<D>& b) {
Dim<D> ret;
UnrollMul<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
}
template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
template <int D>
HOSTDEVICE Dim<D> operator*(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_mult(lhs, rhs);
}
......@@ -354,23 +206,32 @@ HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
* \return Dim object the same size as \p size with normalized strides
*
*/
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct NormalizeStridesFunctor {
HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride,
int64_t* ret) {
ret[kStart] = (size[kStart] == 1 ? 0 : stride[kStart]);
NormalizeStridesFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(
size, stride, ret);
}
};
template <int i>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
int norm_stride = size.head == 1 ? 0 : stride.head;
return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
}
///\cond HIDDEN
template <int kStart, int kEnd>
struct NormalizeStridesFunctor<kStart, kEnd, true> {
HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride,
int64_t* ret) {}
};
} // namespace detail
template <>
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
const Dim<0>& stride) {
return Dim<0>();
template <int D>
HOSTDEVICE Dim<D> normalize_strides(const Dim<D>& size, const Dim<D>& stride) {
Dim<D> ret;
detail::NormalizeStridesFunctor<0, D, D == 0>::Run(size.Get(), stride.Get(),
ret.GetMutable());
return ret;
}
///\endcond
/**
* Helper function to create a Dim
*
......@@ -379,25 +240,17 @@ HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
*/
template <typename... Args>
HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
return Dim<sizeof...(Args)>(idxes...);
}
// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
typename std::enable_if<(i > 1), std::ostream&>::type operator<<(
std::ostream& os, const Dim<i>& d) {
os << d.head << ", " << d.tail;
return os;
}
// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
std::ostream& os, const Dim<i>& d) {
os << d.head;
template <int D>
inline std::ostream& operator<<(std::ostream& os, const Dim<D>& d) {
os << d[0];
for (int i = 1; i < D; ++i) {
os << ", " << d[i];
}
return os;
}
......@@ -405,17 +258,15 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os;
}
template <int i>
HOST std::string Dim<i>::to_string() const {
template <int D>
HOST std::string Dim<D>::to_string() const {
std::stringstream stream;
stream << *this;
return stream.str();
}
template <int D>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, const Dim<D>& extents) {
Dim<D> result;
for (int i = 0; i < D - 1; ++i) {
......@@ -428,5 +279,10 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result;
}
template <int D, typename T1, typename T2>
inline void static_dim_assign(const T1* in, T2* out) {
UnrollAssign<D>::Run(in, out);
}
} // namespace framework
} // namespace paddle
......@@ -59,7 +59,7 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CPUPlace &place) const {
DLContext ctx;
::DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
return ctx;
......@@ -67,7 +67,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CUDAPlace &place) const {
#ifdef PADDLE_WITH_CUDA
DLContext ctx;
::DLContext ctx;
ctx.device_type = kDLGPU;
ctx.device_id = place.device;
return ctx;
......@@ -78,7 +78,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CUDAPinnedPlace &place) const {
#ifdef PADDLE_WITH_CUDA
DLContext ctx;
::DLContext ctx;
ctx.device_type = kDLCPUPinned;
ctx.device_id = 0;
return ctx;
......
......@@ -38,7 +38,7 @@ class DLPackTensor {
// The shape in DLTensor is defined as int64_t*
// Add this member to make TVMTensor init without heap allocation
ShapeType shape_[9];
ShapeType shape_[DDim::kMaxRank];
};
} // namespace framework
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace framework {
......@@ -180,6 +181,7 @@ void ExecutorThreadWorker::SetDevice() {
return;
#else
static unsigned concurrency_cap = std::thread::hardware_concurrency();
LOG(WARNING) << "concurrency capacity " << concurrency_cap;
int thread_id = this->thread_id_;
if (static_cast<unsigned>(thread_id) < concurrency_cap) {
......@@ -238,6 +240,55 @@ static void print_fetch_var(Scope* scope, const std::string& var_name) {
VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
}
void ExecutorThreadWorker::TrainFilesWithTimer() {
platform::SetNumThreads(1);
SetDevice();
thread_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
int cur_batch;
int batch_cnt = 0;
timeline.Start();
while ((cur_batch = thread_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) {
timeline.Start();
ops_[i]->Run(*thread_scope_, place_);
timeline.Pause();
op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
++batch_cnt;
thread_scope_->DropKids();
if (thread_id_ == 0) {
if (batch_cnt > 0 && batch_cnt % 1000 == 0) {
for (size_t i = 0; i < ops_.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
int fetch_var_num = fetch_var_names_.size();
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
}
}
}
timeline.Start();
}
}
void ExecutorThreadWorker::TrainFiles() {
platform::SetNumThreads(1);
......@@ -320,10 +371,12 @@ void AsyncExecutorThreadWorker::SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
_pslib_ptr = pslib_ptr;
}
void AsyncExecutorThreadWorker::SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {
_pull_dense_thread = dpt;
}
void AsyncExecutorThreadWorker::TrainOneNetwork() {
PrepareParams();
......
......@@ -155,6 +155,8 @@ class ExecutorThreadWorker {
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
// A multi-thread training function
virtual void TrainFiles();
// with timer log
virtual void TrainFilesWithTimer();
// set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
#ifdef PADDLE_WITH_PSLIB
......
......@@ -399,7 +399,7 @@ void NgraphEngine::BuildNgFunction() {
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::op::ParameterVector func_inputs;
ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo));
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#if !defined(_WIN32)
#include <pthread.h>
#endif // !_WIN32
#else
#include <mutex> // NOLINT
#endif // !_WIN32
#include "paddle/fluid/platform/enforce.h"
......@@ -29,17 +31,17 @@ struct RWLock {
~RWLock() { pthread_rwlock_destroy(&lock_); }
void RDLock() {
inline void RDLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0,
"acquire read lock failed");
}
void WRLock() {
inline void WRLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
"acquire write lock failed");
}
void UNLock() {
inline void UNLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
}
......@@ -51,81 +53,46 @@ struct RWLock {
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
struct RWLock {
void RDLock() {}
void WRLock() {}
void UNLock() {}
// FIXME(minqiyang): use mutex here to do fake lock
inline void RDLock() { mutex_.lock(); }
inline void WRLock() { mutex_.lock(); }
inline void UNLock() { mutex_.unlock(); }
private:
std::mutex mutex_;
};
#endif
class RWLockGuard {
class AutoWRLock {
public:
enum Status { kUnLock, kWRLock, kRDLock };
RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
case Status::kUnLock: {
break;
}
}
}
explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
status_ = Status::kWRLock;
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}
~AutoWRLock() { UnLock(); }
void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
status_ = Status::kRDLock;
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}
private:
inline void Lock() { lock_->WRLock(); }
void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
status_ = Status::kUnLock;
}
}
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
};
class AutoRDLock {
public:
explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
~AutoRDLock() { UnLock(); }
private:
inline void Lock() { lock_->RDLock(); }
~RWLockGuard() { UnLock(); }
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
Status status_;
};
} // namespace framework
......
......@@ -47,9 +47,15 @@ DEFINE_bool(fast_eager_deletion_mode, false,
// the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`.
#ifdef PADDLE_ON_INFERENCE
#define SCOPE_LOCK_GUARD
#define SCOPE_KIDS_READER_LOCK
#define SCOPE_KIDS_WRITER_LOCK
#define SCOPE_VARS_READER_LOCK
#define SCOPE_VARS_WRITER_LOCK
#else
#define SCOPE_LOCK_GUARD std::lock_guard<std::mutex> lock(mutex_);
#define SCOPE_KIDS_READER_LOCK AutoRDLock auto_lock(&kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK AutoWRLock auto_lock(&kids_lock_);
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_);
#endif
namespace paddle {
......@@ -67,64 +73,69 @@ bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
SCOPE_LOCK_GUARD
kids_.push_back(new Scope(this));
return *kids_.back();
Scope* child = new Scope(this);
{
SCOPE_KIDS_WRITER_LOCK
kids_.push_back(child);
}
return *child;
}
Variable* Scope::Var(const std::string& name) {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
return VarInternal(name);
}
Variable* Scope::Var(std::string* name) {
SCOPE_LOCK_GUARD
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = new_name;
}
SCOPE_VARS_WRITER_LOCK
return VarInternal(new_name);
}
Variable* Scope::FindVar(const std::string& name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindVarInternal(name);
}
Variable* Scope::FindLocalVar(const std::string& name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindVarLocally(name);
}
const Scope* Scope::FindScope(const Variable* var) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_READER_LOCK
return FindScopeInternal(var);
}
void Scope::DropKids() {
SCOPE_LOCK_GUARD
SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s;
kids_.clear();
}
bool Scope::HasKid(const Scope* scope) const {
SCOPE_LOCK_GUARD
SCOPE_KIDS_READER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end();
}
std::vector<std::string> Scope::LocalVarNames() const {
SCOPE_LOCK_GUARD
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
{
SCOPE_VARS_READER_LOCK
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
}
return known_vars;
}
void Scope::DeleteScope(Scope* scope) const {
SCOPE_LOCK_GUARD
SCOPE_KIDS_WRITER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "%p Cannot find %p as kid scope",
this, scope);
......@@ -138,8 +149,8 @@ void Scope::DeleteScope(Scope* scope) const {
}
void Scope::EraseVars(const std::vector<std::string>& var_names) {
SCOPE_LOCK_GUARD
std::set<std::string> var_set(var_names.begin(), var_names.end());
SCOPE_VARS_WRITER_LOCK
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
it = vars_.erase(it);
......@@ -151,12 +162,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_LOCK_GUARD
SCOPE_VARS_WRITER_LOCK
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
......
......@@ -14,12 +14,18 @@ limitations under the License. */
#pragma once
extern "C" {
#include <xxhash.h>
}
#include <list>
#include <mutex> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h"
......@@ -95,7 +101,14 @@ class Scope {
std::string Rename(const std::string& origin_name) const;
protected:
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
struct KeyHasher {
std::size_t operator()(const std::string& key) const {
return XXH32(key.c_str(), key.size(), 1);
}
};
mutable std::unordered_map<std::string, std::unique_ptr<Variable>, KeyHasher>
vars_;
private:
// Call Scope::NewScope for a sub-scope.
......@@ -124,7 +137,8 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
mutable RWLock kids_lock_;
mutable RWLock vars_lock_;
};
// Generate some debug string about the inherience structure of scope, quite
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <type_traits>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace framework {
namespace detail {
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollFillConstant {
template <typename T>
HOSTDEVICE inline static void Run(T *data, T val) {
data[kStart] = val;
UnrollFillConstant<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(data, val);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollFillConstant<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(T *data, T val) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollAssign {
template <typename Tin, typename Tout>
HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) {
d2[kStart] = static_cast<Tout>(d1[kStart]);
UnrollAssign<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollAssign<kStart, kEnd, true> {
template <typename Tin, typename Tout>
HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) {}
};
template <typename T, size_t kStart, size_t kEnd, bool kStop>
struct UnrollVarArgsAssignImpl {
template <typename... Args>
HOSTDEVICE inline static void Run(T *d, T val, Args... args) {
static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument");
d[kStart] = val;
UnrollVarArgsAssignImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd>::Run(
d, args...);
}
};
template <typename T, size_t kStart, size_t kEnd>
struct UnrollVarArgsAssignImpl<T, kStart, kEnd, true> {
HOSTDEVICE inline static void Run(T *d) {}
};
template <typename T>
struct UnrollVarArgsAssign {
template <typename... Args>
HOSTDEVICE inline static void Run(T *d, Args... args) {
UnrollVarArgsAssignImpl<T, 0, sizeof...(Args), sizeof...(Args) == 0>::Run(
d, args...);
}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollCompare {
template <typename T>
HOSTDEVICE inline static bool Run(const T *d1, const T *d2) {
return d1[kStart] == d2[kStart] &&
UnrollCompare<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollCompare<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline constexpr static bool Run(const T *d1, const T *d2) {
return true;
}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollAdd {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {
d3[kStart] = d1[kStart] + d2[kStart];
UnrollAdd<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2, d3);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollAdd<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollMul {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {
d3[kStart] = d1[kStart] * d2[kStart];
UnrollMul<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2, d3);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollMul<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollProduct {
template <typename T>
HOSTDEVICE inline static T Run(const T *d) {
return d[kStart] *
UnrollProduct<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d);
}
template <typename T>
HOSTDEVICE inline static T Run(const T *d1, const T *d2) {
return d1[kStart] * d2[kStart] +
UnrollProduct<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollProduct<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline constexpr static T Run(const T *d) {
return 1;
}
template <typename T>
HOSTDEVICE inline constexpr static T Run(const T *d1, const T *d2) {
return 0;
}
};
} // namespace detail
template <size_t N>
using UnrollFillConstant = detail::UnrollFillConstant<0, N, N == 0>;
template <size_t N>
using UnrollAssign = detail::UnrollAssign<0, N, N == 0>;
template <typename T>
using UnrollVarArgsAssign = detail::UnrollVarArgsAssign<T>;
template <size_t N>
using UnrollCompare = detail::UnrollCompare<0, N, N == 0>;
template <size_t N>
using UnrollAdd = detail::UnrollAdd<0, N, N == 0>;
template <size_t N>
using UnrollMul = detail::UnrollMul<0, N, N == 0>;
template <size_t N>
using UnrollProduct = detail::UnrollProduct<0, N, N == 0>;
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/unroll_array_ops.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <array>
#include <cstdint>
namespace paddle {
namespace framework {
template <typename T>
bool CheckEquality(const T* p, size_t n, T val) {
return std::all_of(p, p + n, [val](const T& v) { return v == val; });
}
template <int D1, int D2>
bool FillConstantTestMain() {
static_assert(D1 >= D2, "");
std::array<int, D1> arr;
arr.fill(0);
UnrollFillConstant<D2>::Run(arr.data(), 1);
return CheckEquality(arr.data(), D2, 1) &&
CheckEquality(arr.data() + D2, arr.size() - D2, 0);
}
TEST(unroll_ops, fill_constant) {
EXPECT_TRUE((FillConstantTestMain<9, 0>()));
EXPECT_TRUE((FillConstantTestMain<9, 1>()));
EXPECT_TRUE((FillConstantTestMain<9, 4>()));
EXPECT_TRUE((FillConstantTestMain<9, 9>()));
}
TEST(unroll_ops, assign) {
const int a[] = {1, 2, 3, 4, 5};
int b[] = {0, 0, 0, 0, 0};
UnrollAssign<3>::Run(a, b);
EXPECT_EQ(b[0], 1);
EXPECT_EQ(b[1], 2);
EXPECT_EQ(b[2], 3);
EXPECT_EQ(b[3], 0);
EXPECT_EQ(b[4], 0);
}
TEST(unroll_ops, var_args_assign) {
int a[] = {0, 0, 0};
UnrollVarArgsAssign<int>::Run(a, 1, 2);
EXPECT_EQ(a[0], 1);
EXPECT_EQ(a[1], 2);
EXPECT_EQ(a[2], 0);
}
TEST(unroll_ops, compare) {
int a[] = {1, 2, 3};
int b[] = {1, 2, 4};
EXPECT_TRUE(UnrollCompare<2>::Run(a, b));
EXPECT_FALSE(UnrollCompare<3>::Run(a, b));
b[0] = -1;
EXPECT_TRUE(UnrollCompare<0>::Run(a, b));
EXPECT_FALSE(UnrollCompare<1>::Run(a, b));
}
TEST(unroll_ops, add) {
int a[] = {2, 3, 4};
int b[] = {5, 10, 102};
int c[] = {0, 0, 0};
UnrollAdd<2>::Run(a, b, c);
EXPECT_EQ(a[0] + b[0], c[0]);
EXPECT_EQ(a[1] + b[1], c[1]);
EXPECT_EQ(c[2], 0);
}
TEST(unroll_ops, mul) {
int a[] = {2, 3, 4};
int b[] = {5, 10, 102};
int c[] = {0, 0, 0};
UnrollMul<2>::Run(a, b, c);
EXPECT_EQ(a[0] * b[0], c[0]);
EXPECT_EQ(a[1] * b[1], c[1]);
EXPECT_EQ(c[2], 0);
}
TEST(unroll_ops, product) {
int a[] = {2, 3, 4};
int b[] = {5, 10, 102};
EXPECT_EQ(UnrollProduct<3>::Run(a), a[0] * a[1] * a[2]);
EXPECT_EQ(UnrollProduct<3>::Run(a, b),
a[0] * b[0] + a[1] * b[1] + a[2] * b[2]);
}
} // namespace framework
} // namespace paddle
......@@ -86,8 +86,6 @@ class UnaryLogicalOpInferShape : public framework::InferShapeBase {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
......
......@@ -19,6 +19,10 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
DECLARE_int64(cudnn_exhaustive_search_times);
namespace paddle {
namespace operators {
......@@ -45,6 +49,7 @@ static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
......@@ -54,9 +59,14 @@ class AlgorithmsCache {
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
int search_times_;
};
template <typename TAlgorithm>
......@@ -107,5 +117,29 @@ TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}
} // namespace operators
} // namespace paddle
......@@ -28,6 +28,8 @@ namespace operators {
// x is Input,
// z is ResidualData,
// bias is Bias
// When `split_channels` is set, y will be splitted into multiple outputs,
// each output has split_channels[i] number of channels.
class Conv2DFusionOpMaker : public Conv2DOpMaker {
protected:
void Apply() override {
......@@ -36,8 +38,65 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
"The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
"'relux' , 'tanh', 'band_pass'")
.SetDefault("relu");
AddAttr<std::vector<int>>(
"split_channels",
"When `split_channels` are set, there will be multiple outputs, the "
"output size is equal to the number of `split_channels`.")
.SetDefault({});
AddOutput("Outputs",
"This Outputs is used when setting `split_channels`."
"Usually used to fuse conv with same input and same filter size, "
"padding, stride, dilation size.")
.AsDuplicable()
.AsDispensable();
AddInput("AlgoCache",
"The cache of convolution algorithm, a RAW type variable.")
.AsDispensable();
AddAttr<int>(
"search_times",
"The number of exhaustive search times for convolution algorithm.")
.SetDefault(-1);
}
};
class Conv2DFusionOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of ConvOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
std::vector<int64_t> oshape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
oshape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], strides[i]));
}
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of ConvOp should not be null.");
ctx->SetOutputDim("Output", framework::make_ddim(oshape));
std::vector<int> channels =
ctx->Attrs().Get<std::vector<int>>("split_channels");
if (channels.size()) {
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
"Output(Outputs) of ConvOp should not be null.");
std::vector<framework::DDim> oshapes;
oshapes.reserve(channels.size());
for (size_t i = 0; i < channels.size(); ++i) {
oshapes.push_back({oshape[0], channels[i], oshape[2], oshape[3]});
}
ctx->SetOutputsDim("Outputs", oshapes);
}
}
};
// TODO(qingqing): add gradient operator for conv2d_fusion
} // namespace operators
......@@ -45,4 +104,5 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker);
ops::Conv2DFusionOpInferShape, ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker);
......@@ -16,8 +16,9 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
DEFINE_int64(cudnn_exhaustive_search_times, -1,
"Exhaustive search times for cuDNN convolution, "
"defalut is 1, only search once.");
namespace paddle {
namespace operators {
......@@ -117,41 +118,60 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else {
auto search_func = [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " "
<< stat.memory;
}
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
int search_times = ctx.Attr<int>("search_times");
search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
if (search_times > 0) {
// The searched algo will be cached by `search_times` times for
// different input dimension. For other dimensions, select the algo
// of closest area.
auto var_name = ctx.Inputs("AlgoCache")[0];
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
.FindVar(var_name)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func);
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
// Cache searched algo in Var(kCUDNNFwdAlgoCache).
// all conv ops use the same kCUDNNFwdAlgoCache variable.
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
// TODO(qingqing) remove const_cast
algo_cache =
const_cast<framework::Scope*>(ctx.scope().parent())
->Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, search_func);
}
algo = algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return fwd_perf_stat[0].algo;
});
VLOG(3) << "choose algo " << algo;
}
......@@ -195,6 +215,27 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) {
auto outs = ctx.MultiOutput<framework::Tensor>("Outputs");
if (x_dims[0] == 1) {
// share data with Output
framework::Tensor t;
t.ShareDataWith(*output);
auto y_dims = output->dims();
t.Resize({y_dims[1], y_dims[2], y_dims[3]});
int s = 0;
for (size_t i = 0; i < channels.size(); ++i) {
int e = s + channels[i];
outs[i]->ShareDataWith(t.Slice(s, e));
outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]});
s = e;
}
} else {
// TODO(qingiqng): do copy when batch size large than 1
PADDLE_THROW("Batch size greater than 1 is Unsupported");
}
}
}
};
#endif
......
......@@ -68,7 +68,6 @@ void CropFunction(const framework::ExecutionContext& context) {
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = GetOffsets(context);
int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) {
......
......@@ -147,7 +147,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
->GetMutable<CudnnRNNCache>();
auto input_dims = input->dims();
auto weight_dims = weight->dims();
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();
in_grad->mutable_data<T>(ctx.GetPlace());
......
......@@ -27,8 +27,8 @@ struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 0> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<0> src_stride, framework::Dim<0> dst_dim,
framework::Dim<0> dst_stride, T* dst) const {
const int64_t* src_stride, const int64_t* dst_dim,
const int64_t* dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
......@@ -50,18 +50,18 @@ struct StridedMemcpyFunctor<T, 0> {
template <typename T>
struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<1> src_stride, framework::Dim<1> dst_dim,
framework::Dim<1> dst_stride, T* dst) const {
const int64_t* src_stride, const int64_t* dst_dim,
const int64_t* dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim[0]);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head,
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim[0],
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
......@@ -73,19 +73,19 @@ struct StridedMemcpyFunctor<T, 1> {
template <typename T, int Rank>
struct StridedMemcpyFunctor {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<Rank> src_stride, framework::Dim<Rank> dst_dim,
framework::Dim<Rank> dst_stride, T* dst) const {
for (int64_t i = 0; i < dst_dim.head; ++i) {
const int64_t* src_stride, const int64_t* dst_dim,
const int64_t* dst_stride, T* dst) const {
for (int64_t i = 0; i < dst_dim[0]; ++i) {
StridedMemcpyFunctor<T, Rank - 1> func;
func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst);
src += src_stride.head;
dst += dst_stride.head;
func(dev_ctx, src, src_stride + 1, dst_dim + 1, dst_stride + 1, dst);
src += src_stride[0];
dst += dst_stride[0];
}
}
};
template <typename T>
struct StridedCopyDimVisitor : public boost::static_visitor<void> {
struct StridedCopyDimVisitor {
StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride,
const framework::DDim& dst_stride, T* dst)
......@@ -95,13 +95,11 @@ struct StridedCopyDimVisitor : public boost::static_visitor<void> {
dst_stride_(dst_stride),
dst_(dst) {}
template <typename Dim>
void operator()(Dim dst_dim) const {
Dim src_stride = boost::get<Dim>(src_stride_);
Dim dst_stride = boost::get<Dim>(dst_stride_);
constexpr int dim = Dim::dimensions;
StridedMemcpyFunctor<T, dim> functor;
functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_);
template <int D>
void operator()(const framework::Dim<D>& dst_dim) const {
StridedMemcpyFunctor<T, D> functor;
functor(dev_ctx_, src_, src_stride_.Get(), dst_dim.Get(), dst_stride_.Get(),
dst_);
}
const platform::DeviceContext& dev_ctx_;
......
......@@ -64,8 +64,6 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
"Output(BboxOutsideWeights) of RpnTargetAssignOp should not be null");
auto rpn_rois_dims = ctx->GetInputDim("RpnRois");
auto gt_classes_dims = ctx->GetInputDim("GtClasses");
auto is_crowd_dims = ctx->GetInputDim("IsCrowd");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto im_info_dims = ctx->GetInputDim("ImInfo");
......
......@@ -53,12 +53,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Variances"),
"Input(Variances) shouldn't be null.");
auto scores_dims = ctx->GetInputDim("Scores");
auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas");
auto im_info_dims = ctx->GetInputDim("ImInfo");
auto anchors_dims = ctx->GetInputDim("Anchors");
auto variances_dims = ctx->GetInputDim("Variances");
ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
}
......
......@@ -58,7 +58,6 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto is_crowd_dims = ctx->GetInputDim("IsCrowd");
auto im_info_dims = ctx->GetInputDim("ImInfo");
PADDLE_ENFORCE_EQ(anchor_dims.size(), 2,
"The rank of Input(Anchor) must be 2.");
......
......@@ -7,56 +7,52 @@ if(WITH_GRPC)
else()
set(cc_generic_services "true")
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
configure_file(send_recv.proto.in ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto @ONLY)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
set(GRPC_SRCS grpc/grpc_client.cc grpc/grpc_server.cc grpc/grpc_serde.cc grpc/grpc_bytebuffer_stream.cc grpc/grpc_variable_response.cc)
grpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto
DEPS lod_tensor selected_rows_functor memory)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL)
endif()
cc_test(grpc_serde_test SRCS grpc/grpc_serde_test.cc
DEPS ${RPC_DEPS} scope profiler math_function SERIAL)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else()
set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc/server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
brpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${BRPC_SRCS}
PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto
DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
set(RPC_DEPS sendrecvop_rpc brpc ssl crypto protobuf leveldb snappystream snappy zlib)
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op SERIAL)
endif()
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL)
endif()
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -31,10 +31,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
......
......@@ -14,7 +14,7 @@
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
......
......@@ -20,10 +20,10 @@ limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
......
......@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
......
......@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv {
......
......@@ -19,8 +19,8 @@ limitations under the License. */
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
......
......@@ -23,7 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
......
......@@ -22,7 +22,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline);
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
......
......@@ -21,9 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -52,12 +52,12 @@ std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
framework::Scope* scope = new framework::Scope();
framework::Variable* var = scope->Var("var1");
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(1000);
slr->set_height(20000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({3, 5}));
tensor->Resize(framework::make_ddim({20000, 1024}));
tensor->mutable_data<float>(place);
paddle::operators::math::set_constant(ctx, tensor, 32.7);
......@@ -83,6 +83,7 @@ void Gather(const std::vector<distributed::RemoteVar>& vars,
}
TEST(PREFETCH, GPU) {
setenv("FLAGS_max_body_size", "2147483647", 1);
platform::CUDAPlace place;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......
......@@ -18,15 +18,15 @@
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
#define RPCSERVER_T paddle::operators::distributed::AsyncGRPCServer
#define RPCCLIENT_T paddle::operators::distributed::GRPCClient
#else // PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#define RPCSERVER_T paddle::operators::distributed::AsyncBRPCServer
#define RPCCLIENT_T paddle::operators::distributed::BRPCClient
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#else // PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#endif // PADDLE_WITH_GRPC
#endif // PADDLE_WITH_DISTRIBUTE
......@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
namespace paddle {
namespace operators {
......
......@@ -17,8 +17,8 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h"
......
......@@ -39,10 +39,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
......
......@@ -21,9 +21,9 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/port.h"
......
......@@ -27,8 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
namespace paddle {
namespace operators {
......
......@@ -21,9 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#include <limits>
#include <string>
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter;
......
......@@ -29,11 +29,10 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/grpc_service.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_service.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
......
......@@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow
......
......@@ -19,7 +19,7 @@
#include <nccl.h>
#endif
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -22,13 +22,11 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
......
......@@ -23,7 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
......@@ -18,13 +17,8 @@ package sendrecv;
option cc_generic_services = @cc_generic_services@;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
......@@ -33,19 +27,12 @@ service SendRecvService {
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// LoDTensor
// SelectedRows
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
......@@ -62,21 +49,14 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h"
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
......
......@@ -25,7 +25,7 @@
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/distributed_pb.h"
DECLARE_string(rpc_server_profile_path);
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/profiler.h"
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
......
......@@ -178,7 +178,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
......
......@@ -77,7 +77,6 @@ class ExpandKernel : public framework::OpKernel<T> {
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
auto x_dims = in0->dims();
for (size_t i = 0; i < expand_times.size(); ++i) {
bcast_dims[i] = expand_times[i];
}
......
......@@ -146,7 +146,6 @@ class FCOpKernel : public framework::OpKernel<T> {
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out");
auto in_dims = input->dims();
auto w_dims = w->dims();
auto out_dims = output->dims();
int M = framework::product(out_dims) / out_dims[out_dims.size() - 1];
......
include(operators)
register_operators(EXCLUDES fusion_transpose_flatten_concat_op)
register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op)
if (WITH_GPU)
op_library(fusion_transpose_flatten_concat_op)
op_library(fusion_conv_inception_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n")
endif()
......@@ -241,15 +241,15 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes");
#define INIT_BASE_SIZES \
auto ids_dims = ids->dims(); /* T x M*/ \
auto ids_numel = ids->numel(); /* T x 1*/ \
auto wh_dims = wh->dims(); /* D x 4D*/ \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
int64_t row_number = embeddings->dims()[0]; \
int64_t row_width = embeddings->dims()[1]; \
#define INIT_BASE_SIZES \
auto ids_dims = ids->dims(); /* T x M*/ \
auto ids_numel = framework::product(ids_dims); /* T x 1*/ \
auto wh_dims = wh->dims(); /* D x 4D*/ \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
int64_t row_number = embeddings->dims()[0]; \
int64_t row_width = embeddings->dims()[1]; \
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
class ConvInceptionFusionOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
// 1 x
auto in_dims = ctx->GetInputDim("Input");
// 4 filters
auto w_dims = ctx->GetInputsDim("Filter");
PADDLE_ENFORCE(in_dims.size(), 4, "Conv intput should be 4-D tensor.");
PADDLE_ENFORCE_EQ(w_dims.size(), 4, "There should be 4 filters");
PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1]);
PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1]);
int n = in_dims[0];
// compute output channel
// 1st channel
int c = w_dims[0][0];
// add 2nd channel
c += (w_dims[1][0] - w_dims[2][1] * 2);
// add 3rd channel
c += (w_dims[2][0] - w_dims[3][1]);
// add 4-th channel
c += w_dims[3][0];
int h = in_dims[2];
int w = in_dims[3];
ctx->SetOutputDim("Output", {n, c, h, w});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
}
};
class ConvInceptionFusionOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
void Make() override {
AddInput("Input", "(Tensor) NCHW layout.");
AddInput("Filter", "(vector<Tensor>) 4 aggregated filters").AsDuplicable();
AddInput("Bias", "(vector<Tensor>) it's lenght is equal to Filter")
.AsDuplicable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddOutput("TempOutput", "").AsDuplicable();
AddAttr<std::string>(
"pooling_type",
"(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
AddAttr<bool>(
"exclusive",
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True.")
.SetDefault(true);
AddAttr<std::string>(
"activation",
"The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
"'relux' , 'tanh', 'band_pass'")
.SetDefault("relu");
AddAttr<int>("workspace_size_MB",
"Only used in cudnn kernel. Need set use_cudnn to true."
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddComment(R"DOC(
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d_inception_fusion, ops::ConvInceptionFusionOp,
ops::ConvInceptionFusionOpMaker,
paddle::framework::EmptyGradOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit);
namespace paddle {
namespace operators {
#if CUDNN_VERSION >= 7001
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto* input = ctx.Input<Tensor>("Input");
auto filters = ctx.MultiInput<framework::Tensor>("Filter");
auto bias = ctx.MultiInput<framework::Tensor>("Bias");
auto* output = ctx.Output<Tensor>("Output");
auto temp_outs = ctx.MultiOutput<framework::Tensor>("TempOutput");
const std::string pool_type = ctx.Attr<std::string>("pooling_type");
const std::string activation = ctx.Attr<std::string>("activation");
const bool exclusive = ctx.Attr<bool>("exclusive");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
T* temp_data = temp_outs[0]->mutable_data<T>(input->dims(), ctx.GetPlace());
DataLayout layout = DataLayout::kNCHW;
std::vector<int> in_dim = framework::vectorize2int(input->dims());
// ------------------- cudnn descriptors ---------------------
PoolingMode pooling_mode;
if (pool_type == "max") {
pooling_mode = PoolingMode::kMaximum;
} else {
pooling_mode = exclusive ? PoolingMode::kAverageExclusive
: (PoolingMode::kAverageInclusive);
}
std::vector<int> k0x0 = {0, 0};
std::vector<int> k1x1 = {1, 1};
std::vector<int> k1x1_2 = {1, 1};
std::vector<int> k3x3 = {3, 3};
ScopedPoolingDescriptor pool_desc;
ScopedActivationDescriptor act_desc;
ScopedTensorDescriptor out_pool_desc;
ScopedTensorDescriptor input_desc;
cudnnPoolingDescriptor_t cudnn_pool_desc =
pool_desc.descriptor(pooling_mode, k3x3, k1x1, k1x1);
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t pool_out_desc = out_pool_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
cudnnTensorDescriptor_t* out_desc = new cudnnTensorDescriptor_t[4];
cudnnFilterDescriptor_t* filter_desc = new cudnnFilterDescriptor_t[4];
cudnnTensorDescriptor_t* bias_desc = new cudnnTensorDescriptor_t[4];
cudnnTensorDescriptor_t* in_desc = new cudnnTensorDescriptor_t[4];
cudnnConvolutionDescriptor_t* conv_desc =
new cudnnConvolutionDescriptor_t[4];
for (int i = 0; i < 4; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnCreateFilterDescriptor(&filter_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&bias_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&in_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&out_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateConvolutionDescriptor(&conv_desc[i]));
}
std::vector<std::vector<int>> filter_dims;
std::vector<std::vector<int>> bias_dims;
std::vector<std::vector<int>> in_dims;
std::vector<std::vector<int>> out_dims;
std::vector<std::vector<int>> in_strides;
std::vector<std::vector<int>> out_strides;
std::vector<std::vector<int>> bias_strides;
cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
int n = in_dim[0];
int h = in_dim[2];
int w = in_dim[3];
int oc = output->dims()[1];
cudnnDataType_t compute_type = (cudnn_dtype == CUDNN_DATA_DOUBLE)
? CUDNN_DATA_DOUBLE
: CUDNN_DATA_FLOAT;
for (int i = 0; i < 4; ++i) {
filter_dims.push_back(framework::vectorize2int(filters[i]->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
filter_desc[i], cudnn_dtype, format, 4, filter_dims[i].data()));
bias_dims.push_back({1, filter_dims[i][0], 1, 1});
bias_strides.push_back({filter_dims[i][0], 1, 1, 1});
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
bias_desc[i], cudnn_dtype, 4, bias_dims[i].data(),
bias_strides[i].data()));
in_dims.push_back({n, filter_dims[i][1], h, w});
out_dims.push_back({n, filter_dims[i][0], h, w});
in_strides.push_back({filter_dims[i][1] * h * w, h * w, w, 1});
out_strides.push_back({oc * h * w, h * w, w, 1});
if (i < 2) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionNdDescriptor(
conv_desc[i], 2, k0x0.data(), k1x1.data(), k1x1.data(),
CUDNN_CROSS_CORRELATION, compute_type));
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionNdDescriptor(
conv_desc[i], 2, k1x1.data(), k1x1.data(), k1x1.data(),
CUDNN_CROSS_CORRELATION, compute_type));
}
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
conv_desc[i], CUDNN_DEFAULT_MATH));
}
in_dims[2][1] *= 2;
in_strides[2][0] = oc * h * w;
out_strides[2][0] = filter_dims[2][0] * h * w; // this out is continuous.
in_strides[3][0] = filter_dims[2][0] * h * w;
CUDNN_ENFORCE(
platform::dynload::cudnnSetConvolutionGroupCount(conv_desc[2], 2));
cudnnConvolutionFwdAlgo_t algo[4];
auto handle = dev_ctx.cudnn_handle();
size_t workspace_size_in_bytes = 0; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
for (int i = 0; i < 4; ++i) {
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
in_desc[i], cudnn_dtype, 4, in_dims[i].data(), in_strides[i].data()));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
out_desc[i], cudnn_dtype, 4, out_dims[i].data(),
out_strides[i].data()));
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, in_desc[i], filter_desc[i], conv_desc[i], out_desc[i],
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
&algo[i]));
size_t tmp_size = 0;
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, in_desc[i], filter_desc[i], conv_desc[i], out_desc[i],
algo[i], &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
cudnnActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);
int oc0 = filter_dims[0][0];
int oc1 = filter_dims[1][0] - filter_dims[2][1] * 2;
int oc3 = filter_dims[3][0];
int oc2 = oc - oc0 - oc1 - oc3;
// branch1: pool + 1x1 conv
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
pool_out_desc, temp_data));
std::vector<const void*> in_datas;
in_datas.push_back(static_cast<const void*>(temp_data));
in_datas.push_back(static_cast<const void*>(input_data));
in_datas.push_back(
static_cast<const void*>(output_data + (oc0 + oc1) * h * w));
T* temp2_data = temp_outs[1]->mutable_data<T>(
framework::make_ddim(out_dims[2]), ctx.GetPlace());
in_datas.push_back(static_cast<const void*>(temp2_data + oc2 * h * w));
std::vector<void*> out_datas;
out_datas.push_back(static_cast<void*>(output_data));
out_datas.push_back(static_cast<void*>(output_data + oc0 * h * w));
out_datas.push_back(static_cast<void*>(temp2_data));
out_datas.push_back(
static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w));
for (int i = 0; i < 4; ++i) {
auto func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha, in_desc[i], in_datas[i], filter_desc[i],
static_cast<const void*>(filters[i]->data<T>()), conv_desc[i],
algo[i], cudnn_workspace, workspace_size_in_bytes, &beta,
out_desc[i], out_datas[i], bias_desc[i],
static_cast<const void*>(bias[i]->data<T>()), cudnn_act_desc,
out_desc[i], out_datas[i]));
};
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
workspace_handle.RunFunc(func, workspace_size_in_bytes);
}
cudnnTensorDescriptor_t x_desc;
cudnnTensorDescriptor_t y_desc;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&x_desc));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&y_desc));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
x_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[2].data()));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[3].data()));
CUDNN_ENFORCE(platform::dynload::cudnnTransformTensor(
handle, CudnnDataType<T>::kOne(), x_desc,
static_cast<const void*>(out_datas[2]), CudnnDataType<T>::kZero(),
y_desc, static_cast<void*>(output_data + (oc0 + oc1) * h * w)));
for (int i = 0; i < 4; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(in_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(out_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyFilterDescriptor(filter_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(bias_desc[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyConvolutionDescriptor(conv_desc[i]));
}
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(x_desc));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(y_desc));
}
};
#endif
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 7001
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_inception_fusion,
ops::CUDNNConvInceptionFusionOpKernel<float>,
ops::CUDNNConvInceptionFusionOpKernel<double>);
#endif
......@@ -88,7 +88,6 @@ class HingeLossGradOp : public framework::OperatorWithKernel {
"Input(Logits@GRAD) should not be null.");
auto pred_dims = ctx->GetInputDim("Logits");
auto lab_dims = ctx->GetInputDim("Labels");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
......
......@@ -92,7 +92,6 @@ class LogLossGradOp : public framework::OperatorWithKernel {
"Output(Predicted@GRAD) should not be null.");
auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
......
......@@ -37,9 +37,6 @@ void Transpose<DeviceContext, T, Rank>::operator()(
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out->dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(*out);
auto* dev = context.eigen_device();
......
......@@ -76,7 +76,6 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto in_dims = X->dims();
auto out_dims = Y->dims();
const float* in_data = X->data<float>();
float* out_data = Y->data<float>();
const int kBatchDim = 0;
......
......@@ -87,7 +87,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
"Input(Out@Grad) must not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
......
......@@ -147,12 +147,6 @@ class MulGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_mat_dims = framework::flatten_to_2d(
x_dims, ctx->Attrs().Get<int>("x_num_col_dims"));
auto y_mat_dims = framework::flatten_to_2d(
y_dims, ctx->Attrs().Get<int>("y_num_col_dims"));
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
......
......@@ -36,7 +36,6 @@ class NCEOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label");
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
if (ctx->HasInput("Bias")) {
......
......@@ -43,7 +43,6 @@ class NormKernel : public framework::OpKernel<T> {
out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
auto ndim = out_norm->dims();
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
......
......@@ -41,7 +41,6 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>();
......
......@@ -143,8 +143,6 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
set_zero(ctx.template device_context<DeviceContext>(), x_grad,
static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
......
......@@ -40,7 +40,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& dst_stride, T* dst) {
paddle::operators::detail::StridedCopyDimVisitor<T> func(
dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
dst_dim.apply_visitor(func);
}
// Strided numel memory copy from src to dst by the specified axis
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册