diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 80e07800305c8538f37f479d205527d8dd7ca691..77492e60f2927fe70a9707e1da7414df4d3448a0 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -151,42 +151,26 @@ class CosSimOpGrad : public framework::OperatorWithKernel { template struct CosSimDyFunctor { - CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dy, int cols) - : x_norm_(x_norm), - y_norm_(y_norm), - x_(x), - y_(y), - z_(z), - dz_(dz), - dy_(dy), - cols_(static_cast(cols)) {} - - inline HOSTDEVICE void operator()(size_t offset) const { - auto xy_norm_prod = x_norm_[offset] * y_norm_[0]; - auto dz = dz_[offset]; - auto z = z_[offset]; - auto* x = x_ + cols_ * offset; - auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; - - auto y_norm_square = y_norm_[0] * y_norm_[0]; - auto reciprocal_y_norm_square = 1 / y_norm_square; - for (size_t i = 0; i < cols_; ++i) { - dy_[i] += dz * (x[i] * reciprocal_xy_norm_prod - - z * y_[i] * reciprocal_y_norm_square); + inline void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + for (size_t offset = 0; offset < rows; ++offset) { + auto xy_norm_prod = x_norm[offset] * y_norm[0]; + auto dz_data = dz[offset]; + auto z_data = z[offset]; + auto* x_data = x + cols * offset; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm[0] * y_norm[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + } } } - - const T* x_norm_; - const T* y_norm_; - const T* x_; - const T* y_; - const T* z_; - const T* dz_; - T* dy_; - const size_t cols_; }; - } // namespace operators } // namespace paddle diff --git a/paddle/operators/cos_sim_op.cu b/paddle/operators/cos_sim_op.cu index 88f49c1b141a1197ecfbf4b14ccb5d5ec6c4a2a1..42194d7a0576c32a37b4196d82fe0bfbcb3b29ee 100644 --- a/paddle/operators/cos_sim_op.cu +++ b/paddle/operators/cos_sim_op.cu @@ -20,45 +20,45 @@ namespace paddle { namespace operators { template -struct CosSimDyFunctor { - CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dy, int cols) - : x_norm_(x_norm), - y_norm_(y_norm), - x_(x), - y_(y), - z_(z), - dz_(dz), - dy_(dy), - cols_(static_cast(cols)) {} - - inline HOSTDEVICE void operator()(size_t offset) const { - auto xy_norm_prod = x_norm_[offset] * y_norm_[0]; - auto dz = dz_[offset]; - auto z = z_[offset]; - auto* x = x_ + cols_ * offset; - auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; +__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) { + int grid_size = blockDim.x * gridDim.x; + T y_norm_data = y_norm[0]; + for (int offset = blockIdx.x * blockDim.x + threadIdx.x; offset < rows; + offset += grid_size) { + T xy_norm_prod = x_norm[offset] * y_norm_data; + T dz_data = dz[offset]; + T z_data = z[offset]; + const T* x_data = x + cols * offset; + T reciprocal_xy_norm_prod = 1 / xy_norm_prod; - auto y_norm_square = y_norm_[0] * y_norm_[0]; - auto reciprocal_y_norm_square = 1 / y_norm_square; - for (size_t i = 0; i < cols_; ++i) { - T dy = dz * (x[i] * reciprocal_xy_norm_prod - - z * y_[i] * reciprocal_y_norm_square); - // platform::CudaAtomicAdd(dy_ + i, dy); - dy_[i] += dy; + T y_norm_square = y_norm_data * y_norm_data; + T reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + platform::CudaAtomicAdd(dy + i, dy_data); } } +} - const T* x_norm_; - const T* y_norm_; - const T* x_; - const T* y_; - const T* z_; - const T* dz_; - T* dy_; - const size_t cols_; +template +struct CosSimDyFunctor { + inline void operator()(const platform::CUDADeviceContext& ctx, + const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, const size_t rows, + const size_t cols, T* dy) const { + const int block_size = 512; + dim3 threads(block_size, 1); + dim3 grid(1, (rows + block_size - 1) / block_size); + CosSimDyKernel<<>>( + x_norm, y_norm, x, y, z, dz, rows, cols, dy); + } }; +template struct CosSimDyFunctor; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index bb7c893a29038412ed9a5e4747e3580cb6e03743..a913e576f9702d5eed464be3bdf4cc2ecd56d1fc 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -193,9 +193,10 @@ struct CosSimDxFunctor { template struct CosSimDyFunctor { - CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dy, int cols); - inline HOSTDEVICE void operator()(size_t) const; + inline void operator()(const DeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const; }; template @@ -255,14 +256,11 @@ class CosSimGradKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); set_zero(dev_ctx, out_grad_y, static_cast(0)); - CosSimDyFunctor functor( - in_x_norm->data(), in_y_norm->data(), in_x->data(), - in_y->data(), in_z->data(), in_grad_z->data(), - out_grad_y->data(), cols); - platform::ForRange for_range( - static_cast(context.device_context()), - rows_x); - for_range(functor); + CosSimDyFunctor functor; + functor(dev_ctx, in_x_norm->data(), in_y_norm->data(), + in_x->data(), in_y->data(), in_z->data(), + in_grad_z->data(), static_cast(rows_x), + static_cast(cols), out_grad_y->data()); } } }