未验证 提交 f58fe6d3 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #6601 from chengduoZH/profiling/cosine_op

Refine cos-sim-op
......@@ -229,6 +229,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(cos_sim_op DEPS cos_sim_functor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
......
......@@ -13,19 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
......@@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace());
// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenVector<T>::Flatten(*out_z);
auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
// compute
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
int cols = framework::product(in_x->dims()) / rows_x;
if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
z.device(place) = xy / x_norm / y_norm;
math::CosSimFunctor<T, true> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(row_along);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
math::CosSimFunctor<T, false> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
}
}
};
......@@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
// convert Tensor to Eigen Tensor
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);
// compute gradident
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
math::CosSimGradFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
math::CosSimGradFunctor<T> functor(
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
} else {
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols);
auto y_bcast = y.broadcast(bcast_rows);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows))
.eval()
.broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
math::CosSimDxFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenVector<T>::Flatten(*out_grad_y);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));
math::CosSimDyFunctor<DeviceContext, T> functor;
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
static_cast<size_t>(cols), out_grad_y->data<T>());
}
}
}
......
......@@ -16,6 +16,7 @@ if(WITH_GPU)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -30,6 +31,7 @@ else()
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t row_id = 0; row_id < rows; ++row_id) {
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
auto dz_data = dz[row_id];
auto z_data = z[row_id];
auto* x_data = x + cols * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
};
template class CosSimDyFunctor<platform::CPUDeviceContext, float>;
template class CosSimDyFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
row_id += grid_size) {
T xy_norm_prod = x_norm[row_id] * y_norm_data;
T dz_data = dz[row_id];
T z_data = z[row_id];
const T* x_data = x + cols * row_id;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};
template class CosSimDyFunctor<platform::CUDADeviceContext, float>;
template class CosSimDyFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <stdlib.h>
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[row_id] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y_[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
}
}
T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};
template <typename T>
struct CosSimGradFunctor {
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto* x = x_ + cols_ * row_id;
auto* y = y_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto xy_norm_prod = x_norm_[row_id] * y_norm_[0];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* x = x_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm,
const T* x, const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) const;
};
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册