diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index a7f275bae5748671acc689294ac31de0b59a1e73..77a3603eb63ff5d6a4ab5b96140af4a5cef555f4 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,18 +1,18 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) + nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor selected_rows) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) - cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() +cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor selected_rows) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 26bf0ec2f1e50b4cbf1065d61340efb8aeaace24..d31b223b2c8dc5460311118bf47cc1a8c7b45b08 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" +#include "paddle/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -191,7 +192,7 @@ struct SelectedRowsAdd { auto in2_place = input2.place(); PADDLE_ENFORCE(platform::is_gpu_place(in2_place)); auto out_place = context.GetPlace(); - PADDLE_ENFORCE(platform::is_gpu_place(out_place)) + PADDLE_ENFORCE(platform::is_gpu_place(out_place)); memory::Copy( boost::get(out_place), out_data, @@ -211,22 +212,26 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; namespace { -template -__global__ void SelectedRowsAddTensorKernel(T* selected_rows, int64_t* rows, - T* tensor_in, T* tensor_out, - const int64_t row_numel) { - const ty = blockIdx.y; +template +__global__ void SelectedRowsAddTensorKernel(const T* selected_rows, + const int64_t* rows, + T* tensor_out, + int64_t row_numel, + int block_size) { + const int ty = blockIdx.y; int tid = threadIdx.x; selected_rows += ty * row_numel; - tensor_in += rows[ty] * row_numel; tensor_out += rows[ty] * row_numel; for (int index = tid; index < row_numel; index += block_size) { - tensor_out[index] = tensor_in[index] + selected_rows[index]; + // Since index in rows of SelectedRows can be duplicate, we can not use + // tensor_out[index] += selected_rows[index]; Instead, we have to use + // AtomicAdd to avoid concurrent write error. + paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]); } } -} +} // namespace template struct SelectedRowsAddTensor { @@ -250,13 +255,22 @@ struct SelectedRowsAddTensor { auto* in2_data = input2.data(); auto* out_data = output->data(); - const int block_size = 256; + SetConstant functor; + functor(context, output, 0.0); + + int block_size = 256; dim3 threads(block_size, 1); dim3 grid(1, in1_height); - SelectedRowsAddTensorKernel<<< + SelectedRowsAddTensorKernel<<< grid, threads, 0, - reinterpret_cast(ctx).stream()>>>( - in1_data, in1_rows.data(), in2_data, out_data, in1_row_numel); + reinterpret_cast(context).stream() + >>>(in1_data, in1_rows.data(), + out_data, in1_row_numel, block_size); + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.GetEigenDevice()) = + out_eigen + in2_eigen; } }; diff --git a/paddle/operators/math/math_function_test.cu b/paddle/operators/math/math_function_test.cu index e691078bb6b44695bd859f58f60ba135f86ee512..1acc5f66a69f88d4eb2cf47c70277ca843b02627 100644 --- a/paddle/operators/math/math_function_test.cu +++ b/paddle/operators/math/math_function_test.cu @@ -183,20 +183,21 @@ TEST(math_function, selected_rows_add) { using namespace paddle::platform; using namespace paddle::operators::math; - CPUPlace gpu_place(0); + GPUPlace gpu_place(0); + CPUPlace cpu_place; CUDADeviceContext ctx(gpu_place); SetConstant functor; int64_t height = 10; int64_t row_numel = 10; - Vector rows1{0, 4, 7}; + std::vector rows1{0, 4, 7}; std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; auto* in1_value = selected_rows1->mutable_value(); in1_value->mutable_data( make_ddim({static_cast(rows1.size()), row_numel}), gpu_place); functor(ctx, in1_value, 1.0); - Vector rows2{0, 5, 7, 9}; + std::vector rows2{0, 5, 7, 9}; std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; auto* in2_value = selected_rows2->mutable_value(); in2_value->mutable_data( @@ -228,7 +229,7 @@ TEST(math_function, selected_rows_add) { EXPECT_EQ(out_rows[6], 9); Tensor out_cpu; - out_cpu.CopyFrom(*out_value, platform::CPUPlace(), ctx); + out_cpu.CopyFrom(*out_value, cpu_place, ctx); ctx.Wait(); auto* out_cpu_data = out_cpu.data(); @@ -256,10 +257,10 @@ TEST(math_function, selected_rows_add) { add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); Tensor tensor2_cpu; - tensor2_cpu.CopyFrom(*tensor2, platform::CPUPlace(), ctx); + tensor2_cpu.CopyFrom(*tensor2, cpu_place, ctx); ctx.Wait(); - auto* tensor2_cpu_data = tensor2_cpu->data(); + auto* tensor2_cpu_data = tensor2_cpu.data(); // row0: 1.0 + 2.0 + 3.0 EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0); // row1: 3.0