提交 bcf0b56f 编写于 作者: C chengduoZH

refine iterator

上级 784740d8
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_add_op.h" #include "paddle/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,27 +28,73 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,27 +28,73 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, typename DeviceContext> template <typename IT1, typename IT2, typename Callback>
void Function_forward(T* out, T* x_norm, T* y_norm, static void ForEachZip(IT1 begin1, IT1 last1, IT2 begin2, Callback callback) {
ElementIterator<T, DeviceContext>& x, // This method could be implemented in CUDA
ElementIterator<T, DeviceContext>& y, int row, int col) { for (; begin1 < last1; ++begin1, ++begin2) {
for (int i = 0; i < row; ++i) { callback(*begin1, *begin2);
T xx = 0;
T yy = 0;
T xy = 0;
for (int j = 0; j < col; ++j) {
xy += (*x) * (*y);
xx += (*x) * (*x);
yy += (*y) * (*y);
++y;
++x;
} }
x_norm[i] = sqrt(xx); }
y_norm[i] = sqrt(yy);
template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}
out[i] = xy / (x_norm[i] * y_norm[i]); inline void operator()(T& x_norm, T& y_norm) const {
size_t x_offset = &x_norm - x_norm_;
size_t y_offset = &y_norm - y_norm_;
auto* x = x_ + cols_ * x_offset;
T xx = 0, xy = 0;
T yy = 0;
if (same_row) {
auto* y = y_ + cols_ * y_offset;
for (size_t i = 0; i < cols_; ++i) {
xx += x[i] * x[i];
yy += y[i] * y[i];
xy += x[i] * y[i];
}
xx = sqrt(xx);
yy = sqrt(yy);
x_norm_[x_offset] = xx;
y_norm_[y_offset] = yy;
z_[x_offset] = xy / (xx * yy);
} else {
auto* y = y_;
// if (yy == -1) {
// yy = 0;
// for (size_t i = 0; i < cols_; ++i) {
// yy += y[i] * y[i];
// }
// y_norm[0] = sqrt(yy);
// }
for (size_t i = 0; i < cols_; ++i) {
xx += x[i] * x[i];
yy += y[i] * y[i]; // only need
xy += x[i] * y[i];
}
xx = sqrt(xx);
yy = sqrt(yy);
x_norm_[x_offset] = xx;
y_norm_[0] = yy;
z_[x_offset] = xy / (xx * yy);
}
} }
}
T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> { class CosSimKernel : public framework::OpKernel<T> {
...@@ -68,58 +114,140 @@ class CosSimKernel : public framework::OpKernel<T> { ...@@ -68,58 +114,140 @@ class CosSimKernel : public framework::OpKernel<T> {
int rows_y = in_y->dims()[0]; int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x; int cols = framework::product(in_x->dims()) / rows_x;
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
cols, rows_x, cols); if (rows_x == rows_y) {
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y, CosSimFunctor<T, true> functor(
cols, rows_x, cols); in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
Function_forward(out_z->data<T>(), out_x_norm->data<T>(), ForEachZip(out_x_norm->data<T>(), out_x_norm->data<T>() + rows_x,
out_y_norm->data<T>(), x_iter, y_iter, rows_x, cols); out_y_norm->data<T>(), functor);
// } else {
// // convert Tensor to Eigen Tensor CosSimFunctor<T, false> functor(
//// int rows_x = in_x->dims()[0]; in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
//// int rows_y = in_y->dims()[0]; out_y_norm->data<T>(), out_z->data<T>(), cols);
// auto x = EigenMatrix<T>::Reshape(*in_x, 1); ForEachZip(out_x_norm->data<T>(), out_x_norm->data<T>() + rows_x,
// auto y = EigenMatrix<T>::Reshape(*in_y, 1); out_y_norm->data<T>(), functor);
// auto z = EigenVector<T>::Flatten(*out_z); }
// auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
// auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
//
// // compute
// auto& place =
// *context.template device_context<DeviceContext>().eigen_device();
// auto row_along = Eigen::array<int, 1>({{1}});
// x_norm.device(place) = x.square().sum(row_along).sqrt();
// y_norm.device(place) = y.square().sum(row_along).sqrt();
// if (rows_x == rows_y) {
// auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
// z.device(place) = xy / x_norm / y_norm;
// } else {
// Eigen::DSizes<int, 2> bcast(rows_x, 1);
// auto xy = (x * y.broadcast(bcast)).sum(row_along);
// z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
// }
} }
}; };
template <typename T, typename DeviceContext> template <typename T>
void Function_element(T* result, ElementIterator<T, DeviceContext> dz, struct CosSimGradFunctor {
ElementIterator<T, DeviceContext> y, CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
ElementIterator<T, DeviceContext> x_norm, const T* z, const T* dz, T* dx, int cols)
ElementIterator<T, DeviceContext> y_norm, : x_norm_(x_norm),
ElementIterator<T, DeviceContext> z, y_norm_(y_norm),
ElementIterator<T, DeviceContext> x, int num, int block) { x_(x),
for (int i = 0; i < num; ++i) { y_(y),
result[i % block] += (*dz) * ((*y) / ((*x_norm) * (*y_norm)) - z_(z),
(*z) * (*x) / ((*x_norm) * (*x_norm))); dz_(dz),
++dz; dx_(dx),
++y; cols_(static_cast<size_t>(cols)) {}
++x_norm;
++y_norm; void operator()(const T& x_norm, const T& y_norm) const {
++z; size_t x_offset = &x_norm - x_norm_;
++x; size_t y_offset = &y_norm - y_norm_;
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
// auto y_norm_square = y_norm_[y_offset] * y_norm_[y_offset];
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[y_offset];
auto dz = dz_[x_offset];
auto* dx = dx_ + cols_ * x_offset;
auto* x = x_ + cols_ * x_offset;
auto* y = y_ + cols_ * y_offset;
auto z = z_[x_offset];
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] / xy_norm_prod - z * x[i] / x_norm_square);
} }
} }
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
void operator()(const T& x_norm, const T& y_norm) const {
size_t x_offset = &x_norm - x_norm_;
auto x_norm_square = x_norm_[x_offset] * x_norm_[x_offset];
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
auto dz = dz_[x_offset];
auto z = z_[x_offset];
auto* dx = dx_ + cols_ * x_offset;
auto* x = x_ + cols_ * x_offset;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] / xy_norm_prod - z * x[i] / x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDyFunctor {
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dy, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dy_(dy),
cols_(static_cast<size_t>(cols)) {}
void operator()(const T& x_norm, const T& y_norm) const {
size_t x_offset = &x_norm - x_norm_;
auto y_norm_square = y_norm_[0] * y_norm_[0];
auto xy_norm_prod = x_norm_[x_offset] * y_norm_[0];
auto dz = dz_[x_offset];
auto z = z_[x_offset];
auto* x = x_ + cols_ * x_offset;
for (size_t i = 0; i < cols_; ++i) {
dy_[i] += dz * (x[i] / xy_norm_prod - z * y_[i] / y_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dy_;
const size_t cols_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> { class CosSimGradKernel : public framework::OpKernel<T> {
...@@ -140,45 +268,40 @@ class CosSimGradKernel : public framework::OpKernel<T> { ...@@ -140,45 +268,40 @@ class CosSimGradKernel : public framework::OpKernel<T> {
int rows_y = in_y->dims()[0]; int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x; int cols = framework::product(in_x->dims()) / rows_x;
////////////////////////////// if (rows_x == rows_y) {
// ##
auto x_iter = ElementIterator<T, DeviceContext>(in_x->data<T>(), rows_x,
cols, rows_x, cols);
auto y_iter = ElementIterator<T, DeviceContext>(in_y->data<T>(), rows_y,
cols, rows_x, cols);
auto z_iter = ElementIterator<T, DeviceContext>(in_z->data<T>(), rows_x, 1,
rows_x, cols);
auto dz_iter = ElementIterator<T, DeviceContext>(in_grad_z->data<T>(),
rows_x, 1, rows_x, cols);
auto x_norm_iter = ElementIterator<T, DeviceContext>(
in_x_norm->data<T>(), rows_x, 1, rows_x, cols);
auto y_norm_iter = ElementIterator<T, DeviceContext>(
in_y_norm->data<T>(), rows_y, 1, rows_x, cols);
// ##
//////////////////////////////
// compute dx
if (out_grad_x) { if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace()); CosSimGradFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
////////////////////////////// in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
// ## out_grad_x->mutable_data<T>(context.GetPlace()), cols);
Function_element(out_grad_x->data<T>(), dz_iter, y_iter, x_norm_iter, ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
y_norm_iter, z_iter, x_iter, rows_x * cols, in_y_norm->data<T>(), functor);
rows_x * cols); }
// ## if (out_grad_y) {
////////////////////////////// CosSimGradFunctor<T> functor(
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
ForEachZip(in_y_norm->data<T>(), in_y_norm->data<T>() + rows_x,
in_x_norm->data<T>(), functor);
}
} else {
if (out_grad_x) {
CosSimDxFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
in_y_norm->data<T>(), functor);
} }
// compute dy
if (out_grad_y) { if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace()); CosSimDyFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
////////////////////////////// in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
// ## out_grad_y->mutable_data<T>(context.GetPlace()), cols);
Function_element(out_grad_y->data<T>(), dz_iter, x_iter, y_norm_iter, ForEachZip(in_x_norm->data<T>(), in_x_norm->data<T>() + rows_x,
x_norm_iter, z_iter, y_iter, rows_x * cols, in_y_norm->data<T>(), functor);
rows_y * cols); }
// ##
//////////////////////////////
} }
} }
}; };
......
...@@ -131,61 +131,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -131,61 +131,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
int post_; int post_;
}; };
template <typename T, typename Place>
class ElementIterator;
// Fixed(zcd) : Only support 2D
template <typename T>
class ElementIterator<T, platform::CPUDeviceContext> {
public:
ElementIterator(const T* ptr, int t_m, int t_n, int m, int n)
: ptr_(ptr),
index_(0),
i_(0),
j_(0),
t_m_(t_m),
t_n_(t_n),
m_(m),
n_(n) {}
ElementIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
if ((j_ == n_)) {
j_ = 0;
++i_;
}
int t_i = (t_m_ == 1) ? 0 : i_;
int t_j = (t_n_ == 1) ? 0 : j_;
index_ = t_i * t_n_ + t_j;
return *this;
}
bool operator==(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) == &(*rhs);
}
bool operator!=(
const ElementIterator<T, platform::CPUDeviceContext>& rhs) const {
return (ptr_ + index_) != &(*rhs);
}
const T& operator*() { return ptr_[index_]; }
private:
// t_m_ == m_ || t_n_ == n_ || (t_m_ == 1 && t_m_ == 1)
const T* ptr_;
int index_;
int i_;
int j_;
int64_t t_m_;
int64_t t_n_;
int64_t m_;
int64_t n_;
};
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext> class RowwiseTransformIterator<T, platform::CUDADeviceContext>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册