提交 de26ae41 编写于 作者: C chengduoZH

add gpu code

上级 4f5e3d0d
......@@ -151,42 +151,26 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dy, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dy_(dy),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t offset) const {
auto xy_norm_prod = x_norm_[offset] * y_norm_[0];
auto dz = dz_[offset];
auto z = z_[offset];
auto* x = x_ + cols_ * offset;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm_[0] * y_norm_[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dy_[i] += dz * (x[i] * reciprocal_xy_norm_prod -
z * y_[i] * reciprocal_y_norm_square);
inline void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t offset = 0; offset < rows; ++offset) {
auto xy_norm_prod = x_norm[offset] * y_norm[0];
auto dz_data = dz[offset];
auto z_data = z[offset];
auto* x_data = x + cols * offset;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dy_;
const size_t cols_;
};
} // namespace operators
} // namespace paddle
......
......@@ -20,45 +20,45 @@ namespace paddle {
namespace operators {
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dy, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dy_(dy),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t offset) const {
auto xy_norm_prod = x_norm_[offset] * y_norm_[0];
auto dz = dz_[offset];
auto z = z_[offset];
auto* x = x_ + cols_ * offset;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int offset = blockIdx.x * blockDim.x + threadIdx.x; offset < rows;
offset += grid_size) {
T xy_norm_prod = x_norm[offset] * y_norm_data;
T dz_data = dz[offset];
T z_data = z[offset];
const T* x_data = x + cols * offset;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm_[0] * y_norm_[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols_; ++i) {
T dy = dz * (x[i] * reciprocal_xy_norm_prod -
z * y_[i] * reciprocal_y_norm_square);
// platform::CudaAtomicAdd(dy_ + i, dy);
dy_[i] += dy;
T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dy_;
const size_t cols_;
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
inline void operator()(const platform::CUDADeviceContext& ctx,
const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz, const size_t rows,
const size_t cols, T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};
template struct CosSimDyFunctor<platform::CUDADeviceContext, float>;
} // namespace operators
} // namespace paddle
......
......@@ -193,9 +193,10 @@ struct CosSimDxFunctor {
template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
CosSimDyFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dy, int cols);
inline HOSTDEVICE void operator()(size_t) const;
inline void operator()(const DeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const;
};
template <typename DeviceContext, typename T>
......@@ -255,14 +256,11 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));
CosSimDyFunctor<DeviceContext, T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
CosSimDyFunctor<DeviceContext, T> functor;
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
static_cast<size_t>(cols), out_grad_y->data<T>());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册