From 7ac00dd684b025a8b1ea6a34a4cdf39ce7fd792e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 27 Dec 2017 15:23:49 +0800 Subject: [PATCH] refine --- paddle/operators/cos_sim_op.cc | 38 +++++++++ paddle/operators/cos_sim_op.cu | 45 +++++++++++ paddle/operators/cos_sim_op.h | 137 ++++++++++++++------------------- 3 files changed, 142 insertions(+), 78 deletions(-) diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 440c427cb..ab9cf745e 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -149,6 +149,44 @@ class CosSimOpGrad : public framework::OperatorWithKernel { } }; +template +struct CosSimDyFunctor { + CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dy, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dy_(dy), + cols_(static_cast(cols)) {} + + inline void operator()(size_t offset) const { + auto xy_norm_prod = x_norm_[offset] * y_norm_[0]; + auto dz = dz_[offset]; + auto z = z_[offset]; + auto* x = x_ + cols_ * offset; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm_[0] * y_norm_[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols_; ++i) { + dy_[i] += dz * (x[i] * reciprocal_xy_norm_prod - + z * y_[i] * reciprocal_y_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dy_; + const size_t cols_; +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/cos_sim_op.cu b/paddle/operators/cos_sim_op.cu index 1cb01f594..eacac68ba 100644 --- a/paddle/operators/cos_sim_op.cu +++ b/paddle/operators/cos_sim_op.cu @@ -15,6 +15,51 @@ #define EIGEN_USE_GPU #include "paddle/operators/cos_sim_op.h" +namespace paddle { +namespace operators { + +template +struct CosSimDyFunctor { + CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dy, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dy_(dy), + cols_(static_cast(cols)) {} + + inline void operator()(size_t offset) const { + auto xy_norm_prod = x_norm_[offset] * y_norm_[0]; + auto dz = dz_[offset]; + auto z = z_[offset]; + auto* x = x_ + cols_ * offset; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm_[0] * y_norm_[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols_; ++i) { + T dy = dz * (x[i] * reciprocal_xy_norm_prod - + z * y_[i] * reciprocal_y_norm_square); + paddle::paddleAtomicAdd(dy_ + i, dy) + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dy_; + const size_t cols_; +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( cos_sim, ops::CosSimKernel); diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index cd5c703c3..8b2a06a41 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -21,10 +21,17 @@ namespace operators { using Tensor = framework::Tensor; -template -static void ForEachZip(IT1 begin1, IT1 last1, IT2 begin2, Callback callback) { - for (; begin1 < last1; ++begin1, ++begin2) { - callback(*begin1, *begin2); +template +struct CosSimDyFunctor { + CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dy, int cols); + inline void operator()(size_t) const; +}; + +template +static void ForEachZip(size_t num, Callback callback) { + for (size_t i = 0; i < num; ++i) { + callback(i); } } @@ -38,16 +45,11 @@ struct CosSimFunctor { z_(z), cols_(static_cast(cols)) {} - inline void operator()(T& x_norm, T& y_norm) const { - size_t x_offset = &x_norm - x_norm_; - size_t y_offset = &y_norm - y_norm_; - - auto* x = x_ + cols_ * x_offset; - - T xx = 0, xy = 0; - T yy = 0; + inline HOSTDEVICE void operator()(size_t offset) const { + auto* x = x_ + cols_ * offset; + T xx = 0, xy = 0, yy = 0; if (same_row) { - auto* y = y_ + cols_ * y_offset; + auto* y = y_ + cols_ * offset; for (size_t i = 0; i < cols_; ++i) { xx += x[i] * x[i]; yy += y[i] * y[i]; @@ -55,21 +57,20 @@ struct CosSimFunctor { } xx = sqrt(xx); yy = sqrt(yy); - x_norm_[x_offset] = xx; - y_norm_[y_offset] = yy; - z_[x_offset] = xy / (xx * yy); + y_norm_[offset] = yy; + x_norm_[offset] = xx; + z_[offset] = xy / (xx * yy); } else { // This can be wrote in a better way. - auto* y = y_; for (size_t i = 0; i < cols_; ++i) { xx += x[i] * x[i]; - yy += y[i] * y[i]; // only need - xy += x[i] * y[i]; + yy += y_[i] * y_[i]; // only need + xy += x[i] * y_[i]; } xx = sqrt(xx); yy = sqrt(yy); - x_norm_[x_offset] = xx; y_norm_[0] = yy; - z_[x_offset] = xy / (xx * yy); + x_norm_[offset] = xx; + z_[offset] = xy / (xx * yy); } } @@ -104,14 +105,12 @@ class CosSimKernel : public framework::OpKernel { CosSimFunctor functor( in_x->data(), in_y->data(), out_x_norm->data(), out_y_norm->data(), out_z->data(), cols); - ForEachZip(out_x_norm->data(), out_x_norm->data() + rows_x, - out_y_norm->data(), functor); + ForEachZip(rows_x, functor); } else { CosSimFunctor functor( in_x->data(), in_y->data(), out_x_norm->data(), out_y_norm->data(), out_z->data(), cols); - ForEachZip(out_x_norm->data(), out_x_norm->data() + rows_x, - out_y_norm->data(), functor); + ForEachZip(rows_x, functor); } } }; @@ -129,19 +128,15 @@ struct CosSimGradFunctor { dx_(dx), cols_(static_cast(cols)) {} - inline void operator()(const T& x_norm, const T& y_norm) const { - size_t x_offset = &x_norm - x_norm_; - size_t y_offset = &y_norm - y_norm_; + inline HOSTDEVICE void operator()(size_t offset) const { + auto x_norm_square = x_norm_[offset] * x_norm_[offset]; + auto xy_norm_prod = x_norm_[offset] * y_norm_[offset]; + auto dz = dz_[offset]; + auto z = z_[offset]; - auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset]; - auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset]; - auto dz = dz_[x_offset]; - auto z = z_[x_offset]; - - auto* dx = dx_ + cols_ * x_offset; - auto* x = x_ + cols_ * x_offset; - - auto* y = y_ + cols_ * y_offset; + auto* dx = dx_ + cols_ * offset; + auto* x = x_ + cols_ * offset; + auto* y = y_ + cols_ * offset; auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; auto reciprocal_x_norm_square = 1 / x_norm_square; @@ -161,10 +156,10 @@ struct CosSimGradFunctor { const size_t cols_; }; -template +template struct CosSimDxFunctor { CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, - const T* z, const T* dz, T* dx, T* dy, int cols) + const T* z, const T* dz, T* dx, int cols) : x_norm_(x_norm), y_norm_(y_norm), x_(x), @@ -172,37 +167,23 @@ struct CosSimDxFunctor { z_(z), dz_(dz), dx_(dx), - dy_(dy), cols_(static_cast(cols)) {} - inline void operator()(const T& x_norm, const T& y_norm) const { - size_t x_offset = &x_norm - x_norm_; - - auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0]; - auto dz = dz_[x_offset]; - auto z = z_[x_offset]; - auto* x = x_ + cols_ * x_offset; + inline HOSTDEVICE void operator()(size_t offset) const { + auto xy_norm_prod = x_norm_[offset] * y_norm_[0]; + auto dz = dz_[offset]; + auto z = z_[offset]; + auto* x = x_ + cols_ * offset; auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto x_norm_square = x_norm_[offset] * x_norm_[offset]; + auto* dx = dx_ + cols_ * offset; + auto reciprocal_x_norm_square = 1 / x_norm_square; - if (Dx) { - auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset]; - auto* dx = dx_ + cols_ * x_offset; - auto* x = x_ + cols_ * x_offset; - auto reciprocal_x_norm_square = 1 / x_norm_square; - for (size_t i = 0; i < cols_; ++i) { - dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - - z * x[i] * reciprocal_x_norm_square); - } - } else { - auto y_norm_square = y_norm_[0] * y_norm_[0]; - auto reciprocal_y_norm_square = 1 / y_norm_square; - for (size_t i = 0; i < cols_; ++i) { - dy_[i] += dz * (x[i] * reciprocal_xy_norm_prod - - z * y_[i] * reciprocal_y_norm_square); - } + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); } } - const T* x_norm_; const T* y_norm_; const T* x_; @@ -210,7 +191,6 @@ struct CosSimDxFunctor { const T* z_; const T* dz_; T* dx_; - T* dy_; const size_t cols_; }; @@ -239,33 +219,34 @@ class CosSimGradKernel : public framework::OpKernel { in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), out_grad_x->mutable_data(context.GetPlace()), cols); - ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, - in_y_norm->data(), functor); + ForEachZip(rows_x, functor); } if (out_grad_y) { CosSimGradFunctor functor( in_y_norm->data(), in_x_norm->data(), in_y->data(), in_x->data(), in_z->data(), in_grad_z->data(), out_grad_y->mutable_data(context.GetPlace()), cols); - ForEachZip(in_y_norm->data(), in_y_norm->data() + rows_x, - in_x_norm->data(), functor); + ForEachZip(rows_x, functor); } } else { if (out_grad_x) { - CosSimDxFunctor functor( + CosSimDxFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), - out_grad_x->mutable_data(context.GetPlace()), nullptr, cols); - ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, - in_y_norm->data(), functor); + out_grad_x->mutable_data(context.GetPlace()), cols); + ForEachZip(rows_x, functor); } if (out_grad_y) { - CosSimDxFunctor functor( + out_grad_y->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, out_grad_y, static_cast(0)); + + CosSimDyFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), - in_y->data(), in_z->data(), in_grad_z->data(), nullptr, - out_grad_y->mutable_data(context.GetPlace()), cols); - ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, - in_y_norm->data(), functor); + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_y->data(), cols); + ForEachZip(rows_x, functor); } } } -- GitLab