提交 784740d8 编写于 作者: C chengduoZH

refine cos-sim-op

上级 a91efdde
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_add_op.h"
namespace paddle {
namespace operators {
......@@ -27,6 +28,28 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, typename DeviceContext>
void Function_forward(T* out, T* x_norm, T* y_norm,
ElementIterator<T, DeviceContext>& x,
ElementIterator<T, DeviceContext>& y, int row, int col) {
for (int i = 0; i < row; ++i) {
T xx = 0;
T yy = 0;
T xy = 0;
for (int j = 0; j < col; ++j) {
xy += (*x) * (*y);
xx += (*x) * (*x);
yy += (*y) * (*y);
++y;
++x;
}
x_norm[i] = sqrt(xx);
y_norm[i] = sqrt(yy);
out[i] = xy / (x_norm[i] * y_norm[i]);
}
}
template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
public:
......@@ -41,32 +64,63 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace());
// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenVector<T>::Flatten(*out_z);
auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
// compute
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
z.device(place) = xy / x_norm / y_norm;
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(row_along);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
}
int cols = framework::product(in_x->dims()) / rows_x;
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
cols, rows_x, cols);
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
cols, rows_x, cols);
Function_forward(out_z->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), x_iter, y_iter, rows_x, cols);
//
// // convert Tensor to Eigen Tensor
//// int rows_x = in_x->dims()[0];
//// int rows_y = in_y->dims()[0];
// auto x = EigenMatrix<T>::Reshape(*in_x, 1);
// auto y = EigenMatrix<T>::Reshape(*in_y, 1);
// auto z = EigenVector<T>::Flatten(*out_z);
// auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
// auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
//
// // compute
// auto& place =
// *context.template device_context<DeviceContext>().eigen_device();
// auto row_along = Eigen::array<int, 1>({{1}});
// x_norm.device(place) = x.square().sum(row_along).sqrt();
// y_norm.device(place) = y.square().sum(row_along).sqrt();
// if (rows_x == rows_y) {
// auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
// z.device(place) = xy / x_norm / y_norm;
// } else {
// Eigen::DSizes<int, 2> bcast(rows_x, 1);
// auto xy = (x * y.broadcast(bcast)).sum(row_along);
// z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
// }
}
};
template <typename T, typename DeviceContext>
void Function_element(T* result, ElementIterator<T, DeviceContext> dz,
ElementIterator<T, DeviceContext> y,
ElementIterator<T, DeviceContext> x_norm,
ElementIterator<T, DeviceContext> y_norm,
ElementIterator<T, DeviceContext> z,
ElementIterator<T, DeviceContext> x, int num, int block) {
for (int i = 0; i < num; ++i) {
result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) -
(*z) * (*x) / ((*x_norm) * (*x_norm)));
++dz;
++y;
++x_norm;
++y_norm;
++z;
++x;
}
}
template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> {
public:
......@@ -81,63 +135,50 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
// convert Tensor to Eigen Tensor
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);
// compute gradident
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
}
} else {
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols);
auto y_bcast = y.broadcast(bcast_rows);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows))
.eval()
.broadcast(bcast_cols);
//////////////////////////////
// ##
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
cols, rows_x, cols);
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
cols, rows_x, cols);
auto z_iter = ElementIterator<T, DeviceContext>(in_z->data<T>(), rows_x, 1,
rows_x, cols);
auto dz_iter = ElementIterator<T, DeviceContext>(in_grad_z->data<T>(),
rows_x, 1, rows_x, cols);
auto x_norm_iter = ElementIterator<T, DeviceContext>(
in_x_norm->data<T>(), rows_x, 1, rows_x, cols);
auto y_norm_iter = ElementIterator<T, DeviceContext>(
in_y_norm->data<T>(), rows_y, 1, rows_x, cols);
// ##
//////////////////////////////
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
//////////////////////////////
// ##
Function_element(out_grad_x->data<T>(), dz_iter, y_iter, x_norm_iter,
y_norm_iter, z_iter, x_iter, rows_x * cols,
rows_x * cols);
// ##
//////////////////////////////
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenVector<T>::Flatten(*out_grad_y);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
}
//////////////////////////////
// ##
Function_element(out_grad_y->data<T>(), dz_iter, x_iter, y_norm_iter,
x_norm_iter, z_iter, y_iter, rows_x * cols,
rows_y * cols);
// ##
//////////////////////////////
}
}
};
......
......@@ -131,6 +131,61 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
int post_;
};
template <typename T, typename Place>
class ElementIterator;
// Fixed(zcd) : Only support 2D
template <typename T>
class ElementIterator<T, platform::CPUDeviceContext> {
public:
ElementIterator(const T* ptr, int t_m, int t_n, int m, int n)
: ptr_(ptr),
index_(0),
i_(0),
j_(0),
t_m_(t_m),
t_n_(t_n),
m_(m),
n_(n) {}
ElementIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
if ((j_ == n_)) {
j_ = 0;
++i_;
}
int t_i = (t_m_ == 1) ? 0 : i_;
int t_j = (t_n_ == 1) ? 0 : j_;
index_ = t_i * t_n_ + t_j;
return *this;
}
bool operator==(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) == &(*rhs);
}
bool operator!=(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) != &(*rhs);
}
const T& operator*() { return ptr_[index_]; }
private:
// t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1)
const T* ptr_;
int index_;
int i_;
int j_;
int64_t t_m_;
int64_t t_n_;
int64_t m_;
int64_t n_;
};
#ifdef __NVCC__
template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册