提交 784740d8 编写于 作者: C chengduoZH

refine cos-sim-op

上级 a91efdde
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_add_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,6 +28,28 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,6 +28,28 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, typename DeviceContext>
void Function_forward(T* out, T* x_norm, T* y_norm,
ElementIterator<T, DeviceContext>& x,
ElementIterator<T, DeviceContext>& y, int row, int col) {
for (int i = 0; i < row; ++i) {
T xx = 0;
T yy = 0;
T xy = 0;
for (int j = 0; j < col; ++j) {
xy += (*x) * (*y);
xx += (*x) * (*x);
yy += (*y) * (*y);
++y;
++x;
}
x_norm[i] = sqrt(xx);
y_norm[i] = sqrt(yy);
out[i] = xy / (x_norm[i] * y_norm[i]);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> { class CosSimKernel : public framework::OpKernel<T> {
public: public:
...@@ -41,32 +64,63 @@ class CosSimKernel : public framework::OpKernel<T> { ...@@ -41,32 +64,63 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm->mutable_data<T>(context.GetPlace()); out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace()); out_y_norm->mutable_data<T>(context.GetPlace());
// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0]; int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0]; int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1); int cols = framework::product(in_x->dims()) / rows_x;
auto z = EigenVector<T>::Flatten(*out_z); auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
auto x_norm = EigenVector<T>::Flatten(*out_x_norm); cols, rows_x, cols);
auto y_norm = EigenVector<T>::Flatten(*out_y_norm); auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
cols, rows_x, cols);
// compute
auto& place = Function_forward(out_z->data<T>(), out_x_norm->data<T>(),
*context.template device_context<DeviceContext>().eigen_device(); out_y_norm->data<T>(), x_iter, y_iter, rows_x, cols);
auto row_along = Eigen::array<int, 1>({{1}}); //
x_norm.device(place) = x.square().sum(row_along).sqrt(); // // convert Tensor to Eigen Tensor
y_norm.device(place) = y.square().sum(row_along).sqrt(); //// int rows_x = in_x->dims()[0];
if (rows_x == rows_y) { //// int rows_y = in_y->dims()[0];
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}})); // auto x = EigenMatrix<T>::Reshape(*in_x, 1);
z.device(place) = xy / x_norm / y_norm; // auto y = EigenMatrix<T>::Reshape(*in_y, 1);
} else { // auto z = EigenVector<T>::Flatten(*out_z);
Eigen::DSizes<int, 2> bcast(rows_x, 1); // auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
auto xy = (x * y.broadcast(bcast)).sum(row_along); // auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast); //
} // // compute
// auto& place =
// *context.template device_context<DeviceContext>().eigen_device();
// auto row_along = Eigen::array<int, 1>({{1}});
// x_norm.device(place) = x.square().sum(row_along).sqrt();
// y_norm.device(place) = y.square().sum(row_along).sqrt();
// if (rows_x == rows_y) {
// auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
// z.device(place) = xy / x_norm / y_norm;
// } else {
// Eigen::DSizes<int, 2> bcast(rows_x, 1);
// auto xy = (x * y.broadcast(bcast)).sum(row_along);
// z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
// }
} }
}; };
template <typename T, typename DeviceContext>
void Function_element(T* result, ElementIterator<T, DeviceContext> dz,
ElementIterator<T, DeviceContext> y,
ElementIterator<T, DeviceContext> x_norm,
ElementIterator<T, DeviceContext> y_norm,
ElementIterator<T, DeviceContext> z,
ElementIterator<T, DeviceContext> x, int num, int block) {
for (int i = 0; i < num; ++i) {
result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) -
(*z) * (*x) / ((*x_norm) * (*x_norm)));
++dz;
++y;
++x_norm;
++y_norm;
++z;
++x;
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> { class CosSimGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -81,63 +135,50 @@ class CosSimGradKernel : public framework::OpKernel<T> { ...@@ -81,63 +135,50 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y")); auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out")); auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
// convert Tensor to Eigen Tensor
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);
// compute gradident // compute gradident
int rows_x = in_x->dims()[0]; int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0]; int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x; int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols); //////////////////////////////
auto dz_bcast = dz.broadcast(bcast_cols); // ##
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
auto& place = cols, rows_x, cols);
*context.template device_context<DeviceContext>().eigen_device(); auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
if (rows_x == rows_y) { cols, rows_x, cols);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); auto z_iter = ElementIterator<T, DeviceContext>(in_z->data<T>(), rows_x, 1,
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); rows_x, cols);
// compute dx auto dz_iter = ElementIterator<T, DeviceContext>(in_grad_z->data<T>(),
if (out_grad_x) { rows_x, 1, rows_x, cols);
out_grad_x->mutable_data<T>(context.GetPlace()); auto x_norm_iter = ElementIterator<T, DeviceContext>(
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1); in_x_norm->data<T>(), rows_x, 1, rows_x, cols);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; auto y_norm_iter = ElementIterator<T, DeviceContext>(
dx.device(place) = dz_bcast * grad; in_y_norm->data<T>(), rows_y, 1, rows_x, cols);
} // ##
// compute dy //////////////////////////////
if (out_grad_y) { // compute dx
out_grad_y->mutable_data<T>(context.GetPlace()); if (out_grad_x) {
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1); out_grad_x->mutable_data<T>(context.GetPlace());
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad; //////////////////////////////
} // ##
} else { Function_element(out_grad_x->data<T>(), dz_iter, y_iter, x_norm_iter,
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1); y_norm_iter, z_iter, x_iter, rows_x * cols,
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols); rows_x * cols);
auto y_bcast = y.broadcast(bcast_rows); // ##
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); //////////////////////////////
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) }
.eval() // compute dy
.broadcast(bcast_cols); if (out_grad_y) {
// compute dx out_grad_y->mutable_data<T>(context.GetPlace());
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace()); //////////////////////////////
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1); // ##
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; Function_element(out_grad_y->data<T>(), dz_iter, x_iter, y_norm_iter,
dx.device(place) = dz_bcast * grad; x_norm_iter, z_iter, y_iter, rows_x * cols,
} rows_y * cols);
// compute dy // ##
if (out_grad_y) { //////////////////////////////
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenVector<T>::Flatten(*out_grad_y);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
}
} }
} }
}; };
......
...@@ -131,6 +131,61 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -131,6 +131,61 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
int post_; int post_;
}; };
template <typename T, typename Place>
class ElementIterator;
// Fixed(zcd) : Only support 2D
template <typename T>
class ElementIterator<T, platform::CPUDeviceContext> {
public:
ElementIterator(const T* ptr, int t_m, int t_n, int m, int n)
: ptr_(ptr),
index_(0),
i_(0),
j_(0),
t_m_(t_m),
t_n_(t_n),
m_(m),
n_(n) {}
ElementIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
if ((j_ == n_)) {
j_ = 0;
++i_;
}
int t_i = (t_m_ == 1) ? 0 : i_;
int t_j = (t_n_ == 1) ? 0 : j_;
index_ = t_i * t_n_ + t_j;
return *this;
}
bool operator==(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) == &(*rhs);
}
bool operator!=(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) != &(*rhs);
}
const T& operator*() { return ptr_[index_]; }
private:
// t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1)
const T* ptr_;
int index_;
int i_;
int j_;
int64_t t_m_;
int64_t t_n_;
int64_t m_;
int64_t n_;
};
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext> class RowwiseTransformIterator<T, platform::CUDADeviceContext>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册