From 1d91a49d2f8c304115ba12fef6944c72cf5a5352 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 26 Sep 2018 12:59:03 +0800 Subject: [PATCH] Some trivial optimization (#13530) * some trivial opt * remove the fix of lod_tensor and shrink_rnn_memory_op * refine ShrinkRNNMemoryOp test=develop --- paddle/fluid/framework/op_info.h | 17 +++++--- paddle/fluid/operators/read_op.cc | 2 + paddle/fluid/operators/sgd_op.cu | 41 ++++++++++--------- .../fluid/operators/shrink_rnn_memory_op.cc | 29 +++++++++---- paddle/fluid/platform/device_context.cc | 5 +++ paddle/fluid/platform/device_context.h | 5 +++ paddle/fluid/platform/for_range.h | 39 +++++++++++++----- paddle/fluid/platform/gpu_info.cc | 17 ++++++++ paddle/fluid/platform/gpu_info.h | 3 ++ python/paddle/fluid/layers/io.py | 2 + 10 files changed, 116 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 19e5c2c73..06cf4a0f9 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -38,27 +38,31 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + std::string op_type_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; } const proto::OpProto& Proto() const { - PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); + PADDLE_ENFORCE_NOT_NULL(proto_, "Operator %s Proto has not been registered", + op_type_); PADDLE_ENFORCE(proto_->IsInitialized(), - "Operator Proto must be initialized in op info"); + "Operator %s Proto must be initialized in op info", + op_type_); return *proto_; } const OpCreator& Creator() const { - PADDLE_ENFORCE_NOT_NULL(creator_, - "Operator Creator has not been registered"); + PADDLE_ENFORCE_NOT_NULL( + creator_, "Operator %s Creator has not been registered", op_type_); return creator_; } const GradOpMakerFN& GradOpMaker() const { PADDLE_ENFORCE_NOT_NULL(grad_op_maker_, - "Operator GradOpMaker has not been registered."); + "Operator %s GradOpMaker has not been registered.", + op_type_); return grad_op_maker_; } @@ -73,8 +77,9 @@ class OpInfoMap { return map_.find(op_type) != map_.end(); } - void Insert(const std::string& type, const OpInfo& info) { + void Insert(const std::string& type, OpInfo info) { PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); + info.op_type_ = type; map_.insert({type, info}); } diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index a0d640b20..326c58ee1 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -45,10 +45,12 @@ class ReadInferVarType : public framework::VarTypeInference { framework::VarDesc* reader = block->FindVarRecursive(reader_name); auto dtypes = reader->GetDataTypes(); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); + auto lod_levels = reader->GetLoDLevels(); for (size_t i = 0; i < dtypes.size(); ++i) { framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetDataType(dtypes[i]); + out.SetLoDLevel(lod_levels[i]); } } }; diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 4722be7a6..9527e7ba3 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU +#include #include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, } } -template +template __global__ void SparseSGDFunctorKernel(const T* selected_rows, const int64_t* rows, const T* learning_rate, T* tensor_out, - int64_t row_numel) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - selected_rows += ty * row_numel; - tensor_out += rows[ty] * row_numel; - - for (int index = tid; index < row_numel; index += block_size) { - // Since index in rows of SelectedRows can be duplicate, we have to use - // Atomic Operation to avoid concurrent write error. - paddle::platform::CudaAtomicAdd( - tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]); + int64_t row_numel, int64_t limit) { + for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) { + const T* selected_rows_ptr = selected_rows + i * row_numel; + T* tensor_out_ptr = tensor_out + rows[i] * row_numel; + for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd( + tensor_out_ptr + index, + -1.0 * learning_rate[0] * selected_rows_ptr[index]); + } } } } // namespace @@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel { auto* in_data = in_value.data(); auto* out_data = param_out->data(); - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid(1, in_rows.size()); - SparseSGDFunctorKernel< - T, 256><<>>( + const int kThreadsPerBlock = 256; + int thread_x = kThreadsPerBlock; + int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + SparseSGDFunctorKernel<<>>( in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data(), - out_data, in_row_numel); + out_data, in_row_numel, in_rows.size()); } else { PADDLE_THROW("Unsupported Variable Type of Grad"); diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index 29d2fb989..e008e130e 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -52,16 +52,26 @@ class ShrinkRNNMemoryOp : public ArrayOp { size_t height = dst_num_rows; // do shrink for the top level LoD + if (x_tensor.lod().size() > 0 && x_tensor.lod()[0].size() > static_cast(dst_num_rows)) { - auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0, - dst_num_rows, 0); - height = lod_offset.second.second; - auto out_lod = out_tensor.mutable_lod(); - framework::AppendLoD(out_lod, lod_offset.first); + if (x_tensor.lod().size() > 1) { // MultiLevel LoD + auto lod_offset = framework::GetSubLoDAndAbsoluteOffset( + x_tensor.lod(), 0, dst_num_rows, 0); + height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + framework::AppendLoD(out_lod, lod_offset.first); + } else { + // Shrink LoD + auto lod_item = x_tensor.lod()[0]; + lod_item.resize(dst_num_rows + 1); + out_tensor.set_lod({lod_item}); + const auto &const_lod_item = lod_item; + height = const_lod_item.back(); + } } - if (dst_num_rows != 0) { + if (height != 0) { out_tensor.mutable_data(place, x_tensor.type()); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); framework::TensorCopy(x_tensor.Slice(0, height), place, *dev_ctx, @@ -134,8 +144,11 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { } else { auto &dout_tensor = dout_var->Get(); auto height = dout_tensor.dims()[0]; - auto slice = dx_tensor.Slice(0, static_cast(height)); - framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice); + if (height != 0) { + auto slice = dx_tensor.Slice(0, static_cast(height)); + framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, + &slice); + } if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( static_cast(height), static_cast(dx_tensor.dims()[0])); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index dfc079e98..1b283fc97 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -201,6 +201,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); + grid_max_dims_ = GpuMaxGridDim(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -239,6 +240,10 @@ int CUDADeviceContext::GetMaxPhysicalThreadCount() const { return multi_process * max_threads_per_mp; } +std::tuple CUDADeviceContext::GetMaxGridDims() const { + return grid_max_dims_; +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 795391951..da32b0dad 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -13,6 +13,7 @@ limitations under the License. */ #include #include // NOLINT #include +#include #include #include @@ -91,6 +92,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; + std::tuple GetMaxGridDims() const; + /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -135,6 +138,8 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cublasHandle_t cublas_handle_; + std::tuple grid_max_dims_; + int compute_capability; int multi_process; int max_threads_per_mp; diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index c153e80fe..2806d726d 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -48,35 +48,54 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) { } template -__global__ static void ForRangeElemwiseOp(Function func, int limit) { +__global__ static void ForRangeElemwiseOp(Function func, size_t limit) { size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx < limit) { func(idx); } } +template +__global__ static void ForRangeElemwiseOpGridLarge(Function func, size_t limit, + int grid_dim) { + size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + while (idx < limit) { + func(idx); + idx += grid_dim; + } +} + template <> struct ForRange { ForRange(const CUDADeviceContext& dev_ctx, size_t limit) - : dev_ctx_(dev_ctx), limit_(static_cast(limit)) {} + : dev_ctx_(dev_ctx), limit_(limit) {} template inline void operator()(Function func) const { constexpr int num_threads = 1024; int block_size = limit_ <= num_threads ? limit_ : num_threads; - int grid_size = (limit_ + num_threads - 1) / num_threads; - - if (grid_size == 1) { - ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( - func); + size_t grid_size = (limit_ + num_threads - 1) / num_threads; + + int max_grid_dim = std::get<0>(dev_ctx_.GetMaxGridDims()); + + if (grid_size < max_grid_dim) { + int grid_size_int = static_cast(grid_size); + if (grid_size == 1) { + ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( + func); + } else { + ForRangeElemwiseOp<<>>( + func, limit_); + } } else { - ForRangeElemwiseOp<<>>( - func, limit_); + ForRangeElemwiseOpGridLarge<<>>(func, limit_, + max_grid_dim); } } const CUDADeviceContext& dev_ctx_; - int limit_; + size_t limit_; }; #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 126636d87..b88523728 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -152,5 +152,22 @@ void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream), "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync"); } + +std::tuple GpuMaxGridDim(int id) { + std::tuple result; + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&std::get<0>(result), cudaDevAttrMaxBlockDimX, id), + "cudaDeviceGetAttribute failed in " + "cudaDevAttrMaxBlockDim"); + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&std::get<1>(result), cudaDevAttrMaxBlockDimY, id), + "cudaDeviceGetAttribute failed in " + "cudaDevAttrMaxBlockDim"); + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&std::get<2>(result), cudaDevAttrMaxBlockDimZ, id), + "cudaDeviceGetAttribute failed in " + "cudaDevAttrMaxBlockDim"); + return result; +} } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index f4640d3ea..b748c6e8a 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include namespace paddle { namespace platform { @@ -72,6 +73,8 @@ void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, //! Set memory dst with value count size asynchronously void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); +std::tuple GpuMaxGridDim(int id); + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index d56fa7630..75c29b127 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -311,6 +311,7 @@ def _copy_reader_var_(block, var): new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.desc.set_lod_levels(var.desc.lod_levels()) new_var.persistable = True return new_var @@ -632,6 +633,7 @@ def py_reader(capacity, }) startup_var.desc.set_dtypes(dtypes) + startup_var.desc.set_lod_levels(lod_levels) startup_var.persistable = True main_prog_var = _copy_reader_var_(default_main_program().current_block(), -- GitLab