diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index fecb5a79b2397dd73d991a1a87efcf84d60ef882..3a7e67506d29f72ca3f73f6b2e5278324e2ce3c2 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/elementwise_add_op.h" namespace paddle { namespace operators { @@ -27,6 +28,28 @@ template using EigenVector = framework::EigenVector; +template +void Function_forward(T* out, T* x_norm, T* y_norm, + ElementIterator& x, + ElementIterator& y, int row, int col) { + for (int i = 0; i < row; ++i) { + T xx = 0; + T yy = 0; + T xy = 0; + for (int j = 0; j < col; ++j) { + xy += (*x) * (*y); + xx += (*x) * (*x); + yy += (*y) * (*y); + ++y; + ++x; + } + x_norm[i] = sqrt(xx); + y_norm[i] = sqrt(yy); + + out[i] = xy / (x_norm[i] * y_norm[i]); + } +} + template class CosSimKernel : public framework::OpKernel { public: @@ -41,32 +64,63 @@ class CosSimKernel : public framework::OpKernel { out_x_norm->mutable_data(context.GetPlace()); out_y_norm->mutable_data(context.GetPlace()); - // convert Tensor to Eigen Tensor int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenVector::Flatten(*out_z); - auto x_norm = EigenVector::Flatten(*out_x_norm); - auto y_norm = EigenVector::Flatten(*out_y_norm); - - // compute - auto& place = - *context.template device_context().eigen_device(); - auto row_along = Eigen::array({{1}}); - x_norm.device(place) = x.square().sum(row_along).sqrt(); - y_norm.device(place) = y.square().sum(row_along).sqrt(); - if (rows_x == rows_y) { - auto xy = (x * y).sum(Eigen::array({{1}})); - z.device(place) = xy / x_norm / y_norm; - } else { - Eigen::DSizes bcast(rows_x, 1); - auto xy = (x * y.broadcast(bcast)).sum(row_along); - z.device(place) = xy / x_norm / y_norm.broadcast(bcast); - } + + int cols = framework::product(in_x->dims()) / rows_x; + auto x_iter = ElementIterator(in_x->data(), rows_x, + cols, rows_x, cols); + auto y_iter = ElementIterator(in_y->data(), rows_y, + cols, rows_x, cols); + + Function_forward(out_z->data(), out_x_norm->data(), + out_y_norm->data(), x_iter, y_iter, rows_x, cols); + // + // // convert Tensor to Eigen Tensor + //// int rows_x = in_x->dims()[0]; + //// int rows_y = in_y->dims()[0]; + // auto x = EigenMatrix::Reshape(*in_x, 1); + // auto y = EigenMatrix::Reshape(*in_y, 1); + // auto z = EigenVector::Flatten(*out_z); + // auto x_norm = EigenVector::Flatten(*out_x_norm); + // auto y_norm = EigenVector::Flatten(*out_y_norm); + // + // // compute + // auto& place = + // *context.template device_context().eigen_device(); + // auto row_along = Eigen::array({{1}}); + // x_norm.device(place) = x.square().sum(row_along).sqrt(); + // y_norm.device(place) = y.square().sum(row_along).sqrt(); + // if (rows_x == rows_y) { + // auto xy = (x * y).sum(Eigen::array({{1}})); + // z.device(place) = xy / x_norm / y_norm; + // } else { + // Eigen::DSizes bcast(rows_x, 1); + // auto xy = (x * y.broadcast(bcast)).sum(row_along); + // z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + // } } }; +template +void Function_element(T* result, ElementIterator dz, + ElementIterator y, + ElementIterator x_norm, + ElementIterator y_norm, + ElementIterator z, + ElementIterator x, int num, int block) { + for (int i = 0; i < num; ++i) { + result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) - + (*z) * (*x) / ((*x_norm) * (*x_norm))); + ++dz; + ++y; + ++x_norm; + ++y_norm; + ++z; + ++x; + } +} + template class CosSimGradKernel : public framework::OpKernel { public: @@ -81,63 +135,50 @@ class CosSimGradKernel : public framework::OpKernel { auto* out_grad_y = context.Output(framework::GradVarName("Y")); auto* in_grad_z = context.Input(framework::GradVarName("Out")); - // convert Tensor to Eigen Tensor - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenMatrix::Reshape(*in_z, 1); - auto x_norm = EigenMatrix::Reshape(*in_x_norm, 1); - auto y_norm = EigenMatrix::Reshape(*in_y_norm, 1); - auto dz = EigenMatrix::Reshape(*in_grad_z, 1); - // compute gradident int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - Eigen::DSizes bcast_cols(1, cols); - auto z_bcast = z.broadcast(bcast_cols); - auto dz_bcast = dz.broadcast(bcast_cols); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); - auto& place = - *context.template device_context().eigen_device(); - if (rows_x == rows_y) { - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); - // compute dx - if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; - } - // compute dy - if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::Reshape(*out_grad_y, 1); - auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; - dy.device(place) = dz_bcast * grad; - } - } else { - Eigen::DSizes bcast_rows(rows_x, 1); - Eigen::DSizes bcast_rows_cols(rows_x, cols); - auto y_bcast = y.broadcast(bcast_rows); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); - auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) - .eval() - .broadcast(bcast_cols); - // compute dx - if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; - } - // compute dy - if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenVector::Flatten(*out_grad_y); - auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; - dy.device(place) = (dz_bcast * grad).sum(Eigen::array({{0}})); - } + + ////////////////////////////// + // ## + auto x_iter = ElementIterator(in_x->data(), rows_x, + cols, rows_x, cols); + auto y_iter = ElementIterator(in_y->data(), rows_y, + cols, rows_x, cols); + auto z_iter = ElementIterator(in_z->data(), rows_x, 1, + rows_x, cols); + auto dz_iter = ElementIterator(in_grad_z->data(), + rows_x, 1, rows_x, cols); + auto x_norm_iter = ElementIterator( + in_x_norm->data(), rows_x, 1, rows_x, cols); + auto y_norm_iter = ElementIterator( + in_y_norm->data(), rows_y, 1, rows_x, cols); + // ## + ////////////////////////////// + // compute dx + if (out_grad_x) { + out_grad_x->mutable_data(context.GetPlace()); + + ////////////////////////////// + // ## + Function_element(out_grad_x->data(), dz_iter, y_iter, x_norm_iter, + y_norm_iter, z_iter, x_iter, rows_x * cols, + rows_x * cols); + // ## + ////////////////////////////// + } + // compute dy + if (out_grad_y) { + out_grad_y->mutable_data(context.GetPlace()); + + ////////////////////////////// + // ## + Function_element(out_grad_y->data(), dz_iter, x_iter, y_norm_iter, + x_norm_iter, z_iter, y_iter, rows_x * cols, + rows_y * cols); + // ## + ////////////////////////////// } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 7ebfc7df8c117edd7bcf14cc5ae6ba3dc1302c03..33b7d06467e84056fe0dcb25f0f52bdb2e4519f9 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -131,6 +131,61 @@ class MidWiseTransformIterator { int post_; }; +template +class ElementIterator; + +// Fixed(zcd) : Only support 2D +template +class ElementIterator { + public: + ElementIterator(const T* ptr, int t_m, int t_n, int m, int n) + : ptr_(ptr), + index_(0), + i_(0), + j_(0), + t_m_(t_m), + t_n_(t_n), + m_(m), + n_(n) {} + + ElementIterator& operator++() { + ++j_; + + if ((j_ == n_)) { + j_ = 0; + ++i_; + } + int t_i = (t_m_ == 1) ? 0 : i_; + int t_j = (t_n_ == 1) ? 0 : j_; + index_ = t_i * t_n_ + t_j; + + return *this; + } + + bool operator==( + const ElementIterator& rhs) const { + return (ptr_ + index_) == &(*rhs); + } + + bool operator!=( + const ElementIterator& rhs) const { + return (ptr_ + index_) != &(*rhs); + } + + const T& operator*() { return ptr_[index_]; } + + private: + // t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1) + const T* ptr_; + int index_; + int i_; + int j_; + int64_t t_m_; + int64_t t_n_; + int64_t m_; + int64_t n_; +}; + #ifdef __NVCC__ template class RowwiseTransformIterator