提交 24cf2fcd 编写于 作者: C chengduoZH

move cos_sim_functor to math

上级 4a11fdb4
......@@ -210,7 +210,8 @@ set(DEPS_OPS
save_op
load_op
send_op
recv_op)
recv_op
cos_sim_op)
if(WITH_DISTRIBUTE)
add_subdirectory(detail)
......@@ -256,6 +257,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
op_library(cos_sim_op DEPS cos_sim_functor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
......
......@@ -149,28 +149,6 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
}
};
template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
inline void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t row_id = 0; row_id < rows; ++row_id) {
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
auto dz_data = dz[row_id];
auto z_data = z[row_id];
auto* x_data = x + cols * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
};
} // namespace operators
} // namespace paddle
......
......@@ -14,51 +14,6 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/cos_sim_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
row_id += grid_size) {
T xy_norm_prod = x_norm[row_id] * y_norm_data;
T dz_data = dz[row_id];
T z_data = z[row_id];
const T* x_data = x + cols * row_id;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
inline void operator()(const platform::CUDADeviceContext& ctx,
const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz, const size_t rows,
const size_t cols, T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/for_range.h"
......@@ -22,59 +23,6 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[row_id] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y_[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
}
}
T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};
template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
public:
......@@ -95,14 +43,14 @@ class CosSimKernel : public framework::OpKernel<T> {
int cols = framework::product(in_x->dims()) / rows_x;
if (rows_x == rows_y) {
CosSimFunctor<T, true> functor(
math::CosSimFunctor<T, true> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
} else {
CosSimFunctor<T, false> functor(
math::CosSimFunctor<T, false> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
......@@ -112,93 +60,6 @@ class CosSimKernel : public framework::OpKernel<T> {
}
};
template <typename T>
struct CosSimGradFunctor {
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto* x = x_ + cols_ * row_id;
auto* y = y_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto xy_norm_prod = x_norm_[row_id] * y_norm_[0];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* x = x_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
inline void operator()(const DeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const;
};
template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> {
public:
......@@ -220,7 +81,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
if (rows_x == rows_y) {
if (out_grad_x) {
CosSimGradFunctor<T> functor(
math::CosSimGradFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
......@@ -230,7 +91,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
for_range(functor);
}
if (out_grad_y) {
CosSimGradFunctor<T> functor(
math::CosSimGradFunctor<T> functor(
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
......@@ -241,7 +102,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
}
} else {
if (out_grad_x) {
CosSimDxFunctor<T> functor(
math::CosSimDxFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
......@@ -256,7 +117,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));
CosSimDyFunctor<DeviceContext, T> functor;
math::CosSimDyFunctor<DeviceContext, T> functor;
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
......
......@@ -16,6 +16,7 @@ if(WITH_GPU)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -30,6 +31,7 @@ else()
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t row_id = 0; row_id < rows; ++row_id) {
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
auto dz_data = dz[row_id];
auto z_data = z[row_id];
auto* x_data = x + cols * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
};
template class CosSimDyFunctor<platform::CPUDeviceContext, float>;
template class CosSimDyFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
row_id += grid_size) {
T xy_norm_prod = x_norm[row_id] * y_norm_data;
T dz_data = dz[row_id];
T z_data = z[row_id];
const T* x_data = x + cols * row_id;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};
template class CosSimDyFunctor<platform::CUDADeviceContext, float>;
template class CosSimDyFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <stdlib.h>
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[row_id] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y_[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
}
}
T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};
template <typename T>
struct CosSimGradFunctor {
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto* x = x_ + cols_ * row_id;
auto* y = y_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto xy_norm_prod = x_norm_[row_id] * y_norm_[0];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* x = x_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm,
const T* x, const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) const;
};
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册