diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 06cf4a0f9f33af67343437baeb9623a35ddad183..19e5c2c73eac74dee030a4f7820531800f737e4e 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -38,31 +38,27 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; - std::string op_type_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; } const proto::OpProto& Proto() const { - PADDLE_ENFORCE_NOT_NULL(proto_, "Operator %s Proto has not been registered", - op_type_); + PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); PADDLE_ENFORCE(proto_->IsInitialized(), - "Operator %s Proto must be initialized in op info", - op_type_); + "Operator Proto must be initialized in op info"); return *proto_; } const OpCreator& Creator() const { - PADDLE_ENFORCE_NOT_NULL( - creator_, "Operator %s Creator has not been registered", op_type_); + PADDLE_ENFORCE_NOT_NULL(creator_, + "Operator Creator has not been registered"); return creator_; } const GradOpMakerFN& GradOpMaker() const { PADDLE_ENFORCE_NOT_NULL(grad_op_maker_, - "Operator %s GradOpMaker has not been registered.", - op_type_); + "Operator GradOpMaker has not been registered."); return grad_op_maker_; } @@ -77,9 +73,8 @@ class OpInfoMap { return map_.find(op_type) != map_.end(); } - void Insert(const std::string& type, OpInfo info) { + void Insert(const std::string& type, const OpInfo& info) { PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); - info.op_type_ = type; map_.insert({type, info}); } diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 326c58ee1c09d6f745e6c8abfb92030d11d8c1c6..a0d640b2020958af53a4405ae886eadb2a1e117e 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -45,12 +45,10 @@ class ReadInferVarType : public framework::VarTypeInference { framework::VarDesc* reader = block->FindVarRecursive(reader_name); auto dtypes = reader->GetDataTypes(); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); - auto lod_levels = reader->GetLoDLevels(); for (size_t i = 0; i < dtypes.size(); ++i) { framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetDataType(dtypes[i]); - out.SetLoDLevel(lod_levels[i]); } } }; diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 9527e7ba300e10a6af1a0dd4b312c0323115256e..4722be7a666d3e8f3c25c9499f88ddda835f60e3 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#define EIGEN_USE_GPU #include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -33,21 +33,22 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, } } -template +template __global__ void SparseSGDFunctorKernel(const T* selected_rows, const int64_t* rows, const T* learning_rate, T* tensor_out, - int64_t row_numel, int64_t limit) { - for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) { - const T* selected_rows_ptr = selected_rows + i * row_numel; - T* tensor_out_ptr = tensor_out + rows[i] * row_numel; - for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) { - // Since index in rows of SelectedRows can be duplicate, we have to use - // Atomic Operation to avoid concurrent write error. - paddle::platform::CudaAtomicAdd( - tensor_out_ptr + index, - -1.0 * learning_rate[0] * selected_rows_ptr[index]); - } + int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd( + tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]); } } } // namespace @@ -96,15 +97,13 @@ class SGDOpCUDAKernel : public framework::OpKernel { auto* in_data = in_value.data(); auto* out_data = param_out->data(); - const int kThreadsPerBlock = 256; - int thread_x = kThreadsPerBlock; - int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - SparseSGDFunctorKernel<<>>( + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in_rows.size()); + SparseSGDFunctorKernel< + T, 256><<>>( in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data(), - out_data, in_row_numel, in_rows.size()); + out_data, in_row_numel); } else { PADDLE_THROW("Unsupported Variable Type of Grad"); diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index e008e130e34f60a78bf44e211c42c4b7786d1721..29d2fb989754f5621222768a279a1c898ea1c355 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -52,26 +52,16 @@ class ShrinkRNNMemoryOp : public ArrayOp { size_t height = dst_num_rows; // do shrink for the top level LoD - if (x_tensor.lod().size() > 0 && x_tensor.lod()[0].size() > static_cast(dst_num_rows)) { - if (x_tensor.lod().size() > 1) { // MultiLevel LoD - auto lod_offset = framework::GetSubLoDAndAbsoluteOffset( - x_tensor.lod(), 0, dst_num_rows, 0); - height = lod_offset.second.second; - auto out_lod = out_tensor.mutable_lod(); - framework::AppendLoD(out_lod, lod_offset.first); - } else { - // Shrink LoD - auto lod_item = x_tensor.lod()[0]; - lod_item.resize(dst_num_rows + 1); - out_tensor.set_lod({lod_item}); - const auto &const_lod_item = lod_item; - height = const_lod_item.back(); - } + auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0, + dst_num_rows, 0); + height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + framework::AppendLoD(out_lod, lod_offset.first); } - if (height != 0) { + if (dst_num_rows != 0) { out_tensor.mutable_data(place, x_tensor.type()); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); framework::TensorCopy(x_tensor.Slice(0, height), place, *dev_ctx, @@ -144,11 +134,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { } else { auto &dout_tensor = dout_var->Get(); auto height = dout_tensor.dims()[0]; - if (height != 0) { - auto slice = dx_tensor.Slice(0, static_cast(height)); - framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, - &slice); - } + auto slice = dx_tensor.Slice(0, static_cast(height)); + framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice); if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( static_cast(height), static_cast(dx_tensor.dims()[0])); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 1b283fc9725fb8d01da913312844d0faea29daf6..dfc079e986e93c7f02f17b299e5d6293edbedd05 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -201,7 +201,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); - grid_max_dims_ = GpuMaxGridDim(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -240,10 +239,6 @@ int CUDADeviceContext::GetMaxPhysicalThreadCount() const { return multi_process * max_threads_per_mp; } -std::tuple CUDADeviceContext::GetMaxGridDims() const { - return grid_max_dims_; -} - Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index da32b0dad4b8cfe75bf82f59ec58db8136b899f2..79539195157d74d4d757edee5e008cbb76c93ee2 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -13,7 +13,6 @@ limitations under the License. */ #include #include // NOLINT #include -#include #include #include @@ -92,8 +91,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; - std::tuple GetMaxGridDims() const; - /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -138,8 +135,6 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cublasHandle_t cublas_handle_; - std::tuple grid_max_dims_; - int compute_capability; int multi_process; int max_threads_per_mp; diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index 2806d726d2b1ac6b717a9041af19e7ee62be6883..c153e80fe42aecb33d3aa97874d2881bce9029be 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -48,54 +48,35 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) { } template -__global__ static void ForRangeElemwiseOp(Function func, size_t limit) { +__global__ static void ForRangeElemwiseOp(Function func, int limit) { size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx < limit) { func(idx); } } -template -__global__ static void ForRangeElemwiseOpGridLarge(Function func, size_t limit, - int grid_dim) { - size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - while (idx < limit) { - func(idx); - idx += grid_dim; - } -} - template <> struct ForRange { ForRange(const CUDADeviceContext& dev_ctx, size_t limit) - : dev_ctx_(dev_ctx), limit_(limit) {} + : dev_ctx_(dev_ctx), limit_(static_cast(limit)) {} template inline void operator()(Function func) const { constexpr int num_threads = 1024; int block_size = limit_ <= num_threads ? limit_ : num_threads; - size_t grid_size = (limit_ + num_threads - 1) / num_threads; - - int max_grid_dim = std::get<0>(dev_ctx_.GetMaxGridDims()); - - if (grid_size < max_grid_dim) { - int grid_size_int = static_cast(grid_size); - if (grid_size == 1) { - ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( - func); - } else { - ForRangeElemwiseOp<<>>( - func, limit_); - } + int grid_size = (limit_ + num_threads - 1) / num_threads; + + if (grid_size == 1) { + ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( + func); } else { - ForRangeElemwiseOpGridLarge<<>>(func, limit_, - max_grid_dim); + ForRangeElemwiseOp<<>>( + func, limit_); } } const CUDADeviceContext& dev_ctx_; - size_t limit_; + int limit_; }; #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index b88523728407803a1ea9d343dc2d33c6a38d5de9..126636d879213b1c8f242db8fbdf6a358a1d2da9 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -152,22 +152,5 @@ void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream), "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync"); } - -std::tuple GpuMaxGridDim(int id) { - std::tuple result; - PADDLE_ENFORCE( - cudaDeviceGetAttribute(&std::get<0>(result), cudaDevAttrMaxBlockDimX, id), - "cudaDeviceGetAttribute failed in " - "cudaDevAttrMaxBlockDim"); - PADDLE_ENFORCE( - cudaDeviceGetAttribute(&std::get<1>(result), cudaDevAttrMaxBlockDimY, id), - "cudaDeviceGetAttribute failed in " - "cudaDevAttrMaxBlockDim"); - PADDLE_ENFORCE( - cudaDeviceGetAttribute(&std::get<2>(result), cudaDevAttrMaxBlockDimZ, id), - "cudaDeviceGetAttribute failed in " - "cudaDevAttrMaxBlockDim"); - return result; -} } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index b748c6e8a519d27acd211f815a210c7a74ff32c8..f4640d3eaa2165c35e8e14690d83e9e7e7168c0b 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include #include -#include namespace paddle { namespace platform { @@ -73,8 +72,6 @@ void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, //! Set memory dst with value count size asynchronously void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); -std::tuple GpuMaxGridDim(int id); - } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 75c29b12724d53783b9748d6df066c52bd232482..d56fa76300e7054ef71a7729483a579fa35f1dac 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -311,7 +311,6 @@ def _copy_reader_var_(block, var): new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) - new_var.desc.set_lod_levels(var.desc.lod_levels()) new_var.persistable = True return new_var @@ -633,7 +632,6 @@ def py_reader(capacity, }) startup_var.desc.set_dtypes(dtypes) - startup_var.desc.set_lod_levels(lod_levels) startup_var.persistable = True main_prog_var = _copy_reader_var_(default_main_program().current_block(),