diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 3a7e67506d29f72ca3f73f6b2e5278324e2ce3c2..e96592ab28ad3d831e3db8e99935ab9c3139c83b 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/elementwise_add_op.h" +#include "paddle/operators/elementwise_op_function.h" namespace paddle { namespace operators { @@ -28,27 +28,73 @@ template using EigenVector = framework::EigenVector; -template -void Function_forward(T* out, T* x_norm, T* y_norm, - ElementIterator& x, - ElementIterator& y, int row, int col) { - for (int i = 0; i < row; ++i) { - T xx = 0; +template +static void ForEachZip(IT1 begin1, IT1 last1, IT2 begin2, Callback callback) { + // This method could be implemented in CUDA + for (; begin1 < last1; ++begin1, ++begin2) { + callback(*begin1, *begin2); + } +} + +template +struct CosSimFunctor { + CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + cols_(static_cast(cols)) {} + + inline void operator()(T& x_norm, T& y_norm) const { + size_t x_offset = &x_norm - x_norm_; + size_t y_offset = &y_norm - y_norm_; + + auto* x = x_ + cols_ * x_offset; + + T xx = 0, xy = 0; T yy = 0; - T xy = 0; - for (int j = 0; j < col; ++j) { - xy += (*x) * (*y); - xx += (*x) * (*x); - yy += (*y) * (*y); - ++y; - ++x; + if (same_row) { + auto* y = y_ + cols_ * y_offset; + for (size_t i = 0; i < cols_; ++i) { + xx += x[i] * x[i]; + yy += y[i] * y[i]; + xy += x[i] * y[i]; + } + xx = sqrt(xx); + yy = sqrt(yy); + x_norm_[x_offset] = xx; + y_norm_[y_offset] = yy; + z_[x_offset] = xy / (xx * yy); + } else { + auto* y = y_; + // if (yy == -1) { + // yy = 0; + // for (size_t i = 0; i < cols_; ++i) { + // yy += y[i] * y[i]; + // } + // y_norm[0] = sqrt(yy); + // } + for (size_t i = 0; i < cols_; ++i) { + xx += x[i] * x[i]; + yy += y[i] * y[i]; // only need + xy += x[i] * y[i]; + } + xx = sqrt(xx); + yy = sqrt(yy); + x_norm_[x_offset] = xx; + y_norm_[0] = yy; + z_[x_offset] = xy / (xx * yy); } - x_norm[i] = sqrt(xx); - y_norm[i] = sqrt(yy); - - out[i] = xy / (x_norm[i] * y_norm[i]); } -} + + T* x_norm_; + T* y_norm_; + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; template class CosSimKernel : public framework::OpKernel { @@ -68,58 +114,140 @@ class CosSimKernel : public framework::OpKernel { int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - auto x_iter = ElementIterator(in_x->data(), rows_x, - cols, rows_x, cols); - auto y_iter = ElementIterator(in_y->data(), rows_y, - cols, rows_x, cols); - - Function_forward(out_z->data(), out_x_norm->data(), - out_y_norm->data(), x_iter, y_iter, rows_x, cols); - // - // // convert Tensor to Eigen Tensor - //// int rows_x = in_x->dims()[0]; - //// int rows_y = in_y->dims()[0]; - // auto x = EigenMatrix::Reshape(*in_x, 1); - // auto y = EigenMatrix::Reshape(*in_y, 1); - // auto z = EigenVector::Flatten(*out_z); - // auto x_norm = EigenVector::Flatten(*out_x_norm); - // auto y_norm = EigenVector::Flatten(*out_y_norm); - // - // // compute - // auto& place = - // *context.template device_context().eigen_device(); - // auto row_along = Eigen::array({{1}}); - // x_norm.device(place) = x.square().sum(row_along).sqrt(); - // y_norm.device(place) = y.square().sum(row_along).sqrt(); - // if (rows_x == rows_y) { - // auto xy = (x * y).sum(Eigen::array({{1}})); - // z.device(place) = xy / x_norm / y_norm; - // } else { - // Eigen::DSizes bcast(rows_x, 1); - // auto xy = (x * y.broadcast(bcast)).sum(row_along); - // z.device(place) = xy / x_norm / y_norm.broadcast(bcast); - // } + + if (rows_x == rows_y) { + CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + ForEachZip(out_x_norm->data(), out_x_norm->data() + rows_x, + out_y_norm->data(), functor); + } else { + CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + ForEachZip(out_x_norm->data(), out_x_norm->data() + rows_x, + out_y_norm->data(), functor); + } } }; -template -void Function_element(T* result, ElementIterator dz, - ElementIterator y, - ElementIterator x_norm, - ElementIterator y_norm, - ElementIterator z, - ElementIterator x, int num, int block) { - for (int i = 0; i < num; ++i) { - result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) - - (*z) * (*x) / ((*x_norm) * (*x_norm))); - ++dz; - ++y; - ++x_norm; - ++y_norm; - ++z; - ++x; +template +struct CosSimGradFunctor { + CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + void operator()(const T& x_norm, const T& y_norm) const { + size_t x_offset = &x_norm - x_norm_; + size_t y_offset = &y_norm - y_norm_; + + auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset]; + // auto y_norm_square = y_norm_[y_offset] * y_norm_[y_offset]; + auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset]; + auto dz = dz_[x_offset]; + + auto* dx = dx_ + cols_ * x_offset; + auto* x = x_ + cols_ * x_offset; + auto* y = y_ + cols_ * y_offset; + auto z = z_[x_offset]; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y[i] / xy_norm_prod - z * x[i] / x_norm_square); + } } -} + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDxFunctor { + CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + void operator()(const T& x_norm, const T& y_norm) const { + size_t x_offset = &x_norm - x_norm_; + + auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset]; + auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0]; + auto dz = dz_[x_offset]; + auto z = z_[x_offset]; + + auto* dx = dx_ + cols_ * x_offset; + auto* x = x_ + cols_ * x_offset; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] / xy_norm_prod - z * x[i] / x_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDyFunctor { + CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dy, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dy_(dy), + cols_(static_cast(cols)) {} + + void operator()(const T& x_norm, const T& y_norm) const { + size_t x_offset = &x_norm - x_norm_; + + auto y_norm_square = y_norm_[0] * y_norm_[0]; + auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0]; + auto dz = dz_[x_offset]; + auto z = z_[x_offset]; + auto* x = x_ + cols_ * x_offset; + + for (size_t i = 0; i < cols_; ++i) { + dy_[i] += dz * (x[i] / xy_norm_prod - z * y_[i] / y_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dy_; + const size_t cols_; +}; template class CosSimGradKernel : public framework::OpKernel { @@ -140,45 +268,40 @@ class CosSimGradKernel : public framework::OpKernel { int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - ////////////////////////////// - // ## - auto x_iter = ElementIterator(in_x->data(), rows_x, - cols, rows_x, cols); - auto y_iter = ElementIterator(in_y->data(), rows_y, - cols, rows_x, cols); - auto z_iter = ElementIterator(in_z->data(), rows_x, 1, - rows_x, cols); - auto dz_iter = ElementIterator(in_grad_z->data(), - rows_x, 1, rows_x, cols); - auto x_norm_iter = ElementIterator( - in_x_norm->data(), rows_x, 1, rows_x, cols); - auto y_norm_iter = ElementIterator( - in_y_norm->data(), rows_y, 1, rows_x, cols); - // ## - ////////////////////////////// - // compute dx - if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - - ////////////////////////////// - // ## - Function_element(out_grad_x->data(), dz_iter, y_iter, x_norm_iter, - y_norm_iter, z_iter, x_iter, rows_x * cols, - rows_x * cols); - // ## - ////////////////////////////// - } - // compute dy - if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - - ////////////////////////////// - // ## - Function_element(out_grad_y->data(), dz_iter, x_iter, y_norm_iter, - x_norm_iter, z_iter, y_iter, rows_x * cols, - rows_y * cols); - // ## - ////////////////////////////// + if (rows_x == rows_y) { + if (out_grad_x) { + CosSimGradFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, + in_y_norm->data(), functor); + } + if (out_grad_y) { + CosSimGradFunctor functor( + in_y_norm->data(), in_x_norm->data(), in_y->data(), + in_x->data(), in_z->data(), in_grad_z->data(), + out_grad_y->mutable_data(context.GetPlace()), cols); + ForEachZip(in_y_norm->data(), in_y_norm->data() + rows_x, + in_x_norm->data(), functor); + } + } else { + if (out_grad_x) { + CosSimDxFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, + in_y_norm->data(), functor); + } + if (out_grad_y) { + CosSimDyFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_y->mutable_data(context.GetPlace()), cols); + ForEachZip(in_x_norm->data(), in_x_norm->data() + rows_x, + in_y_norm->data(), functor); + } } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 33b7d06467e84056fe0dcb25f0f52bdb2e4519f9..7ebfc7df8c117edd7bcf14cc5ae6ba3dc1302c03 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -131,61 +131,6 @@ class MidWiseTransformIterator { int post_; }; -template -class ElementIterator; - -// Fixed(zcd) : Only support 2D -template -class ElementIterator { - public: - ElementIterator(const T* ptr, int t_m, int t_n, int m, int n) - : ptr_(ptr), - index_(0), - i_(0), - j_(0), - t_m_(t_m), - t_n_(t_n), - m_(m), - n_(n) {} - - ElementIterator& operator++() { - ++j_; - - if ((j_ == n_)) { - j_ = 0; - ++i_; - } - int t_i = (t_m_ == 1) ? 0 : i_; - int t_j = (t_n_ == 1) ? 0 : j_; - index_ = t_i * t_n_ + t_j; - - return *this; - } - - bool operator==( - const ElementIterator& rhs) const { - return (ptr_ + index_) == &(*rhs); - } - - bool operator!=( - const ElementIterator& rhs) const { - return (ptr_ + index_) != &(*rhs); - } - - const T& operator*() { return ptr_[index_]; } - - private: - // t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1) - const T* ptr_; - int index_; - int i_; - int j_; - int64_t t_m_; - int64_t t_n_; - int64_t m_; - int64_t n_; -}; - #ifdef __NVCC__ template class RowwiseTransformIterator